From c9e327ebeb94a9d63f9fd8da80c2283c5a4c110e Mon Sep 17 00:00:00 2001 From: Kaleb Elwert Date: Wed, 12 Jun 2019 10:41:07 -0700 Subject: [PATCH] Remove Handler getters and setters --- server.go | 46 +++++++++++++++++++++------------------------- session_test.go | 2 +- 2 files changed, 22 insertions(+), 26 deletions(-) diff --git a/server.go b/server.go index 5119bec..934c230 100644 --- a/server.go +++ b/server.go @@ -37,8 +37,15 @@ type Server struct { IdleTimeout time.Duration // connection timeout when no activity, none if empty MaxTimeout time.Duration // absolute connection timeout, none if empty - channelHandlers map[string]ChannelHandler // fallback channel handlers - requestHandlers map[string]RequestHandler + // ChannelHandlers allow overriding the built-in session handlers or provide + // extensions to the protocol, such as tcpip forwarding. By default only the + // "session" handler is enabled. + ChannelHandlers map[string]ChannelHandler + + // RequestHandlers allow overriding the server-level request handlers or + // provide extensions to the protocol, such as tcpip forwarding. By default + // no handlers are enabled. + RequestHandlers map[string]RequestHandler listenerWg sync.WaitGroup mu sync.Mutex @@ -82,14 +89,14 @@ func (srv *Server) ensureHostSigner() error { func (srv *Server) ensureHandlers() { srv.mu.Lock() defer srv.mu.Unlock() - if srv.requestHandlers == nil { - srv.requestHandlers = map[string]RequestHandler{ + if srv.RequestHandlers == nil { + srv.RequestHandlers = map[string]RequestHandler{ "tcpip-forward": forwardedTCPHandler{}, "cancel-tcpip-forward": forwardedTCPHandler{}, } } - if srv.channelHandlers == nil { - srv.channelHandlers = map[string]ChannelHandler{ + if srv.ChannelHandlers == nil { + srv.ChannelHandlers = map[string]ChannelHandler{ "session": ChannelHandlerFunc(sessionHandler), "direct-tcpip": ChannelHandlerFunc(directTcpipHandler), } @@ -234,18 +241,6 @@ func (srv *Server) Serve(l net.Listener) error { } } -func (srv *Server) SetChannelHandler(kind string, handler ChannelHandler) { - srv.ensureHandlers() - srv.mu.Lock() - defer srv.mu.Unlock() - srv.channelHandlers[kind] = handler -} - -func (srv *Server) ChannelHandler(kind string) ChannelHandler { - srv.ensureHandlers() - return srv.channelHandlers[kind] -} - func (srv *Server) handleConn(newConn net.Conn) { if srv.ConnCallback != nil { cbConn := srv.ConnCallback(newConn) @@ -279,11 +274,9 @@ func (srv *Server) handleConn(newConn net.Conn) { //go gossh.DiscardRequests(reqs) go srv.handleRequests(ctx, reqs) for ch := range chans { - handler, found := srv.channelHandlers[ch.ChannelType()] - if !found || handler == nil { - if defaultHandler, found := srv.channelHandlers["default"]; found { - handler = defaultHandler - } + handler := srv.ChannelHandlers[ch.ChannelType()] + if handler == nil { + handler = srv.ChannelHandlers["default"] } if handler == nil { ch.Reject(gossh.UnknownChannelType, "unsupported channel type") @@ -295,8 +288,11 @@ func (srv *Server) handleConn(newConn net.Conn) { func (srv *Server) handleRequests(ctx Context, in <-chan *gossh.Request) { for req := range in { - handler, found := srv.requestHandlers[req.Type] - if !found { + handler := srv.RequestHandlers[req.Type] + if handler == nil { + handler = srv.RequestHandlers["default"] + } + if handler == nil { if req.WantReply { req.Reply(false, nil) } diff --git a/session_test.go b/session_test.go index 90034a7..4396c5a 100644 --- a/session_test.go +++ b/session_test.go @@ -19,7 +19,7 @@ func (srv *Server) serveOnce(l net.Listener) error { if e != nil { return e } - srv.channelHandlers = map[string]ChannelHandler{ + srv.ChannelHandlers = map[string]ChannelHandler{ "session": ChannelHandlerFunc(sessionHandler), "direct-tcpip": ChannelHandlerFunc(directTcpipHandler), }