feat: configurable server handlers
This commit is contained in:
parent
cbabf54144
commit
8b3cdd49b6
46
server.go
46
server.go
@ -34,7 +34,7 @@ type Server struct {
|
|||||||
IdleTimeout time.Duration // connection timeout when no activity, none if empty
|
IdleTimeout time.Duration // connection timeout when no activity, none if empty
|
||||||
MaxTimeout time.Duration // absolute connection timeout, none if empty
|
MaxTimeout time.Duration // absolute connection timeout, none if empty
|
||||||
|
|
||||||
channelHandlers map[string]channelHandler
|
channelHandlers map[string]ChannelHandler // fallback channel handlers
|
||||||
requestHandlers map[string]RequestHandler
|
requestHandlers map[string]RequestHandler
|
||||||
|
|
||||||
listenerWg sync.WaitGroup
|
listenerWg sync.WaitGroup
|
||||||
@ -48,8 +48,7 @@ type RequestHandler interface {
|
|||||||
HandleRequest(ctx Context, srv *Server, req *gossh.Request) (ok bool, payload []byte)
|
HandleRequest(ctx Context, srv *Server, req *gossh.Request) (ok bool, payload []byte)
|
||||||
}
|
}
|
||||||
|
|
||||||
// internal for now
|
type ChannelHandler func(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context)
|
||||||
type channelHandler func(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context)
|
|
||||||
|
|
||||||
func (srv *Server) ensureHostSigner() error {
|
func (srv *Server) ensureHostSigner() error {
|
||||||
if len(srv.HostSigners) == 0 {
|
if len(srv.HostSigners) == 0 {
|
||||||
@ -65,13 +64,17 @@ func (srv *Server) ensureHostSigner() error {
|
|||||||
func (srv *Server) ensureHandlers() {
|
func (srv *Server) ensureHandlers() {
|
||||||
srv.mu.Lock()
|
srv.mu.Lock()
|
||||||
defer srv.mu.Unlock()
|
defer srv.mu.Unlock()
|
||||||
srv.requestHandlers = map[string]RequestHandler{
|
if srv.requestHandlers == nil {
|
||||||
"tcpip-forward": forwardedTCPHandler{},
|
srv.requestHandlers = map[string]RequestHandler{
|
||||||
"cancel-tcpip-forward": forwardedTCPHandler{},
|
"tcpip-forward": forwardedTCPHandler{},
|
||||||
|
"cancel-tcpip-forward": forwardedTCPHandler{},
|
||||||
|
}
|
||||||
}
|
}
|
||||||
srv.channelHandlers = map[string]channelHandler{
|
if srv.channelHandlers == nil {
|
||||||
"session": sessionHandler,
|
srv.channelHandlers = map[string]ChannelHandler{
|
||||||
"direct-tcpip": directTcpipHandler,
|
"session": sessionHandler,
|
||||||
|
"direct-tcpip": directTcpipHandler,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -170,12 +173,6 @@ func (srv *Server) Serve(l net.Listener) error {
|
|||||||
if srv.Handler == nil {
|
if srv.Handler == nil {
|
||||||
srv.Handler = DefaultHandler
|
srv.Handler = DefaultHandler
|
||||||
}
|
}
|
||||||
if srv.channelHandlers == nil {
|
|
||||||
srv.channelHandlers = map[string]channelHandler{
|
|
||||||
"session": sessionHandler,
|
|
||||||
"direct-tcpip": directTcpipHandler,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
var tempDelay time.Duration
|
var tempDelay time.Duration
|
||||||
|
|
||||||
srv.trackListener(l, true)
|
srv.trackListener(l, true)
|
||||||
@ -206,6 +203,18 @@ func (srv *Server) Serve(l net.Listener) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (srv *Server) SetChannelHandler(kind string, handler ChannelHandler) {
|
||||||
|
srv.ensureHandlers()
|
||||||
|
srv.mu.Lock()
|
||||||
|
defer srv.mu.Unlock()
|
||||||
|
srv.channelHandlers[kind] = handler
|
||||||
|
}
|
||||||
|
|
||||||
|
func (srv *Server) GetChannelHandler(kind string) ChannelHandler {
|
||||||
|
srv.ensureHandlers()
|
||||||
|
return srv.channelHandlers[kind]
|
||||||
|
}
|
||||||
|
|
||||||
func (srv *Server) handleConn(newConn net.Conn) {
|
func (srv *Server) handleConn(newConn net.Conn) {
|
||||||
if srv.ConnCallback != nil {
|
if srv.ConnCallback != nil {
|
||||||
cbConn := srv.ConnCallback(newConn)
|
cbConn := srv.ConnCallback(newConn)
|
||||||
@ -240,7 +249,12 @@ func (srv *Server) handleConn(newConn net.Conn) {
|
|||||||
go srv.handleRequests(ctx, reqs)
|
go srv.handleRequests(ctx, reqs)
|
||||||
for ch := range chans {
|
for ch := range chans {
|
||||||
handler, found := srv.channelHandlers[ch.ChannelType()]
|
handler, found := srv.channelHandlers[ch.ChannelType()]
|
||||||
if !found {
|
if !found || handler == nil {
|
||||||
|
if defaultHandler, found := srv.channelHandlers["default"]; found {
|
||||||
|
handler = defaultHandler
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if handler == nil {
|
||||||
ch.Reject(gossh.UnknownChannelType, "unsupported channel type")
|
ch.Reject(gossh.UnknownChannelType, "unsupported channel type")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
@ -19,7 +19,7 @@ func (srv *Server) serveOnce(l net.Listener) error {
|
|||||||
if e != nil {
|
if e != nil {
|
||||||
return e
|
return e
|
||||||
}
|
}
|
||||||
srv.channelHandlers = map[string]channelHandler{
|
srv.channelHandlers = map[string]ChannelHandler{
|
||||||
"session": sessionHandler,
|
"session": sessionHandler,
|
||||||
"direct-tcpip": directTcpipHandler,
|
"direct-tcpip": directTcpipHandler,
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user