diff --git a/context.go b/context.go index 31531f4..2f61a40 100644 --- a/context.go +++ b/context.go @@ -4,6 +4,7 @@ import ( "context" "encoding/hex" "net" + "sync" gossh "golang.org/x/crypto/ssh" ) @@ -59,9 +60,11 @@ var ( // Context is a package specific context interface. It exposes connection // metadata and allows new values to be easily written to it. It's used in // authentication handlers and callbacks, and its underlying context.Context is -// exposed on Session in the session Handler. +// exposed on Session in the session Handler. A connection-scoped lock is also +// embedded in the context to make it easier to limit operations per-connection. type Context interface { context.Context + sync.Locker // User returns the username used when establishing the SSH connection. User() string @@ -90,11 +93,12 @@ type Context interface { type sshContext struct { context.Context + *sync.Mutex } func newContext(srv *Server) (*sshContext, context.CancelFunc) { innerCtx, cancel := context.WithCancel(context.Background()) - ctx := &sshContext{innerCtx} + ctx := &sshContext{innerCtx, &sync.Mutex{}} ctx.SetValue(ContextKeyServer, srv) perms := &Permissions{&gossh.Permissions{}} ctx.SetValue(ContextKeyPermissions, perms) diff --git a/server.go b/server.go index dc27f69..fb0c489 100644 --- a/server.go +++ b/server.go @@ -32,6 +32,7 @@ type Server struct { LocalPortForwardingCallback LocalPortForwardingCallback // callback for allowing local port forwarding, denies all if nil ReversePortForwardingCallback ReversePortForwardingCallback // callback for allowing reverse port forwarding, denies all if nil DefaultServerConfigCallback DefaultServerConfigCallback // callback for configuring detailed SSH options + SessionRequestCallback SessionRequestCallback // callback for allowing or denying SSH sessions IdleTimeout time.Duration // connection timeout when no activity, none if empty MaxTimeout time.Duration // absolute connection timeout, none if empty diff --git a/session.go b/session.go index c3db354..19ddda6 100644 --- a/session.go +++ b/session.go @@ -84,11 +84,12 @@ func sessionHandler(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChanne return } sess := &session{ - Channel: ch, - conn: conn, - handler: srv.Handler, - ptyCb: srv.PtyCallback, - ctx: ctx, + Channel: ch, + conn: conn, + handler: srv.Handler, + ptyCb: srv.PtyCallback, + sessReqCb: srv.SessionRequestCallback, + ctx: ctx, } sess.handleRequests(reqs) } @@ -96,18 +97,19 @@ func sessionHandler(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChanne type session struct { sync.Mutex gossh.Channel - conn *gossh.ServerConn - handler Handler - handled bool - exited bool - pty *Pty - winch chan Window - env []string - ptyCb PtyCallback - cmd []string - ctx Context - sigCh chan<- Signal - sigBuf []Signal + conn *gossh.ServerConn + handler Handler + handled bool + exited bool + pty *Pty + winch chan Window + env []string + ptyCb PtyCallback + sessReqCb SessionRequestCallback + cmd []string + ctx Context + sigCh chan<- Signal + sigBuf []Signal } func (sess *session) Write(p []byte) (n int, err error) { @@ -209,12 +211,22 @@ func (sess *session) handleRequests(reqs <-chan *gossh.Request) { req.Reply(false, nil) continue } - sess.handled = true - req.Reply(true, nil) var payload = struct{ Value string }{} gossh.Unmarshal(req.Payload, &payload) sess.cmd, _ = shlex.Split(payload.Value, true) + + // If there's a session policy callback, we need to confirm before + // accepting the session. + if sess.sessReqCb != nil && !sess.sessReqCb(sess, req.Type) { + sess.cmd = nil + req.Reply(false, nil) + continue + } + + sess.handled = true + req.Reply(true, nil) + go func() { sess.handler(sess) sess.Exit(0) diff --git a/ssh.go b/ssh.go index 334566d..cb0d77f 100644 --- a/ssh.go +++ b/ssh.go @@ -2,8 +2,9 @@ package ssh import ( "crypto/subtle" - gossh "golang.org/x/crypto/ssh" "net" + + gossh "golang.org/x/crypto/ssh" ) type Signal string @@ -46,6 +47,9 @@ type KeyboardInteractiveHandler func(ctx Context, challenger gossh.KeyboardInter // PtyCallback is a hook for allowing PTY sessions. type PtyCallback func(ctx Context, pty Pty) bool +// SessionRequestCallback is a callback for allowing or denying SSH sessions. +type SessionRequestCallback func(sess Session, requestType string) bool + // ConnCallback is a hook for new connections before handling. // It allows wrapping for timeouts and limiting by returning // the net.Conn that will be used as the underlying connection.