From 8b3cdd49b6d2f0c7aa025abc8453ea231848332c Mon Sep 17 00:00:00 2001 From: Manfred Touron Date: Fri, 16 Nov 2018 10:56:12 +0100 Subject: [PATCH 1/6] feat: configurable server handlers --- server.go | 46 ++++++++++++++++++++++++++++++---------------- session_test.go | 2 +- 2 files changed, 31 insertions(+), 17 deletions(-) diff --git a/server.go b/server.go index 31e3353..b990243 100644 --- a/server.go +++ b/server.go @@ -34,7 +34,7 @@ type Server struct { IdleTimeout time.Duration // connection timeout when no activity, none if empty MaxTimeout time.Duration // absolute connection timeout, none if empty - channelHandlers map[string]channelHandler + channelHandlers map[string]ChannelHandler // fallback channel handlers requestHandlers map[string]RequestHandler listenerWg sync.WaitGroup @@ -48,8 +48,7 @@ type RequestHandler interface { HandleRequest(ctx Context, srv *Server, req *gossh.Request) (ok bool, payload []byte) } -// internal for now -type channelHandler func(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) +type ChannelHandler func(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) func (srv *Server) ensureHostSigner() error { if len(srv.HostSigners) == 0 { @@ -65,13 +64,17 @@ func (srv *Server) ensureHostSigner() error { func (srv *Server) ensureHandlers() { srv.mu.Lock() defer srv.mu.Unlock() - srv.requestHandlers = map[string]RequestHandler{ - "tcpip-forward": forwardedTCPHandler{}, - "cancel-tcpip-forward": forwardedTCPHandler{}, + if srv.requestHandlers == nil { + srv.requestHandlers = map[string]RequestHandler{ + "tcpip-forward": forwardedTCPHandler{}, + "cancel-tcpip-forward": forwardedTCPHandler{}, + } } - srv.channelHandlers = map[string]channelHandler{ - "session": sessionHandler, - "direct-tcpip": directTcpipHandler, + if srv.channelHandlers == nil { + srv.channelHandlers = map[string]ChannelHandler{ + "session": sessionHandler, + "direct-tcpip": directTcpipHandler, + } } } @@ -170,12 +173,6 @@ func (srv *Server) Serve(l net.Listener) error { if srv.Handler == nil { srv.Handler = DefaultHandler } - if srv.channelHandlers == nil { - srv.channelHandlers = map[string]channelHandler{ - "session": sessionHandler, - "direct-tcpip": directTcpipHandler, - } - } var tempDelay time.Duration srv.trackListener(l, true) @@ -206,6 +203,18 @@ func (srv *Server) Serve(l net.Listener) error { } } +func (srv *Server) SetChannelHandler(kind string, handler ChannelHandler) { + srv.ensureHandlers() + srv.mu.Lock() + defer srv.mu.Unlock() + srv.channelHandlers[kind] = handler +} + +func (srv *Server) GetChannelHandler(kind string) ChannelHandler { + srv.ensureHandlers() + return srv.channelHandlers[kind] +} + func (srv *Server) handleConn(newConn net.Conn) { if srv.ConnCallback != nil { cbConn := srv.ConnCallback(newConn) @@ -240,7 +249,12 @@ func (srv *Server) handleConn(newConn net.Conn) { go srv.handleRequests(ctx, reqs) for ch := range chans { handler, found := srv.channelHandlers[ch.ChannelType()] - if !found { + if !found || handler == nil { + if defaultHandler, found := srv.channelHandlers["default"]; found { + handler = defaultHandler + } + } + if handler == nil { ch.Reject(gossh.UnknownChannelType, "unsupported channel type") continue } diff --git a/session_test.go b/session_test.go index 06c7724..82e0d2e 100644 --- a/session_test.go +++ b/session_test.go @@ -19,7 +19,7 @@ func (srv *Server) serveOnce(l net.Listener) error { if e != nil { return e } - srv.channelHandlers = map[string]channelHandler{ + srv.channelHandlers = map[string]ChannelHandler{ "session": sessionHandler, "direct-tcpip": directTcpipHandler, } From 570aa23f40f362f3cb19c14cfbaf0e60112feea0 Mon Sep 17 00:00:00 2001 From: Jose Diaz-Gonzalez Date: Sun, 23 Dec 2018 18:05:39 -0500 Subject: [PATCH 2/6] fix: use idiomatic go --- server.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server.go b/server.go index b990243..a3f5e41 100644 --- a/server.go +++ b/server.go @@ -210,7 +210,7 @@ func (srv *Server) SetChannelHandler(kind string, handler ChannelHandler) { srv.channelHandlers[kind] = handler } -func (srv *Server) GetChannelHandler(kind string) ChannelHandler { +func (srv *Server) ChannelHandler(kind string) ChannelHandler { srv.ensureHandlers() return srv.channelHandlers[kind] } From 77856273e04e3d7953348fbd71e29628debed3f8 Mon Sep 17 00:00:00 2001 From: Kaleb Elwert Date: Wed, 12 Jun 2019 10:29:36 -0700 Subject: [PATCH 3/6] Clean up Channel and Request Handler interfaces --- server.go | 27 +++++++++++++++++++++------ session_test.go | 4 ++-- tcpip.go | 2 +- 3 files changed, 24 insertions(+), 9 deletions(-) diff --git a/server.go b/server.go index 56dec27..5119bec 100644 --- a/server.go +++ b/server.go @@ -47,11 +47,26 @@ type Server struct { connWg sync.WaitGroup doneChan chan struct{} } + type RequestHandler interface { - HandleRequest(ctx Context, srv *Server, req *gossh.Request) (ok bool, payload []byte) + HandleSSHRequest(ctx Context, srv *Server, req *gossh.Request) (ok bool, payload []byte) } -type ChannelHandler func(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) +type RequestHandlerFunc func(ctx Context, srv *Server, req *gossh.Request) (ok bool, payload []byte) + +func (f RequestHandlerFunc) HandleSSHRequest(ctx Context, srv *Server, req *gossh.Request) (ok bool, payload []byte) { + return f(ctx, srv, req) +} + +type ChannelHandler interface { + HandleSSHChannel(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) +} + +type ChannelHandlerFunc func(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) + +func (f ChannelHandlerFunc) HandleSSHChannel(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) { + f(srv, conn, newChan, ctx) +} func (srv *Server) ensureHostSigner() error { if len(srv.HostSigners) == 0 { @@ -75,8 +90,8 @@ func (srv *Server) ensureHandlers() { } if srv.channelHandlers == nil { srv.channelHandlers = map[string]ChannelHandler{ - "session": sessionHandler, - "direct-tcpip": directTcpipHandler, + "session": ChannelHandlerFunc(sessionHandler), + "direct-tcpip": ChannelHandlerFunc(directTcpipHandler), } } } @@ -274,7 +289,7 @@ func (srv *Server) handleConn(newConn net.Conn) { ch.Reject(gossh.UnknownChannelType, "unsupported channel type") continue } - go handler(srv, sshConn, ch, ctx) + go handler.HandleSSHChannel(srv, sshConn, ch, ctx) } } @@ -289,7 +304,7 @@ func (srv *Server) handleRequests(ctx Context, in <-chan *gossh.Request) { } /*reqCtx, cancel := context.WithCancel(ctx) defer cancel() */ - ret, payload := handler.HandleRequest(ctx, srv, req) + ret, payload := handler.HandleSSHRequest(ctx, srv, req) if req.WantReply { req.Reply(ret, payload) } diff --git a/session_test.go b/session_test.go index 82e0d2e..90034a7 100644 --- a/session_test.go +++ b/session_test.go @@ -20,8 +20,8 @@ func (srv *Server) serveOnce(l net.Listener) error { return e } srv.channelHandlers = map[string]ChannelHandler{ - "session": sessionHandler, - "direct-tcpip": directTcpipHandler, + "session": ChannelHandlerFunc(sessionHandler), + "direct-tcpip": ChannelHandlerFunc(directTcpipHandler), } srv.handleConn(conn) return nil diff --git a/tcpip.go b/tcpip.go index 64542ff..4afbf2d 100644 --- a/tcpip.go +++ b/tcpip.go @@ -89,7 +89,7 @@ type forwardedTCPHandler struct { sync.Mutex } -func (h forwardedTCPHandler) HandleRequest(ctx Context, srv *Server, req *gossh.Request) (bool, []byte) { +func (h forwardedTCPHandler) HandleSSHRequest(ctx Context, srv *Server, req *gossh.Request) (bool, []byte) { h.Lock() if h.forwards == nil { h.forwards = make(map[string]net.Listener) From c9e327ebeb94a9d63f9fd8da80c2283c5a4c110e Mon Sep 17 00:00:00 2001 From: Kaleb Elwert Date: Wed, 12 Jun 2019 10:41:07 -0700 Subject: [PATCH 4/6] Remove Handler getters and setters --- server.go | 46 +++++++++++++++++++++------------------------- session_test.go | 2 +- 2 files changed, 22 insertions(+), 26 deletions(-) diff --git a/server.go b/server.go index 5119bec..934c230 100644 --- a/server.go +++ b/server.go @@ -37,8 +37,15 @@ type Server struct { IdleTimeout time.Duration // connection timeout when no activity, none if empty MaxTimeout time.Duration // absolute connection timeout, none if empty - channelHandlers map[string]ChannelHandler // fallback channel handlers - requestHandlers map[string]RequestHandler + // ChannelHandlers allow overriding the built-in session handlers or provide + // extensions to the protocol, such as tcpip forwarding. By default only the + // "session" handler is enabled. + ChannelHandlers map[string]ChannelHandler + + // RequestHandlers allow overriding the server-level request handlers or + // provide extensions to the protocol, such as tcpip forwarding. By default + // no handlers are enabled. + RequestHandlers map[string]RequestHandler listenerWg sync.WaitGroup mu sync.Mutex @@ -82,14 +89,14 @@ func (srv *Server) ensureHostSigner() error { func (srv *Server) ensureHandlers() { srv.mu.Lock() defer srv.mu.Unlock() - if srv.requestHandlers == nil { - srv.requestHandlers = map[string]RequestHandler{ + if srv.RequestHandlers == nil { + srv.RequestHandlers = map[string]RequestHandler{ "tcpip-forward": forwardedTCPHandler{}, "cancel-tcpip-forward": forwardedTCPHandler{}, } } - if srv.channelHandlers == nil { - srv.channelHandlers = map[string]ChannelHandler{ + if srv.ChannelHandlers == nil { + srv.ChannelHandlers = map[string]ChannelHandler{ "session": ChannelHandlerFunc(sessionHandler), "direct-tcpip": ChannelHandlerFunc(directTcpipHandler), } @@ -234,18 +241,6 @@ func (srv *Server) Serve(l net.Listener) error { } } -func (srv *Server) SetChannelHandler(kind string, handler ChannelHandler) { - srv.ensureHandlers() - srv.mu.Lock() - defer srv.mu.Unlock() - srv.channelHandlers[kind] = handler -} - -func (srv *Server) ChannelHandler(kind string) ChannelHandler { - srv.ensureHandlers() - return srv.channelHandlers[kind] -} - func (srv *Server) handleConn(newConn net.Conn) { if srv.ConnCallback != nil { cbConn := srv.ConnCallback(newConn) @@ -279,11 +274,9 @@ func (srv *Server) handleConn(newConn net.Conn) { //go gossh.DiscardRequests(reqs) go srv.handleRequests(ctx, reqs) for ch := range chans { - handler, found := srv.channelHandlers[ch.ChannelType()] - if !found || handler == nil { - if defaultHandler, found := srv.channelHandlers["default"]; found { - handler = defaultHandler - } + handler := srv.ChannelHandlers[ch.ChannelType()] + if handler == nil { + handler = srv.ChannelHandlers["default"] } if handler == nil { ch.Reject(gossh.UnknownChannelType, "unsupported channel type") @@ -295,8 +288,11 @@ func (srv *Server) handleConn(newConn net.Conn) { func (srv *Server) handleRequests(ctx Context, in <-chan *gossh.Request) { for req := range in { - handler, found := srv.requestHandlers[req.Type] - if !found { + handler := srv.RequestHandlers[req.Type] + if handler == nil { + handler = srv.RequestHandlers["default"] + } + if handler == nil { if req.WantReply { req.Reply(false, nil) } diff --git a/session_test.go b/session_test.go index 90034a7..4396c5a 100644 --- a/session_test.go +++ b/session_test.go @@ -19,7 +19,7 @@ func (srv *Server) serveOnce(l net.Listener) error { if e != nil { return e } - srv.channelHandlers = map[string]ChannelHandler{ + srv.ChannelHandlers = map[string]ChannelHandler{ "session": ChannelHandlerFunc(sessionHandler), "direct-tcpip": ChannelHandlerFunc(directTcpipHandler), } From 465d1bd2c79d99098fd76071ca12f1e492237558 Mon Sep 17 00:00:00 2001 From: Kaleb Elwert Date: Wed, 12 Jun 2019 10:42:05 -0700 Subject: [PATCH 5/6] Clean up Request replies --- server.go | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/server.go b/server.go index 934c230..41ba87e 100644 --- a/server.go +++ b/server.go @@ -293,17 +293,13 @@ func (srv *Server) handleRequests(ctx Context, in <-chan *gossh.Request) { handler = srv.RequestHandlers["default"] } if handler == nil { - if req.WantReply { - req.Reply(false, nil) - } + req.Reply(false, nil) continue } /*reqCtx, cancel := context.WithCancel(ctx) defer cancel() */ ret, payload := handler.HandleSSHRequest(ctx, srv, req) - if req.WantReply { - req.Reply(ret, payload) - } + req.Reply(ret, payload) } } From dd61f8b0d5629a711ef3582d7a5c3775714ef74d Mon Sep 17 00:00:00 2001 From: Kaleb Elwert Date: Wed, 12 Jun 2019 10:52:26 -0700 Subject: [PATCH 6/6] Disable port forwarding by default Fixes #68 --- server.go | 18 ++++++++++++------ session.go | 2 +- session_test.go | 4 ++-- tcpip.go | 11 ++++++++--- 4 files changed, 23 insertions(+), 12 deletions(-) diff --git a/server.go b/server.go index 41ba87e..dc61fb4 100644 --- a/server.go +++ b/server.go @@ -65,6 +65,8 @@ func (f RequestHandlerFunc) HandleSSHRequest(ctx Context, srv *Server, req *goss return f(ctx, srv, req) } +var DefaultRequestHandlers = map[string]RequestHandler{} + type ChannelHandler interface { HandleSSHChannel(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) } @@ -75,6 +77,10 @@ func (f ChannelHandlerFunc) HandleSSHChannel(srv *Server, conn *gossh.ServerConn f(srv, conn, newChan, ctx) } +var DefaultChannelHandlers = map[string]ChannelHandler{ + "session": ChannelHandlerFunc(DefaultSessionHandler), +} + func (srv *Server) ensureHostSigner() error { if len(srv.HostSigners) == 0 { signer, err := generateSigner() @@ -90,15 +96,15 @@ func (srv *Server) ensureHandlers() { srv.mu.Lock() defer srv.mu.Unlock() if srv.RequestHandlers == nil { - srv.RequestHandlers = map[string]RequestHandler{ - "tcpip-forward": forwardedTCPHandler{}, - "cancel-tcpip-forward": forwardedTCPHandler{}, + srv.RequestHandlers = map[string]RequestHandler{} + for k, v := range DefaultRequestHandlers { + srv.RequestHandlers[k] = v } } if srv.ChannelHandlers == nil { - srv.ChannelHandlers = map[string]ChannelHandler{ - "session": ChannelHandlerFunc(sessionHandler), - "direct-tcpip": ChannelHandlerFunc(directTcpipHandler), + srv.ChannelHandlers = map[string]ChannelHandler{} + for k, v := range DefaultChannelHandlers { + srv.ChannelHandlers[k] = v } } } diff --git a/session.go b/session.go index 19ddda6..a6085f3 100644 --- a/session.go +++ b/session.go @@ -77,7 +77,7 @@ type Session interface { // when there is no signal channel specified const maxSigBufSize = 128 -func sessionHandler(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) { +func DefaultSessionHandler(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) { ch, reqs, err := newChan.Accept() if err != nil { // TODO: trigger event callback diff --git a/session_test.go b/session_test.go index 4396c5a..f47ff8a 100644 --- a/session_test.go +++ b/session_test.go @@ -20,8 +20,8 @@ func (srv *Server) serveOnce(l net.Listener) error { return e } srv.ChannelHandlers = map[string]ChannelHandler{ - "session": ChannelHandlerFunc(sessionHandler), - "direct-tcpip": ChannelHandlerFunc(directTcpipHandler), + "session": ChannelHandlerFunc(DefaultSessionHandler), + "direct-tcpip": ChannelHandlerFunc(DirectTCPIPHandler), } srv.handleConn(conn) return nil diff --git a/tcpip.go b/tcpip.go index 4afbf2d..2a7f33d 100644 --- a/tcpip.go +++ b/tcpip.go @@ -23,7 +23,9 @@ type localForwardChannelData struct { OriginPort uint32 } -func directTcpipHandler(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) { +// DirectTCPIPHandler can be enabled by adding it to the server's +// ChannelHandlers under direct-tcpip. +func DirectTCPIPHandler(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) { d := localForwardChannelData{} if err := gossh.Unmarshal(newChan.ExtraData(), &d); err != nil { newChan.Reject(gossh.ConnectionFailed, "error parsing forward data: "+err.Error()) @@ -84,12 +86,15 @@ type remoteForwardChannelData struct { OriginPort uint32 } -type forwardedTCPHandler struct { +// ForwardedTCPHandler can be enabled by creating a ForwardedTCPHandler and +// adding it to the server's RequestHandlers under tcpip-forward and +// cancel-tcpip-forward. +type ForwardedTCPHandler struct { forwards map[string]net.Listener sync.Mutex } -func (h forwardedTCPHandler) HandleSSHRequest(ctx Context, srv *Server, req *gossh.Request) (bool, []byte) { +func (h ForwardedTCPHandler) HandleSSHRequest(ctx Context, srv *Server, req *gossh.Request) (bool, []byte) { h.Lock() if h.forwards == nil { h.forwards = make(map[string]net.Listener)