Add SessionPolicyCallback (#80)

* Add SessionPolicyCallback

Closes #7

* Update docs related to the embedded sync.Locker in the Context

* Fix mutex in context
This commit is contained in:
Kaleb Elwert 2019-02-22 18:11:43 -08:00 committed by Jeff Lindsay
parent 4b72c663cf
commit e5ece1489c
4 changed files with 43 additions and 22 deletions

@ -4,6 +4,7 @@ import (
"context"
"encoding/hex"
"net"
"sync"
gossh "golang.org/x/crypto/ssh"
)
@ -59,9 +60,11 @@ var (
// Context is a package specific context interface. It exposes connection
// metadata and allows new values to be easily written to it. It's used in
// authentication handlers and callbacks, and its underlying context.Context is
// exposed on Session in the session Handler.
// exposed on Session in the session Handler. A connection-scoped lock is also
// embedded in the context to make it easier to limit operations per-connection.
type Context interface {
context.Context
sync.Locker
// User returns the username used when establishing the SSH connection.
User() string
@ -90,11 +93,12 @@ type Context interface {
type sshContext struct {
context.Context
*sync.Mutex
}
func newContext(srv *Server) (*sshContext, context.CancelFunc) {
innerCtx, cancel := context.WithCancel(context.Background())
ctx := &sshContext{innerCtx}
ctx := &sshContext{innerCtx, &sync.Mutex{}}
ctx.SetValue(ContextKeyServer, srv)
perms := &Permissions{&gossh.Permissions{}}
ctx.SetValue(ContextKeyPermissions, perms)

@ -32,6 +32,7 @@ type Server struct {
LocalPortForwardingCallback LocalPortForwardingCallback // callback for allowing local port forwarding, denies all if nil
ReversePortForwardingCallback ReversePortForwardingCallback // callback for allowing reverse port forwarding, denies all if nil
DefaultServerConfigCallback DefaultServerConfigCallback // callback for configuring detailed SSH options
SessionRequestCallback SessionRequestCallback // callback for allowing or denying SSH sessions
IdleTimeout time.Duration // connection timeout when no activity, none if empty
MaxTimeout time.Duration // absolute connection timeout, none if empty

@ -84,11 +84,12 @@ func sessionHandler(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChanne
return
}
sess := &session{
Channel: ch,
conn: conn,
handler: srv.Handler,
ptyCb: srv.PtyCallback,
ctx: ctx,
Channel: ch,
conn: conn,
handler: srv.Handler,
ptyCb: srv.PtyCallback,
sessReqCb: srv.SessionRequestCallback,
ctx: ctx,
}
sess.handleRequests(reqs)
}
@ -96,18 +97,19 @@ func sessionHandler(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChanne
type session struct {
sync.Mutex
gossh.Channel
conn *gossh.ServerConn
handler Handler
handled bool
exited bool
pty *Pty
winch chan Window
env []string
ptyCb PtyCallback
cmd []string
ctx Context
sigCh chan<- Signal
sigBuf []Signal
conn *gossh.ServerConn
handler Handler
handled bool
exited bool
pty *Pty
winch chan Window
env []string
ptyCb PtyCallback
sessReqCb SessionRequestCallback
cmd []string
ctx Context
sigCh chan<- Signal
sigBuf []Signal
}
func (sess *session) Write(p []byte) (n int, err error) {
@ -209,12 +211,22 @@ func (sess *session) handleRequests(reqs <-chan *gossh.Request) {
req.Reply(false, nil)
continue
}
sess.handled = true
req.Reply(true, nil)
var payload = struct{ Value string }{}
gossh.Unmarshal(req.Payload, &payload)
sess.cmd, _ = shlex.Split(payload.Value, true)
// If there's a session policy callback, we need to confirm before
// accepting the session.
if sess.sessReqCb != nil && !sess.sessReqCb(sess, req.Type) {
sess.cmd = nil
req.Reply(false, nil)
continue
}
sess.handled = true
req.Reply(true, nil)
go func() {
sess.handler(sess)
sess.Exit(0)

6
ssh.go

@ -2,8 +2,9 @@ package ssh
import (
"crypto/subtle"
gossh "golang.org/x/crypto/ssh"
"net"
gossh "golang.org/x/crypto/ssh"
)
type Signal string
@ -46,6 +47,9 @@ type KeyboardInteractiveHandler func(ctx Context, challenger gossh.KeyboardInter
// PtyCallback is a hook for allowing PTY sessions.
type PtyCallback func(ctx Context, pty Pty) bool
// SessionRequestCallback is a callback for allowing or denying SSH sessions.
type SessionRequestCallback func(sess Session, requestType string) bool
// ConnCallback is a hook for new connections before handling.
// It allows wrapping for timeouts and limiting by returning
// the net.Conn that will be used as the underlying connection.