Remove Handler getters and setters
This commit is contained in:
parent
77856273e0
commit
c9e327ebeb
46
server.go
46
server.go
@ -37,8 +37,15 @@ type Server struct {
|
|||||||
IdleTimeout time.Duration // connection timeout when no activity, none if empty
|
IdleTimeout time.Duration // connection timeout when no activity, none if empty
|
||||||
MaxTimeout time.Duration // absolute connection timeout, none if empty
|
MaxTimeout time.Duration // absolute connection timeout, none if empty
|
||||||
|
|
||||||
channelHandlers map[string]ChannelHandler // fallback channel handlers
|
// ChannelHandlers allow overriding the built-in session handlers or provide
|
||||||
requestHandlers map[string]RequestHandler
|
// extensions to the protocol, such as tcpip forwarding. By default only the
|
||||||
|
// "session" handler is enabled.
|
||||||
|
ChannelHandlers map[string]ChannelHandler
|
||||||
|
|
||||||
|
// RequestHandlers allow overriding the server-level request handlers or
|
||||||
|
// provide extensions to the protocol, such as tcpip forwarding. By default
|
||||||
|
// no handlers are enabled.
|
||||||
|
RequestHandlers map[string]RequestHandler
|
||||||
|
|
||||||
listenerWg sync.WaitGroup
|
listenerWg sync.WaitGroup
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
@ -82,14 +89,14 @@ func (srv *Server) ensureHostSigner() error {
|
|||||||
func (srv *Server) ensureHandlers() {
|
func (srv *Server) ensureHandlers() {
|
||||||
srv.mu.Lock()
|
srv.mu.Lock()
|
||||||
defer srv.mu.Unlock()
|
defer srv.mu.Unlock()
|
||||||
if srv.requestHandlers == nil {
|
if srv.RequestHandlers == nil {
|
||||||
srv.requestHandlers = map[string]RequestHandler{
|
srv.RequestHandlers = map[string]RequestHandler{
|
||||||
"tcpip-forward": forwardedTCPHandler{},
|
"tcpip-forward": forwardedTCPHandler{},
|
||||||
"cancel-tcpip-forward": forwardedTCPHandler{},
|
"cancel-tcpip-forward": forwardedTCPHandler{},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if srv.channelHandlers == nil {
|
if srv.ChannelHandlers == nil {
|
||||||
srv.channelHandlers = map[string]ChannelHandler{
|
srv.ChannelHandlers = map[string]ChannelHandler{
|
||||||
"session": ChannelHandlerFunc(sessionHandler),
|
"session": ChannelHandlerFunc(sessionHandler),
|
||||||
"direct-tcpip": ChannelHandlerFunc(directTcpipHandler),
|
"direct-tcpip": ChannelHandlerFunc(directTcpipHandler),
|
||||||
}
|
}
|
||||||
@ -234,18 +241,6 @@ func (srv *Server) Serve(l net.Listener) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (srv *Server) SetChannelHandler(kind string, handler ChannelHandler) {
|
|
||||||
srv.ensureHandlers()
|
|
||||||
srv.mu.Lock()
|
|
||||||
defer srv.mu.Unlock()
|
|
||||||
srv.channelHandlers[kind] = handler
|
|
||||||
}
|
|
||||||
|
|
||||||
func (srv *Server) ChannelHandler(kind string) ChannelHandler {
|
|
||||||
srv.ensureHandlers()
|
|
||||||
return srv.channelHandlers[kind]
|
|
||||||
}
|
|
||||||
|
|
||||||
func (srv *Server) handleConn(newConn net.Conn) {
|
func (srv *Server) handleConn(newConn net.Conn) {
|
||||||
if srv.ConnCallback != nil {
|
if srv.ConnCallback != nil {
|
||||||
cbConn := srv.ConnCallback(newConn)
|
cbConn := srv.ConnCallback(newConn)
|
||||||
@ -279,11 +274,9 @@ func (srv *Server) handleConn(newConn net.Conn) {
|
|||||||
//go gossh.DiscardRequests(reqs)
|
//go gossh.DiscardRequests(reqs)
|
||||||
go srv.handleRequests(ctx, reqs)
|
go srv.handleRequests(ctx, reqs)
|
||||||
for ch := range chans {
|
for ch := range chans {
|
||||||
handler, found := srv.channelHandlers[ch.ChannelType()]
|
handler := srv.ChannelHandlers[ch.ChannelType()]
|
||||||
if !found || handler == nil {
|
if handler == nil {
|
||||||
if defaultHandler, found := srv.channelHandlers["default"]; found {
|
handler = srv.ChannelHandlers["default"]
|
||||||
handler = defaultHandler
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
if handler == nil {
|
if handler == nil {
|
||||||
ch.Reject(gossh.UnknownChannelType, "unsupported channel type")
|
ch.Reject(gossh.UnknownChannelType, "unsupported channel type")
|
||||||
@ -295,8 +288,11 @@ func (srv *Server) handleConn(newConn net.Conn) {
|
|||||||
|
|
||||||
func (srv *Server) handleRequests(ctx Context, in <-chan *gossh.Request) {
|
func (srv *Server) handleRequests(ctx Context, in <-chan *gossh.Request) {
|
||||||
for req := range in {
|
for req := range in {
|
||||||
handler, found := srv.requestHandlers[req.Type]
|
handler := srv.RequestHandlers[req.Type]
|
||||||
if !found {
|
if handler == nil {
|
||||||
|
handler = srv.RequestHandlers["default"]
|
||||||
|
}
|
||||||
|
if handler == nil {
|
||||||
if req.WantReply {
|
if req.WantReply {
|
||||||
req.Reply(false, nil)
|
req.Reply(false, nil)
|
||||||
}
|
}
|
||||||
|
@ -19,7 +19,7 @@ func (srv *Server) serveOnce(l net.Listener) error {
|
|||||||
if e != nil {
|
if e != nil {
|
||||||
return e
|
return e
|
||||||
}
|
}
|
||||||
srv.channelHandlers = map[string]ChannelHandler{
|
srv.ChannelHandlers = map[string]ChannelHandler{
|
||||||
"session": ChannelHandlerFunc(sessionHandler),
|
"session": ChannelHandlerFunc(sessionHandler),
|
||||||
"direct-tcpip": ChannelHandlerFunc(directTcpipHandler),
|
"direct-tcpip": ChannelHandlerFunc(directTcpipHandler),
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user