diff --git a/server.go b/server.go index 56dec27..5119bec 100644 --- a/server.go +++ b/server.go @@ -47,11 +47,26 @@ type Server struct { connWg sync.WaitGroup doneChan chan struct{} } + type RequestHandler interface { - HandleRequest(ctx Context, srv *Server, req *gossh.Request) (ok bool, payload []byte) + HandleSSHRequest(ctx Context, srv *Server, req *gossh.Request) (ok bool, payload []byte) } -type ChannelHandler func(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) +type RequestHandlerFunc func(ctx Context, srv *Server, req *gossh.Request) (ok bool, payload []byte) + +func (f RequestHandlerFunc) HandleSSHRequest(ctx Context, srv *Server, req *gossh.Request) (ok bool, payload []byte) { + return f(ctx, srv, req) +} + +type ChannelHandler interface { + HandleSSHChannel(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) +} + +type ChannelHandlerFunc func(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) + +func (f ChannelHandlerFunc) HandleSSHChannel(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) { + f(srv, conn, newChan, ctx) +} func (srv *Server) ensureHostSigner() error { if len(srv.HostSigners) == 0 { @@ -75,8 +90,8 @@ func (srv *Server) ensureHandlers() { } if srv.channelHandlers == nil { srv.channelHandlers = map[string]ChannelHandler{ - "session": sessionHandler, - "direct-tcpip": directTcpipHandler, + "session": ChannelHandlerFunc(sessionHandler), + "direct-tcpip": ChannelHandlerFunc(directTcpipHandler), } } } @@ -274,7 +289,7 @@ func (srv *Server) handleConn(newConn net.Conn) { ch.Reject(gossh.UnknownChannelType, "unsupported channel type") continue } - go handler(srv, sshConn, ch, ctx) + go handler.HandleSSHChannel(srv, sshConn, ch, ctx) } } @@ -289,7 +304,7 @@ func (srv *Server) handleRequests(ctx Context, in <-chan *gossh.Request) { } /*reqCtx, cancel := context.WithCancel(ctx) defer cancel() */ - ret, payload := handler.HandleRequest(ctx, srv, req) + ret, payload := handler.HandleSSHRequest(ctx, srv, req) if req.WantReply { req.Reply(ret, payload) } diff --git a/session_test.go b/session_test.go index 82e0d2e..90034a7 100644 --- a/session_test.go +++ b/session_test.go @@ -20,8 +20,8 @@ func (srv *Server) serveOnce(l net.Listener) error { return e } srv.channelHandlers = map[string]ChannelHandler{ - "session": sessionHandler, - "direct-tcpip": directTcpipHandler, + "session": ChannelHandlerFunc(sessionHandler), + "direct-tcpip": ChannelHandlerFunc(directTcpipHandler), } srv.handleConn(conn) return nil diff --git a/tcpip.go b/tcpip.go index 64542ff..4afbf2d 100644 --- a/tcpip.go +++ b/tcpip.go @@ -89,7 +89,7 @@ type forwardedTCPHandler struct { sync.Mutex } -func (h forwardedTCPHandler) HandleRequest(ctx Context, srv *Server, req *gossh.Request) (bool, []byte) { +func (h forwardedTCPHandler) HandleSSHRequest(ctx Context, srv *Server, req *gossh.Request) (bool, []byte) { h.Lock() if h.forwards == nil { h.forwards = make(map[string]net.Listener)