From 8b3cdd49b6d2f0c7aa025abc8453ea231848332c Mon Sep 17 00:00:00 2001 From: Manfred Touron Date: Fri, 16 Nov 2018 10:56:12 +0100 Subject: [PATCH 1/2] feat: configurable server handlers --- server.go | 46 ++++++++++++++++++++++++++++++---------------- session_test.go | 2 +- 2 files changed, 31 insertions(+), 17 deletions(-) diff --git a/server.go b/server.go index 31e3353..b990243 100644 --- a/server.go +++ b/server.go @@ -34,7 +34,7 @@ type Server struct { IdleTimeout time.Duration // connection timeout when no activity, none if empty MaxTimeout time.Duration // absolute connection timeout, none if empty - channelHandlers map[string]channelHandler + channelHandlers map[string]ChannelHandler // fallback channel handlers requestHandlers map[string]RequestHandler listenerWg sync.WaitGroup @@ -48,8 +48,7 @@ type RequestHandler interface { HandleRequest(ctx Context, srv *Server, req *gossh.Request) (ok bool, payload []byte) } -// internal for now -type channelHandler func(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) +type ChannelHandler func(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) func (srv *Server) ensureHostSigner() error { if len(srv.HostSigners) == 0 { @@ -65,13 +64,17 @@ func (srv *Server) ensureHostSigner() error { func (srv *Server) ensureHandlers() { srv.mu.Lock() defer srv.mu.Unlock() - srv.requestHandlers = map[string]RequestHandler{ - "tcpip-forward": forwardedTCPHandler{}, - "cancel-tcpip-forward": forwardedTCPHandler{}, + if srv.requestHandlers == nil { + srv.requestHandlers = map[string]RequestHandler{ + "tcpip-forward": forwardedTCPHandler{}, + "cancel-tcpip-forward": forwardedTCPHandler{}, + } } - srv.channelHandlers = map[string]channelHandler{ - "session": sessionHandler, - "direct-tcpip": directTcpipHandler, + if srv.channelHandlers == nil { + srv.channelHandlers = map[string]ChannelHandler{ + "session": sessionHandler, + "direct-tcpip": directTcpipHandler, + } } } @@ -170,12 +173,6 @@ func (srv *Server) Serve(l net.Listener) error { if srv.Handler == nil { srv.Handler = DefaultHandler } - if srv.channelHandlers == nil { - srv.channelHandlers = map[string]channelHandler{ - "session": sessionHandler, - "direct-tcpip": directTcpipHandler, - } - } var tempDelay time.Duration srv.trackListener(l, true) @@ -206,6 +203,18 @@ func (srv *Server) Serve(l net.Listener) error { } } +func (srv *Server) SetChannelHandler(kind string, handler ChannelHandler) { + srv.ensureHandlers() + srv.mu.Lock() + defer srv.mu.Unlock() + srv.channelHandlers[kind] = handler +} + +func (srv *Server) GetChannelHandler(kind string) ChannelHandler { + srv.ensureHandlers() + return srv.channelHandlers[kind] +} + func (srv *Server) handleConn(newConn net.Conn) { if srv.ConnCallback != nil { cbConn := srv.ConnCallback(newConn) @@ -240,7 +249,12 @@ func (srv *Server) handleConn(newConn net.Conn) { go srv.handleRequests(ctx, reqs) for ch := range chans { handler, found := srv.channelHandlers[ch.ChannelType()] - if !found { + if !found || handler == nil { + if defaultHandler, found := srv.channelHandlers["default"]; found { + handler = defaultHandler + } + } + if handler == nil { ch.Reject(gossh.UnknownChannelType, "unsupported channel type") continue } diff --git a/session_test.go b/session_test.go index 06c7724..82e0d2e 100644 --- a/session_test.go +++ b/session_test.go @@ -19,7 +19,7 @@ func (srv *Server) serveOnce(l net.Listener) error { if e != nil { return e } - srv.channelHandlers = map[string]channelHandler{ + srv.channelHandlers = map[string]ChannelHandler{ "session": sessionHandler, "direct-tcpip": directTcpipHandler, } From 570aa23f40f362f3cb19c14cfbaf0e60112feea0 Mon Sep 17 00:00:00 2001 From: Jose Diaz-Gonzalez Date: Sun, 23 Dec 2018 18:05:39 -0500 Subject: [PATCH 2/2] fix: use idiomatic go --- server.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server.go b/server.go index b990243..a3f5e41 100644 --- a/server.go +++ b/server.go @@ -210,7 +210,7 @@ func (srv *Server) SetChannelHandler(kind string, handler ChannelHandler) { srv.channelHandlers[kind] = handler } -func (srv *Server) GetChannelHandler(kind string) ChannelHandler { +func (srv *Server) ChannelHandler(kind string) ChannelHandler { srv.ensureHandlers() return srv.channelHandlers[kind] }