feat: configurable server handlers
This commit is contained in:
parent
cbabf54144
commit
8b3cdd49b6
46
server.go
46
server.go
@ -34,7 +34,7 @@ type Server struct {
|
||||
IdleTimeout time.Duration // connection timeout when no activity, none if empty
|
||||
MaxTimeout time.Duration // absolute connection timeout, none if empty
|
||||
|
||||
channelHandlers map[string]channelHandler
|
||||
channelHandlers map[string]ChannelHandler // fallback channel handlers
|
||||
requestHandlers map[string]RequestHandler
|
||||
|
||||
listenerWg sync.WaitGroup
|
||||
@ -48,8 +48,7 @@ type RequestHandler interface {
|
||||
HandleRequest(ctx Context, srv *Server, req *gossh.Request) (ok bool, payload []byte)
|
||||
}
|
||||
|
||||
// internal for now
|
||||
type channelHandler func(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context)
|
||||
type ChannelHandler func(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context)
|
||||
|
||||
func (srv *Server) ensureHostSigner() error {
|
||||
if len(srv.HostSigners) == 0 {
|
||||
@ -65,13 +64,17 @@ func (srv *Server) ensureHostSigner() error {
|
||||
func (srv *Server) ensureHandlers() {
|
||||
srv.mu.Lock()
|
||||
defer srv.mu.Unlock()
|
||||
srv.requestHandlers = map[string]RequestHandler{
|
||||
"tcpip-forward": forwardedTCPHandler{},
|
||||
"cancel-tcpip-forward": forwardedTCPHandler{},
|
||||
if srv.requestHandlers == nil {
|
||||
srv.requestHandlers = map[string]RequestHandler{
|
||||
"tcpip-forward": forwardedTCPHandler{},
|
||||
"cancel-tcpip-forward": forwardedTCPHandler{},
|
||||
}
|
||||
}
|
||||
srv.channelHandlers = map[string]channelHandler{
|
||||
"session": sessionHandler,
|
||||
"direct-tcpip": directTcpipHandler,
|
||||
if srv.channelHandlers == nil {
|
||||
srv.channelHandlers = map[string]ChannelHandler{
|
||||
"session": sessionHandler,
|
||||
"direct-tcpip": directTcpipHandler,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -170,12 +173,6 @@ func (srv *Server) Serve(l net.Listener) error {
|
||||
if srv.Handler == nil {
|
||||
srv.Handler = DefaultHandler
|
||||
}
|
||||
if srv.channelHandlers == nil {
|
||||
srv.channelHandlers = map[string]channelHandler{
|
||||
"session": sessionHandler,
|
||||
"direct-tcpip": directTcpipHandler,
|
||||
}
|
||||
}
|
||||
var tempDelay time.Duration
|
||||
|
||||
srv.trackListener(l, true)
|
||||
@ -206,6 +203,18 @@ func (srv *Server) Serve(l net.Listener) error {
|
||||
}
|
||||
}
|
||||
|
||||
func (srv *Server) SetChannelHandler(kind string, handler ChannelHandler) {
|
||||
srv.ensureHandlers()
|
||||
srv.mu.Lock()
|
||||
defer srv.mu.Unlock()
|
||||
srv.channelHandlers[kind] = handler
|
||||
}
|
||||
|
||||
func (srv *Server) GetChannelHandler(kind string) ChannelHandler {
|
||||
srv.ensureHandlers()
|
||||
return srv.channelHandlers[kind]
|
||||
}
|
||||
|
||||
func (srv *Server) handleConn(newConn net.Conn) {
|
||||
if srv.ConnCallback != nil {
|
||||
cbConn := srv.ConnCallback(newConn)
|
||||
@ -240,7 +249,12 @@ func (srv *Server) handleConn(newConn net.Conn) {
|
||||
go srv.handleRequests(ctx, reqs)
|
||||
for ch := range chans {
|
||||
handler, found := srv.channelHandlers[ch.ChannelType()]
|
||||
if !found {
|
||||
if !found || handler == nil {
|
||||
if defaultHandler, found := srv.channelHandlers["default"]; found {
|
||||
handler = defaultHandler
|
||||
}
|
||||
}
|
||||
if handler == nil {
|
||||
ch.Reject(gossh.UnknownChannelType, "unsupported channel type")
|
||||
continue
|
||||
}
|
||||
|
@ -19,7 +19,7 @@ func (srv *Server) serveOnce(l net.Listener) error {
|
||||
if e != nil {
|
||||
return e
|
||||
}
|
||||
srv.channelHandlers = map[string]channelHandler{
|
||||
srv.channelHandlers = map[string]ChannelHandler{
|
||||
"session": sessionHandler,
|
||||
"direct-tcpip": directTcpipHandler,
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user