Merge remote-tracking branch 'moul/dev/moul/configurable-handlers' into configurable-handlers
This commit is contained in:
commit
75b695471d
46
server.go
46
server.go
@ -37,7 +37,7 @@ type Server struct {
|
||||
IdleTimeout time.Duration // connection timeout when no activity, none if empty
|
||||
MaxTimeout time.Duration // absolute connection timeout, none if empty
|
||||
|
||||
channelHandlers map[string]channelHandler
|
||||
channelHandlers map[string]ChannelHandler // fallback channel handlers
|
||||
requestHandlers map[string]RequestHandler
|
||||
|
||||
listenerWg sync.WaitGroup
|
||||
@ -51,8 +51,7 @@ type RequestHandler interface {
|
||||
HandleRequest(ctx Context, srv *Server, req *gossh.Request) (ok bool, payload []byte)
|
||||
}
|
||||
|
||||
// internal for now
|
||||
type channelHandler func(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context)
|
||||
type ChannelHandler func(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context)
|
||||
|
||||
func (srv *Server) ensureHostSigner() error {
|
||||
if len(srv.HostSigners) == 0 {
|
||||
@ -68,13 +67,17 @@ func (srv *Server) ensureHostSigner() error {
|
||||
func (srv *Server) ensureHandlers() {
|
||||
srv.mu.Lock()
|
||||
defer srv.mu.Unlock()
|
||||
srv.requestHandlers = map[string]RequestHandler{
|
||||
"tcpip-forward": forwardedTCPHandler{},
|
||||
"cancel-tcpip-forward": forwardedTCPHandler{},
|
||||
if srv.requestHandlers == nil {
|
||||
srv.requestHandlers = map[string]RequestHandler{
|
||||
"tcpip-forward": forwardedTCPHandler{},
|
||||
"cancel-tcpip-forward": forwardedTCPHandler{},
|
||||
}
|
||||
}
|
||||
srv.channelHandlers = map[string]channelHandler{
|
||||
"session": sessionHandler,
|
||||
"direct-tcpip": directTcpipHandler,
|
||||
if srv.channelHandlers == nil {
|
||||
srv.channelHandlers = map[string]ChannelHandler{
|
||||
"session": sessionHandler,
|
||||
"direct-tcpip": directTcpipHandler,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -186,12 +189,6 @@ func (srv *Server) Serve(l net.Listener) error {
|
||||
if srv.Handler == nil {
|
||||
srv.Handler = DefaultHandler
|
||||
}
|
||||
if srv.channelHandlers == nil {
|
||||
srv.channelHandlers = map[string]channelHandler{
|
||||
"session": sessionHandler,
|
||||
"direct-tcpip": directTcpipHandler,
|
||||
}
|
||||
}
|
||||
var tempDelay time.Duration
|
||||
|
||||
srv.trackListener(l, true)
|
||||
@ -222,6 +219,18 @@ func (srv *Server) Serve(l net.Listener) error {
|
||||
}
|
||||
}
|
||||
|
||||
func (srv *Server) SetChannelHandler(kind string, handler ChannelHandler) {
|
||||
srv.ensureHandlers()
|
||||
srv.mu.Lock()
|
||||
defer srv.mu.Unlock()
|
||||
srv.channelHandlers[kind] = handler
|
||||
}
|
||||
|
||||
func (srv *Server) ChannelHandler(kind string) ChannelHandler {
|
||||
srv.ensureHandlers()
|
||||
return srv.channelHandlers[kind]
|
||||
}
|
||||
|
||||
func (srv *Server) handleConn(newConn net.Conn) {
|
||||
if srv.ConnCallback != nil {
|
||||
cbConn := srv.ConnCallback(newConn)
|
||||
@ -256,7 +265,12 @@ func (srv *Server) handleConn(newConn net.Conn) {
|
||||
go srv.handleRequests(ctx, reqs)
|
||||
for ch := range chans {
|
||||
handler, found := srv.channelHandlers[ch.ChannelType()]
|
||||
if !found {
|
||||
if !found || handler == nil {
|
||||
if defaultHandler, found := srv.channelHandlers["default"]; found {
|
||||
handler = defaultHandler
|
||||
}
|
||||
}
|
||||
if handler == nil {
|
||||
ch.Reject(gossh.UnknownChannelType, "unsupported channel type")
|
||||
continue
|
||||
}
|
||||
|
@ -19,7 +19,7 @@ func (srv *Server) serveOnce(l net.Listener) error {
|
||||
if e != nil {
|
||||
return e
|
||||
}
|
||||
srv.channelHandlers = map[string]channelHandler{
|
||||
srv.channelHandlers = map[string]ChannelHandler{
|
||||
"session": sessionHandler,
|
||||
"direct-tcpip": directTcpipHandler,
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user