Merge remote-tracking branch 'moul/dev/moul/configurable-handlers' into configurable-handlers

This commit is contained in:
Kaleb Elwert 2019-06-12 10:14:10 -07:00
commit 75b695471d
2 changed files with 31 additions and 17 deletions

@ -37,7 +37,7 @@ type Server struct {
IdleTimeout time.Duration // connection timeout when no activity, none if empty
MaxTimeout time.Duration // absolute connection timeout, none if empty
channelHandlers map[string]channelHandler
channelHandlers map[string]ChannelHandler // fallback channel handlers
requestHandlers map[string]RequestHandler
listenerWg sync.WaitGroup
@ -51,8 +51,7 @@ type RequestHandler interface {
HandleRequest(ctx Context, srv *Server, req *gossh.Request) (ok bool, payload []byte)
}
// internal for now
type channelHandler func(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context)
type ChannelHandler func(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context)
func (srv *Server) ensureHostSigner() error {
if len(srv.HostSigners) == 0 {
@ -68,13 +67,17 @@ func (srv *Server) ensureHostSigner() error {
func (srv *Server) ensureHandlers() {
srv.mu.Lock()
defer srv.mu.Unlock()
srv.requestHandlers = map[string]RequestHandler{
"tcpip-forward": forwardedTCPHandler{},
"cancel-tcpip-forward": forwardedTCPHandler{},
if srv.requestHandlers == nil {
srv.requestHandlers = map[string]RequestHandler{
"tcpip-forward": forwardedTCPHandler{},
"cancel-tcpip-forward": forwardedTCPHandler{},
}
}
srv.channelHandlers = map[string]channelHandler{
"session": sessionHandler,
"direct-tcpip": directTcpipHandler,
if srv.channelHandlers == nil {
srv.channelHandlers = map[string]ChannelHandler{
"session": sessionHandler,
"direct-tcpip": directTcpipHandler,
}
}
}
@ -186,12 +189,6 @@ func (srv *Server) Serve(l net.Listener) error {
if srv.Handler == nil {
srv.Handler = DefaultHandler
}
if srv.channelHandlers == nil {
srv.channelHandlers = map[string]channelHandler{
"session": sessionHandler,
"direct-tcpip": directTcpipHandler,
}
}
var tempDelay time.Duration
srv.trackListener(l, true)
@ -222,6 +219,18 @@ func (srv *Server) Serve(l net.Listener) error {
}
}
func (srv *Server) SetChannelHandler(kind string, handler ChannelHandler) {
srv.ensureHandlers()
srv.mu.Lock()
defer srv.mu.Unlock()
srv.channelHandlers[kind] = handler
}
func (srv *Server) ChannelHandler(kind string) ChannelHandler {
srv.ensureHandlers()
return srv.channelHandlers[kind]
}
func (srv *Server) handleConn(newConn net.Conn) {
if srv.ConnCallback != nil {
cbConn := srv.ConnCallback(newConn)
@ -256,7 +265,12 @@ func (srv *Server) handleConn(newConn net.Conn) {
go srv.handleRequests(ctx, reqs)
for ch := range chans {
handler, found := srv.channelHandlers[ch.ChannelType()]
if !found {
if !found || handler == nil {
if defaultHandler, found := srv.channelHandlers["default"]; found {
handler = defaultHandler
}
}
if handler == nil {
ch.Reject(gossh.UnknownChannelType, "unsupported channel type")
continue
}

@ -19,7 +19,7 @@ func (srv *Server) serveOnce(l net.Listener) error {
if e != nil {
return e
}
srv.channelHandlers = map[string]channelHandler{
srv.channelHandlers = map[string]ChannelHandler{
"session": sessionHandler,
"direct-tcpip": directTcpipHandler,
}