diff --git a/server.go b/server.go index 50eb632..56dec27 100644 --- a/server.go +++ b/server.go @@ -37,7 +37,7 @@ type Server struct { IdleTimeout time.Duration // connection timeout when no activity, none if empty MaxTimeout time.Duration // absolute connection timeout, none if empty - channelHandlers map[string]channelHandler + channelHandlers map[string]ChannelHandler // fallback channel handlers requestHandlers map[string]RequestHandler listenerWg sync.WaitGroup @@ -51,8 +51,7 @@ type RequestHandler interface { HandleRequest(ctx Context, srv *Server, req *gossh.Request) (ok bool, payload []byte) } -// internal for now -type channelHandler func(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) +type ChannelHandler func(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) func (srv *Server) ensureHostSigner() error { if len(srv.HostSigners) == 0 { @@ -68,13 +67,17 @@ func (srv *Server) ensureHostSigner() error { func (srv *Server) ensureHandlers() { srv.mu.Lock() defer srv.mu.Unlock() - srv.requestHandlers = map[string]RequestHandler{ - "tcpip-forward": forwardedTCPHandler{}, - "cancel-tcpip-forward": forwardedTCPHandler{}, + if srv.requestHandlers == nil { + srv.requestHandlers = map[string]RequestHandler{ + "tcpip-forward": forwardedTCPHandler{}, + "cancel-tcpip-forward": forwardedTCPHandler{}, + } } - srv.channelHandlers = map[string]channelHandler{ - "session": sessionHandler, - "direct-tcpip": directTcpipHandler, + if srv.channelHandlers == nil { + srv.channelHandlers = map[string]ChannelHandler{ + "session": sessionHandler, + "direct-tcpip": directTcpipHandler, + } } } @@ -186,12 +189,6 @@ func (srv *Server) Serve(l net.Listener) error { if srv.Handler == nil { srv.Handler = DefaultHandler } - if srv.channelHandlers == nil { - srv.channelHandlers = map[string]channelHandler{ - "session": sessionHandler, - "direct-tcpip": directTcpipHandler, - } - } var tempDelay time.Duration srv.trackListener(l, true) @@ -222,6 +219,18 @@ func (srv *Server) Serve(l net.Listener) error { } } +func (srv *Server) SetChannelHandler(kind string, handler ChannelHandler) { + srv.ensureHandlers() + srv.mu.Lock() + defer srv.mu.Unlock() + srv.channelHandlers[kind] = handler +} + +func (srv *Server) ChannelHandler(kind string) ChannelHandler { + srv.ensureHandlers() + return srv.channelHandlers[kind] +} + func (srv *Server) handleConn(newConn net.Conn) { if srv.ConnCallback != nil { cbConn := srv.ConnCallback(newConn) @@ -256,7 +265,12 @@ func (srv *Server) handleConn(newConn net.Conn) { go srv.handleRequests(ctx, reqs) for ch := range chans { handler, found := srv.channelHandlers[ch.ChannelType()] - if !found { + if !found || handler == nil { + if defaultHandler, found := srv.channelHandlers["default"]; found { + handler = defaultHandler + } + } + if handler == nil { ch.Reject(gossh.UnknownChannelType, "unsupported channel type") continue } diff --git a/session_test.go b/session_test.go index 06c7724..82e0d2e 100644 --- a/session_test.go +++ b/session_test.go @@ -19,7 +19,7 @@ func (srv *Server) serveOnce(l net.Listener) error { if e != nil { return e } - srv.channelHandlers = map[string]channelHandler{ + srv.channelHandlers = map[string]ChannelHandler{ "session": sessionHandler, "direct-tcpip": directTcpipHandler, }