Clean up Channel and Request Handler interfaces

This commit is contained in:
Kaleb Elwert 2019-06-12 10:29:36 -07:00
parent 75b695471d
commit 77856273e0
3 changed files with 24 additions and 9 deletions

@ -47,11 +47,26 @@ type Server struct {
connWg sync.WaitGroup connWg sync.WaitGroup
doneChan chan struct{} doneChan chan struct{}
} }
type RequestHandler interface { type RequestHandler interface {
HandleRequest(ctx Context, srv *Server, req *gossh.Request) (ok bool, payload []byte) HandleSSHRequest(ctx Context, srv *Server, req *gossh.Request) (ok bool, payload []byte)
} }
type ChannelHandler func(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) type RequestHandlerFunc func(ctx Context, srv *Server, req *gossh.Request) (ok bool, payload []byte)
func (f RequestHandlerFunc) HandleSSHRequest(ctx Context, srv *Server, req *gossh.Request) (ok bool, payload []byte) {
return f(ctx, srv, req)
}
type ChannelHandler interface {
HandleSSHChannel(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context)
}
type ChannelHandlerFunc func(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context)
func (f ChannelHandlerFunc) HandleSSHChannel(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) {
f(srv, conn, newChan, ctx)
}
func (srv *Server) ensureHostSigner() error { func (srv *Server) ensureHostSigner() error {
if len(srv.HostSigners) == 0 { if len(srv.HostSigners) == 0 {
@ -75,8 +90,8 @@ func (srv *Server) ensureHandlers() {
} }
if srv.channelHandlers == nil { if srv.channelHandlers == nil {
srv.channelHandlers = map[string]ChannelHandler{ srv.channelHandlers = map[string]ChannelHandler{
"session": sessionHandler, "session": ChannelHandlerFunc(sessionHandler),
"direct-tcpip": directTcpipHandler, "direct-tcpip": ChannelHandlerFunc(directTcpipHandler),
} }
} }
} }
@ -274,7 +289,7 @@ func (srv *Server) handleConn(newConn net.Conn) {
ch.Reject(gossh.UnknownChannelType, "unsupported channel type") ch.Reject(gossh.UnknownChannelType, "unsupported channel type")
continue continue
} }
go handler(srv, sshConn, ch, ctx) go handler.HandleSSHChannel(srv, sshConn, ch, ctx)
} }
} }
@ -289,7 +304,7 @@ func (srv *Server) handleRequests(ctx Context, in <-chan *gossh.Request) {
} }
/*reqCtx, cancel := context.WithCancel(ctx) /*reqCtx, cancel := context.WithCancel(ctx)
defer cancel() */ defer cancel() */
ret, payload := handler.HandleRequest(ctx, srv, req) ret, payload := handler.HandleSSHRequest(ctx, srv, req)
if req.WantReply { if req.WantReply {
req.Reply(ret, payload) req.Reply(ret, payload)
} }

@ -20,8 +20,8 @@ func (srv *Server) serveOnce(l net.Listener) error {
return e return e
} }
srv.channelHandlers = map[string]ChannelHandler{ srv.channelHandlers = map[string]ChannelHandler{
"session": sessionHandler, "session": ChannelHandlerFunc(sessionHandler),
"direct-tcpip": directTcpipHandler, "direct-tcpip": ChannelHandlerFunc(directTcpipHandler),
} }
srv.handleConn(conn) srv.handleConn(conn)
return nil return nil

@ -89,7 +89,7 @@ type forwardedTCPHandler struct {
sync.Mutex sync.Mutex
} }
func (h forwardedTCPHandler) HandleRequest(ctx Context, srv *Server, req *gossh.Request) (bool, []byte) { func (h forwardedTCPHandler) HandleSSHRequest(ctx Context, srv *Server, req *gossh.Request) (bool, []byte) {
h.Lock() h.Lock()
if h.forwards == nil { if h.forwards == nil {
h.forwards = make(map[string]net.Listener) h.forwards = make(map[string]net.Listener)