diff --git a/server.go b/server.go index 50eb632..dc61fb4 100644 --- a/server.go +++ b/server.go @@ -37,8 +37,15 @@ type Server struct { IdleTimeout time.Duration // connection timeout when no activity, none if empty MaxTimeout time.Duration // absolute connection timeout, none if empty - channelHandlers map[string]channelHandler - requestHandlers map[string]RequestHandler + // ChannelHandlers allow overriding the built-in session handlers or provide + // extensions to the protocol, such as tcpip forwarding. By default only the + // "session" handler is enabled. + ChannelHandlers map[string]ChannelHandler + + // RequestHandlers allow overriding the server-level request handlers or + // provide extensions to the protocol, such as tcpip forwarding. By default + // no handlers are enabled. + RequestHandlers map[string]RequestHandler listenerWg sync.WaitGroup mu sync.Mutex @@ -47,12 +54,32 @@ type Server struct { connWg sync.WaitGroup doneChan chan struct{} } + type RequestHandler interface { - HandleRequest(ctx Context, srv *Server, req *gossh.Request) (ok bool, payload []byte) + HandleSSHRequest(ctx Context, srv *Server, req *gossh.Request) (ok bool, payload []byte) } -// internal for now -type channelHandler func(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) +type RequestHandlerFunc func(ctx Context, srv *Server, req *gossh.Request) (ok bool, payload []byte) + +func (f RequestHandlerFunc) HandleSSHRequest(ctx Context, srv *Server, req *gossh.Request) (ok bool, payload []byte) { + return f(ctx, srv, req) +} + +var DefaultRequestHandlers = map[string]RequestHandler{} + +type ChannelHandler interface { + HandleSSHChannel(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) +} + +type ChannelHandlerFunc func(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) + +func (f ChannelHandlerFunc) HandleSSHChannel(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) { + f(srv, conn, newChan, ctx) +} + +var DefaultChannelHandlers = map[string]ChannelHandler{ + "session": ChannelHandlerFunc(DefaultSessionHandler), +} func (srv *Server) ensureHostSigner() error { if len(srv.HostSigners) == 0 { @@ -68,13 +95,17 @@ func (srv *Server) ensureHostSigner() error { func (srv *Server) ensureHandlers() { srv.mu.Lock() defer srv.mu.Unlock() - srv.requestHandlers = map[string]RequestHandler{ - "tcpip-forward": forwardedTCPHandler{}, - "cancel-tcpip-forward": forwardedTCPHandler{}, + if srv.RequestHandlers == nil { + srv.RequestHandlers = map[string]RequestHandler{} + for k, v := range DefaultRequestHandlers { + srv.RequestHandlers[k] = v + } } - srv.channelHandlers = map[string]channelHandler{ - "session": sessionHandler, - "direct-tcpip": directTcpipHandler, + if srv.ChannelHandlers == nil { + srv.ChannelHandlers = map[string]ChannelHandler{} + for k, v := range DefaultChannelHandlers { + srv.ChannelHandlers[k] = v + } } } @@ -186,12 +217,6 @@ func (srv *Server) Serve(l net.Listener) error { if srv.Handler == nil { srv.Handler = DefaultHandler } - if srv.channelHandlers == nil { - srv.channelHandlers = map[string]channelHandler{ - "session": sessionHandler, - "direct-tcpip": directTcpipHandler, - } - } var tempDelay time.Duration srv.trackListener(l, true) @@ -255,30 +280,32 @@ func (srv *Server) handleConn(newConn net.Conn) { //go gossh.DiscardRequests(reqs) go srv.handleRequests(ctx, reqs) for ch := range chans { - handler, found := srv.channelHandlers[ch.ChannelType()] - if !found { + handler := srv.ChannelHandlers[ch.ChannelType()] + if handler == nil { + handler = srv.ChannelHandlers["default"] + } + if handler == nil { ch.Reject(gossh.UnknownChannelType, "unsupported channel type") continue } - go handler(srv, sshConn, ch, ctx) + go handler.HandleSSHChannel(srv, sshConn, ch, ctx) } } func (srv *Server) handleRequests(ctx Context, in <-chan *gossh.Request) { for req := range in { - handler, found := srv.requestHandlers[req.Type] - if !found { - if req.WantReply { - req.Reply(false, nil) - } + handler := srv.RequestHandlers[req.Type] + if handler == nil { + handler = srv.RequestHandlers["default"] + } + if handler == nil { + req.Reply(false, nil) continue } /*reqCtx, cancel := context.WithCancel(ctx) defer cancel() */ - ret, payload := handler.HandleRequest(ctx, srv, req) - if req.WantReply { - req.Reply(ret, payload) - } + ret, payload := handler.HandleSSHRequest(ctx, srv, req) + req.Reply(ret, payload) } } diff --git a/session.go b/session.go index 19ddda6..a6085f3 100644 --- a/session.go +++ b/session.go @@ -77,7 +77,7 @@ type Session interface { // when there is no signal channel specified const maxSigBufSize = 128 -func sessionHandler(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) { +func DefaultSessionHandler(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) { ch, reqs, err := newChan.Accept() if err != nil { // TODO: trigger event callback diff --git a/session_test.go b/session_test.go index 06c7724..f47ff8a 100644 --- a/session_test.go +++ b/session_test.go @@ -19,9 +19,9 @@ func (srv *Server) serveOnce(l net.Listener) error { if e != nil { return e } - srv.channelHandlers = map[string]channelHandler{ - "session": sessionHandler, - "direct-tcpip": directTcpipHandler, + srv.ChannelHandlers = map[string]ChannelHandler{ + "session": ChannelHandlerFunc(DefaultSessionHandler), + "direct-tcpip": ChannelHandlerFunc(DirectTCPIPHandler), } srv.handleConn(conn) return nil diff --git a/tcpip.go b/tcpip.go index 64542ff..2a7f33d 100644 --- a/tcpip.go +++ b/tcpip.go @@ -23,7 +23,9 @@ type localForwardChannelData struct { OriginPort uint32 } -func directTcpipHandler(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) { +// DirectTCPIPHandler can be enabled by adding it to the server's +// ChannelHandlers under direct-tcpip. +func DirectTCPIPHandler(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) { d := localForwardChannelData{} if err := gossh.Unmarshal(newChan.ExtraData(), &d); err != nil { newChan.Reject(gossh.ConnectionFailed, "error parsing forward data: "+err.Error()) @@ -84,12 +86,15 @@ type remoteForwardChannelData struct { OriginPort uint32 } -type forwardedTCPHandler struct { +// ForwardedTCPHandler can be enabled by creating a ForwardedTCPHandler and +// adding it to the server's RequestHandlers under tcpip-forward and +// cancel-tcpip-forward. +type ForwardedTCPHandler struct { forwards map[string]net.Listener sync.Mutex } -func (h forwardedTCPHandler) HandleRequest(ctx Context, srv *Server, req *gossh.Request) (bool, []byte) { +func (h ForwardedTCPHandler) HandleSSHRequest(ctx Context, srv *Server, req *gossh.Request) (bool, []byte) { h.Lock() if h.forwards == nil { h.forwards = make(map[string]net.Listener)