Merge pull request #108 from gliderlabs/configurable-handlers
Configurable channel handlers Closes #89, #71
This commit is contained in:
commit
f199e8cd1e
85
server.go
85
server.go
@ -37,8 +37,15 @@ type Server struct {
|
|||||||
IdleTimeout time.Duration // connection timeout when no activity, none if empty
|
IdleTimeout time.Duration // connection timeout when no activity, none if empty
|
||||||
MaxTimeout time.Duration // absolute connection timeout, none if empty
|
MaxTimeout time.Duration // absolute connection timeout, none if empty
|
||||||
|
|
||||||
channelHandlers map[string]channelHandler
|
// ChannelHandlers allow overriding the built-in session handlers or provide
|
||||||
requestHandlers map[string]RequestHandler
|
// extensions to the protocol, such as tcpip forwarding. By default only the
|
||||||
|
// "session" handler is enabled.
|
||||||
|
ChannelHandlers map[string]ChannelHandler
|
||||||
|
|
||||||
|
// RequestHandlers allow overriding the server-level request handlers or
|
||||||
|
// provide extensions to the protocol, such as tcpip forwarding. By default
|
||||||
|
// no handlers are enabled.
|
||||||
|
RequestHandlers map[string]RequestHandler
|
||||||
|
|
||||||
listenerWg sync.WaitGroup
|
listenerWg sync.WaitGroup
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
@ -47,12 +54,32 @@ type Server struct {
|
|||||||
connWg sync.WaitGroup
|
connWg sync.WaitGroup
|
||||||
doneChan chan struct{}
|
doneChan chan struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
type RequestHandler interface {
|
type RequestHandler interface {
|
||||||
HandleRequest(ctx Context, srv *Server, req *gossh.Request) (ok bool, payload []byte)
|
HandleSSHRequest(ctx Context, srv *Server, req *gossh.Request) (ok bool, payload []byte)
|
||||||
}
|
}
|
||||||
|
|
||||||
// internal for now
|
type RequestHandlerFunc func(ctx Context, srv *Server, req *gossh.Request) (ok bool, payload []byte)
|
||||||
type channelHandler func(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context)
|
|
||||||
|
func (f RequestHandlerFunc) HandleSSHRequest(ctx Context, srv *Server, req *gossh.Request) (ok bool, payload []byte) {
|
||||||
|
return f(ctx, srv, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
var DefaultRequestHandlers = map[string]RequestHandler{}
|
||||||
|
|
||||||
|
type ChannelHandler interface {
|
||||||
|
HandleSSHChannel(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context)
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChannelHandlerFunc func(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context)
|
||||||
|
|
||||||
|
func (f ChannelHandlerFunc) HandleSSHChannel(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) {
|
||||||
|
f(srv, conn, newChan, ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
var DefaultChannelHandlers = map[string]ChannelHandler{
|
||||||
|
"session": ChannelHandlerFunc(DefaultSessionHandler),
|
||||||
|
}
|
||||||
|
|
||||||
func (srv *Server) ensureHostSigner() error {
|
func (srv *Server) ensureHostSigner() error {
|
||||||
if len(srv.HostSigners) == 0 {
|
if len(srv.HostSigners) == 0 {
|
||||||
@ -68,13 +95,17 @@ func (srv *Server) ensureHostSigner() error {
|
|||||||
func (srv *Server) ensureHandlers() {
|
func (srv *Server) ensureHandlers() {
|
||||||
srv.mu.Lock()
|
srv.mu.Lock()
|
||||||
defer srv.mu.Unlock()
|
defer srv.mu.Unlock()
|
||||||
srv.requestHandlers = map[string]RequestHandler{
|
if srv.RequestHandlers == nil {
|
||||||
"tcpip-forward": forwardedTCPHandler{},
|
srv.RequestHandlers = map[string]RequestHandler{}
|
||||||
"cancel-tcpip-forward": forwardedTCPHandler{},
|
for k, v := range DefaultRequestHandlers {
|
||||||
|
srv.RequestHandlers[k] = v
|
||||||
|
}
|
||||||
}
|
}
|
||||||
srv.channelHandlers = map[string]channelHandler{
|
if srv.ChannelHandlers == nil {
|
||||||
"session": sessionHandler,
|
srv.ChannelHandlers = map[string]ChannelHandler{}
|
||||||
"direct-tcpip": directTcpipHandler,
|
for k, v := range DefaultChannelHandlers {
|
||||||
|
srv.ChannelHandlers[k] = v
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -186,12 +217,6 @@ func (srv *Server) Serve(l net.Listener) error {
|
|||||||
if srv.Handler == nil {
|
if srv.Handler == nil {
|
||||||
srv.Handler = DefaultHandler
|
srv.Handler = DefaultHandler
|
||||||
}
|
}
|
||||||
if srv.channelHandlers == nil {
|
|
||||||
srv.channelHandlers = map[string]channelHandler{
|
|
||||||
"session": sessionHandler,
|
|
||||||
"direct-tcpip": directTcpipHandler,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
var tempDelay time.Duration
|
var tempDelay time.Duration
|
||||||
|
|
||||||
srv.trackListener(l, true)
|
srv.trackListener(l, true)
|
||||||
@ -255,30 +280,32 @@ func (srv *Server) handleConn(newConn net.Conn) {
|
|||||||
//go gossh.DiscardRequests(reqs)
|
//go gossh.DiscardRequests(reqs)
|
||||||
go srv.handleRequests(ctx, reqs)
|
go srv.handleRequests(ctx, reqs)
|
||||||
for ch := range chans {
|
for ch := range chans {
|
||||||
handler, found := srv.channelHandlers[ch.ChannelType()]
|
handler := srv.ChannelHandlers[ch.ChannelType()]
|
||||||
if !found {
|
if handler == nil {
|
||||||
|
handler = srv.ChannelHandlers["default"]
|
||||||
|
}
|
||||||
|
if handler == nil {
|
||||||
ch.Reject(gossh.UnknownChannelType, "unsupported channel type")
|
ch.Reject(gossh.UnknownChannelType, "unsupported channel type")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
go handler(srv, sshConn, ch, ctx)
|
go handler.HandleSSHChannel(srv, sshConn, ch, ctx)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (srv *Server) handleRequests(ctx Context, in <-chan *gossh.Request) {
|
func (srv *Server) handleRequests(ctx Context, in <-chan *gossh.Request) {
|
||||||
for req := range in {
|
for req := range in {
|
||||||
handler, found := srv.requestHandlers[req.Type]
|
handler := srv.RequestHandlers[req.Type]
|
||||||
if !found {
|
if handler == nil {
|
||||||
if req.WantReply {
|
handler = srv.RequestHandlers["default"]
|
||||||
req.Reply(false, nil)
|
}
|
||||||
}
|
if handler == nil {
|
||||||
|
req.Reply(false, nil)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
/*reqCtx, cancel := context.WithCancel(ctx)
|
/*reqCtx, cancel := context.WithCancel(ctx)
|
||||||
defer cancel() */
|
defer cancel() */
|
||||||
ret, payload := handler.HandleRequest(ctx, srv, req)
|
ret, payload := handler.HandleSSHRequest(ctx, srv, req)
|
||||||
if req.WantReply {
|
req.Reply(ret, payload)
|
||||||
req.Reply(ret, payload)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -77,7 +77,7 @@ type Session interface {
|
|||||||
// when there is no signal channel specified
|
// when there is no signal channel specified
|
||||||
const maxSigBufSize = 128
|
const maxSigBufSize = 128
|
||||||
|
|
||||||
func sessionHandler(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) {
|
func DefaultSessionHandler(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) {
|
||||||
ch, reqs, err := newChan.Accept()
|
ch, reqs, err := newChan.Accept()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// TODO: trigger event callback
|
// TODO: trigger event callback
|
||||||
|
@ -19,9 +19,9 @@ func (srv *Server) serveOnce(l net.Listener) error {
|
|||||||
if e != nil {
|
if e != nil {
|
||||||
return e
|
return e
|
||||||
}
|
}
|
||||||
srv.channelHandlers = map[string]channelHandler{
|
srv.ChannelHandlers = map[string]ChannelHandler{
|
||||||
"session": sessionHandler,
|
"session": ChannelHandlerFunc(DefaultSessionHandler),
|
||||||
"direct-tcpip": directTcpipHandler,
|
"direct-tcpip": ChannelHandlerFunc(DirectTCPIPHandler),
|
||||||
}
|
}
|
||||||
srv.handleConn(conn)
|
srv.handleConn(conn)
|
||||||
return nil
|
return nil
|
||||||
|
11
tcpip.go
11
tcpip.go
@ -23,7 +23,9 @@ type localForwardChannelData struct {
|
|||||||
OriginPort uint32
|
OriginPort uint32
|
||||||
}
|
}
|
||||||
|
|
||||||
func directTcpipHandler(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) {
|
// DirectTCPIPHandler can be enabled by adding it to the server's
|
||||||
|
// ChannelHandlers under direct-tcpip.
|
||||||
|
func DirectTCPIPHandler(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) {
|
||||||
d := localForwardChannelData{}
|
d := localForwardChannelData{}
|
||||||
if err := gossh.Unmarshal(newChan.ExtraData(), &d); err != nil {
|
if err := gossh.Unmarshal(newChan.ExtraData(), &d); err != nil {
|
||||||
newChan.Reject(gossh.ConnectionFailed, "error parsing forward data: "+err.Error())
|
newChan.Reject(gossh.ConnectionFailed, "error parsing forward data: "+err.Error())
|
||||||
@ -84,12 +86,15 @@ type remoteForwardChannelData struct {
|
|||||||
OriginPort uint32
|
OriginPort uint32
|
||||||
}
|
}
|
||||||
|
|
||||||
type forwardedTCPHandler struct {
|
// ForwardedTCPHandler can be enabled by creating a ForwardedTCPHandler and
|
||||||
|
// adding it to the server's RequestHandlers under tcpip-forward and
|
||||||
|
// cancel-tcpip-forward.
|
||||||
|
type ForwardedTCPHandler struct {
|
||||||
forwards map[string]net.Listener
|
forwards map[string]net.Listener
|
||||||
sync.Mutex
|
sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h forwardedTCPHandler) HandleRequest(ctx Context, srv *Server, req *gossh.Request) (bool, []byte) {
|
func (h ForwardedTCPHandler) HandleSSHRequest(ctx Context, srv *Server, req *gossh.Request) (bool, []byte) {
|
||||||
h.Lock()
|
h.Lock()
|
||||||
if h.forwards == nil {
|
if h.forwards == nil {
|
||||||
h.forwards = make(map[string]net.Listener)
|
h.forwards = make(map[string]net.Listener)
|
||||||
|
Loading…
Reference in New Issue
Block a user