Merge pull request #108 from gliderlabs/configurable-handlers

Configurable channel handlers

Closes #89, #71
This commit is contained in:
Kaleb Elwert 2019-06-19 00:26:59 -07:00 committed by GitHub
commit f199e8cd1e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 68 additions and 36 deletions

@ -37,8 +37,15 @@ type Server struct {
IdleTimeout time.Duration // connection timeout when no activity, none if empty IdleTimeout time.Duration // connection timeout when no activity, none if empty
MaxTimeout time.Duration // absolute connection timeout, none if empty MaxTimeout time.Duration // absolute connection timeout, none if empty
channelHandlers map[string]channelHandler // ChannelHandlers allow overriding the built-in session handlers or provide
requestHandlers map[string]RequestHandler // extensions to the protocol, such as tcpip forwarding. By default only the
// "session" handler is enabled.
ChannelHandlers map[string]ChannelHandler
// RequestHandlers allow overriding the server-level request handlers or
// provide extensions to the protocol, such as tcpip forwarding. By default
// no handlers are enabled.
RequestHandlers map[string]RequestHandler
listenerWg sync.WaitGroup listenerWg sync.WaitGroup
mu sync.Mutex mu sync.Mutex
@ -47,12 +54,32 @@ type Server struct {
connWg sync.WaitGroup connWg sync.WaitGroup
doneChan chan struct{} doneChan chan struct{}
} }
type RequestHandler interface { type RequestHandler interface {
HandleRequest(ctx Context, srv *Server, req *gossh.Request) (ok bool, payload []byte) HandleSSHRequest(ctx Context, srv *Server, req *gossh.Request) (ok bool, payload []byte)
} }
// internal for now type RequestHandlerFunc func(ctx Context, srv *Server, req *gossh.Request) (ok bool, payload []byte)
type channelHandler func(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context)
func (f RequestHandlerFunc) HandleSSHRequest(ctx Context, srv *Server, req *gossh.Request) (ok bool, payload []byte) {
return f(ctx, srv, req)
}
var DefaultRequestHandlers = map[string]RequestHandler{}
type ChannelHandler interface {
HandleSSHChannel(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context)
}
type ChannelHandlerFunc func(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context)
func (f ChannelHandlerFunc) HandleSSHChannel(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) {
f(srv, conn, newChan, ctx)
}
var DefaultChannelHandlers = map[string]ChannelHandler{
"session": ChannelHandlerFunc(DefaultSessionHandler),
}
func (srv *Server) ensureHostSigner() error { func (srv *Server) ensureHostSigner() error {
if len(srv.HostSigners) == 0 { if len(srv.HostSigners) == 0 {
@ -68,13 +95,17 @@ func (srv *Server) ensureHostSigner() error {
func (srv *Server) ensureHandlers() { func (srv *Server) ensureHandlers() {
srv.mu.Lock() srv.mu.Lock()
defer srv.mu.Unlock() defer srv.mu.Unlock()
srv.requestHandlers = map[string]RequestHandler{ if srv.RequestHandlers == nil {
"tcpip-forward": forwardedTCPHandler{}, srv.RequestHandlers = map[string]RequestHandler{}
"cancel-tcpip-forward": forwardedTCPHandler{}, for k, v := range DefaultRequestHandlers {
srv.RequestHandlers[k] = v
}
} }
srv.channelHandlers = map[string]channelHandler{ if srv.ChannelHandlers == nil {
"session": sessionHandler, srv.ChannelHandlers = map[string]ChannelHandler{}
"direct-tcpip": directTcpipHandler, for k, v := range DefaultChannelHandlers {
srv.ChannelHandlers[k] = v
}
} }
} }
@ -186,12 +217,6 @@ func (srv *Server) Serve(l net.Listener) error {
if srv.Handler == nil { if srv.Handler == nil {
srv.Handler = DefaultHandler srv.Handler = DefaultHandler
} }
if srv.channelHandlers == nil {
srv.channelHandlers = map[string]channelHandler{
"session": sessionHandler,
"direct-tcpip": directTcpipHandler,
}
}
var tempDelay time.Duration var tempDelay time.Duration
srv.trackListener(l, true) srv.trackListener(l, true)
@ -255,30 +280,32 @@ func (srv *Server) handleConn(newConn net.Conn) {
//go gossh.DiscardRequests(reqs) //go gossh.DiscardRequests(reqs)
go srv.handleRequests(ctx, reqs) go srv.handleRequests(ctx, reqs)
for ch := range chans { for ch := range chans {
handler, found := srv.channelHandlers[ch.ChannelType()] handler := srv.ChannelHandlers[ch.ChannelType()]
if !found { if handler == nil {
handler = srv.ChannelHandlers["default"]
}
if handler == nil {
ch.Reject(gossh.UnknownChannelType, "unsupported channel type") ch.Reject(gossh.UnknownChannelType, "unsupported channel type")
continue continue
} }
go handler(srv, sshConn, ch, ctx) go handler.HandleSSHChannel(srv, sshConn, ch, ctx)
} }
} }
func (srv *Server) handleRequests(ctx Context, in <-chan *gossh.Request) { func (srv *Server) handleRequests(ctx Context, in <-chan *gossh.Request) {
for req := range in { for req := range in {
handler, found := srv.requestHandlers[req.Type] handler := srv.RequestHandlers[req.Type]
if !found { if handler == nil {
if req.WantReply { handler = srv.RequestHandlers["default"]
req.Reply(false, nil) }
} if handler == nil {
req.Reply(false, nil)
continue continue
} }
/*reqCtx, cancel := context.WithCancel(ctx) /*reqCtx, cancel := context.WithCancel(ctx)
defer cancel() */ defer cancel() */
ret, payload := handler.HandleRequest(ctx, srv, req) ret, payload := handler.HandleSSHRequest(ctx, srv, req)
if req.WantReply { req.Reply(ret, payload)
req.Reply(ret, payload)
}
} }
} }

@ -77,7 +77,7 @@ type Session interface {
// when there is no signal channel specified // when there is no signal channel specified
const maxSigBufSize = 128 const maxSigBufSize = 128
func sessionHandler(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) { func DefaultSessionHandler(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) {
ch, reqs, err := newChan.Accept() ch, reqs, err := newChan.Accept()
if err != nil { if err != nil {
// TODO: trigger event callback // TODO: trigger event callback

@ -19,9 +19,9 @@ func (srv *Server) serveOnce(l net.Listener) error {
if e != nil { if e != nil {
return e return e
} }
srv.channelHandlers = map[string]channelHandler{ srv.ChannelHandlers = map[string]ChannelHandler{
"session": sessionHandler, "session": ChannelHandlerFunc(DefaultSessionHandler),
"direct-tcpip": directTcpipHandler, "direct-tcpip": ChannelHandlerFunc(DirectTCPIPHandler),
} }
srv.handleConn(conn) srv.handleConn(conn)
return nil return nil

@ -23,7 +23,9 @@ type localForwardChannelData struct {
OriginPort uint32 OriginPort uint32
} }
func directTcpipHandler(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) { // DirectTCPIPHandler can be enabled by adding it to the server's
// ChannelHandlers under direct-tcpip.
func DirectTCPIPHandler(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) {
d := localForwardChannelData{} d := localForwardChannelData{}
if err := gossh.Unmarshal(newChan.ExtraData(), &d); err != nil { if err := gossh.Unmarshal(newChan.ExtraData(), &d); err != nil {
newChan.Reject(gossh.ConnectionFailed, "error parsing forward data: "+err.Error()) newChan.Reject(gossh.ConnectionFailed, "error parsing forward data: "+err.Error())
@ -84,12 +86,15 @@ type remoteForwardChannelData struct {
OriginPort uint32 OriginPort uint32
} }
type forwardedTCPHandler struct { // ForwardedTCPHandler can be enabled by creating a ForwardedTCPHandler and
// adding it to the server's RequestHandlers under tcpip-forward and
// cancel-tcpip-forward.
type ForwardedTCPHandler struct {
forwards map[string]net.Listener forwards map[string]net.Listener
sync.Mutex sync.Mutex
} }
func (h forwardedTCPHandler) HandleRequest(ctx Context, srv *Server, req *gossh.Request) (bool, []byte) { func (h ForwardedTCPHandler) HandleSSHRequest(ctx Context, srv *Server, req *gossh.Request) (bool, []byte) {
h.Lock() h.Lock()
if h.forwards == nil { if h.forwards == nil {
h.forwards = make(map[string]net.Listener) h.forwards = make(map[string]net.Listener)