contexts (#29)

* context: working mostly tested context implementation and refactoring to go with it
* _example/ssh-publickey: updating new context based callbacks
* godocs related to public api changes for contexts
* context: converting []bytes to strings before putting into context

Signed-off-by: Jeff Lindsay <progrium@gmail.com>
This commit is contained in:
Jeff Lindsay 2017-03-14 14:13:03 -05:00 committed by GitHub
parent 791cd4b75f
commit 9b56478e13
11 changed files with 343 additions and 73 deletions

@ -16,7 +16,7 @@ func main() {
s.Write(authorizedKey)
})
publicKeyOption := ssh.PublicKeyAuth(func(user string, key ssh.PublicKey) bool {
publicKeyOption := ssh.PublicKeyAuth(func(ctx ssh.Context, key ssh.PublicKey) bool {
return true // allow all keys, or use ssh.KeysEqual() to compare against known keys
})

142
context.go Normal file

@ -0,0 +1,142 @@
package ssh
import (
"context"
"net"
gossh "golang.org/x/crypto/ssh"
)
// contextKey is a value for use with context.WithValue. It's used as
// a pointer so it fits in an interface{} without allocation.
type contextKey struct {
name string
}
var (
// ContextKeyUser is a context key for use with Contexts in this package.
// The associated value will be of type string.
ContextKeyUser = &contextKey{"user"}
// ContextKeySessionID is a context key for use with Contexts in this package.
// The associated value will be of type string.
ContextKeySessionID = &contextKey{"session-id"}
// ContextKeyPermissions is a context key for use with Contexts in this package.
// The associated value will be of type *Permissions.
ContextKeyPermissions = &contextKey{"permissions"}
// ContextKeyClientVersion is a context key for use with Contexts in this package.
// The associated value will be of type string.
ContextKeyClientVersion = &contextKey{"client-version"}
// ContextKeyServerVersion is a context key for use with Contexts in this package.
// The associated value will be of type string.
ContextKeyServerVersion = &contextKey{"server-version"}
// ContextKeyLocalAddr is a context key for use with Contexts in this package.
// The associated value will be of type net.Addr.
ContextKeyLocalAddr = &contextKey{"local-addr"}
// ContextKeyRemoteAddr is a context key for use with Contexts in this package.
// The associated value will be of type net.Addr.
ContextKeyRemoteAddr = &contextKey{"remote-addr"}
// ContextKeyServer is a context key for use with Contexts in this package.
// The associated value will be of type *Server.
ContextKeyServer = &contextKey{"ssh-server"}
// ContextKeyPublicKey is a context key for use with Contexts in this package.
// The associated value will be of type PublicKey.
ContextKeyPublicKey = &contextKey{"public-key"}
)
// Context is a package specific context interface. It exposes connection
// metadata and allows new values to be easily written to it. It's used in
// authentication handlers and callbacks, and its underlying context.Context is
// exposed on Session in the session Handler.
type Context interface {
context.Context
// User returns the username used when establishing the SSH connection.
User() string
// SessionID returns the session hash.
SessionID() string
// ClientVersion returns the version reported by the client.
ClientVersion() string
// ServerVersion returns the version reported by the server.
ServerVersion() string
// RemoteAddr returns the remote address for this connection.
RemoteAddr() net.Addr
// LocalAddr returns the local address for this connection.
LocalAddr() net.Addr
// Permissions returns the Permissions object used for this connection.
Permissions() *Permissions
// SetValue allows you to easily write new values into the underlying context.
SetValue(key, value interface{})
}
type sshContext struct {
context.Context
}
func newContext(srv *Server) *sshContext {
ctx := &sshContext{context.Background()}
ctx.SetValue(ContextKeyServer, srv)
perms := &Permissions{&gossh.Permissions{}}
ctx.SetValue(ContextKeyPermissions, perms)
return ctx
}
// this is separate from newContext because we will get ConnMetadata
// at different points so it needs to be applied separately
func (ctx *sshContext) applyConnMetadata(conn gossh.ConnMetadata) {
if ctx.Value(ContextKeySessionID) != nil {
return
}
ctx.SetValue(ContextKeySessionID, string(conn.SessionID()))
ctx.SetValue(ContextKeyClientVersion, string(conn.ClientVersion()))
ctx.SetValue(ContextKeyServerVersion, string(conn.ServerVersion()))
ctx.SetValue(ContextKeyUser, conn.User())
ctx.SetValue(ContextKeyLocalAddr, conn.LocalAddr())
ctx.SetValue(ContextKeyRemoteAddr, conn.RemoteAddr())
}
func (ctx *sshContext) SetValue(key, value interface{}) {
ctx.Context = context.WithValue(ctx.Context, key, value)
}
func (ctx *sshContext) User() string {
return ctx.Value(ContextKeyUser).(string)
}
func (ctx *sshContext) SessionID() string {
return ctx.Value(ContextKeySessionID).(string)
}
func (ctx *sshContext) ClientVersion() string {
return ctx.Value(ContextKeyClientVersion).(string)
}
func (ctx *sshContext) ServerVersion() string {
return ctx.Value(ContextKeyServerVersion).(string)
}
func (ctx *sshContext) RemoteAddr() net.Addr {
return ctx.Value(ContextKeyRemoteAddr).(net.Addr)
}
func (ctx *sshContext) LocalAddr() net.Addr {
return ctx.Value(ContextKeyLocalAddr).(net.Addr)
}
func (ctx *sshContext) Permissions() *Permissions {
return ctx.Value(ContextKeyPermissions).(*Permissions)
}

47
context_test.go Normal file

@ -0,0 +1,47 @@
package ssh
import "testing"
func TestSetPermissions(t *testing.T) {
t.Parallel()
permsExt := map[string]string{
"foo": "bar",
}
session, cleanup := newTestSessionWithOptions(t, &Server{
Handler: func(s Session) {
if _, ok := s.Permissions().Extensions["foo"]; !ok {
t.Fatalf("got %#v; want %#v", s.Permissions().Extensions, permsExt)
}
},
}, nil, PasswordAuth(func(ctx Context, password string) bool {
ctx.Permissions().Extensions = permsExt
return true
}))
defer cleanup()
if err := session.Run(""); err != nil {
t.Fatal(err)
}
}
func TestSetValue(t *testing.T) {
t.Parallel()
value := map[string]string{
"foo": "bar",
}
key := "testValue"
session, cleanup := newTestSessionWithOptions(t, &Server{
Handler: func(s Session) {
v := s.Context().Value(key).(map[string]string)
if v["foo"] != value["foo"] {
t.Fatalf("got %#v; want %#v", v, value)
}
},
}, nil, PasswordAuth(func(ctx Context, password string) bool {
ctx.SetValue(key, value)
return true
}))
defer cleanup()
if err := session.Run(""); err != nil {
t.Fatal(err)
}
}

@ -15,7 +15,7 @@ func ExampleListenAndServe() {
func ExamplePasswordAuth() {
ssh.ListenAndServe(":2222", nil,
ssh.PasswordAuth(func(user, pass string) bool {
ssh.PasswordAuth(func(ctx ssh.Context, pass string) bool {
return pass == "secret"
}),
)
@ -27,7 +27,7 @@ func ExampleNoPty() {
func ExamplePublicKeyAuth() {
ssh.ListenAndServe(":2222", nil,
ssh.PublicKeyAuth(func(user string, key ssh.PublicKey) bool {
ssh.PublicKeyAuth(func(ctx ssh.Context, key ssh.PublicKey) bool {
data, _ := ioutil.ReadFile("/path/to/allowed/key.pub")
allowed, _, _, _, _ := ssh.ParseAuthorizedKey(data)
return ssh.KeysEqual(key, allowed)

@ -56,7 +56,7 @@ func HostKeyPEM(bytes []byte) Option {
// denying PTY requests.
func NoPty() Option {
return func(srv *Server) error {
srv.PtyCallback = func(user string, permissions *Permissions) bool {
srv.PtyCallback = func(ctx Context, pty Pty) bool {
return false
}
return nil

66
options_test.go Normal file

@ -0,0 +1,66 @@
package ssh
import (
"strings"
"testing"
gossh "golang.org/x/crypto/ssh"
)
func newTestSessionWithOptions(t *testing.T, srv *Server, cfg *gossh.ClientConfig, options ...Option) (*gossh.Session, func()) {
for _, option := range options {
if err := srv.SetOption(option); err != nil {
t.Fatal(err)
}
}
return newTestSession(t, srv, cfg)
}
func TestPasswordAuth(t *testing.T) {
t.Parallel()
testUser := "testuser"
testPass := "testpass"
session, cleanup := newTestSessionWithOptions(t, &Server{
Handler: func(s Session) {
// noop
},
}, &gossh.ClientConfig{
User: testUser,
Auth: []gossh.AuthMethod{
gossh.Password(testPass),
},
}, PasswordAuth(func(ctx Context, password string) bool {
if ctx.User() != testUser {
t.Fatalf("user = %#v; want %#v", ctx.User(), testUser)
}
if password != testPass {
t.Fatalf("user = %#v; want %#v", password, testPass)
}
return true
}))
defer cleanup()
if err := session.Run(""); err != nil {
t.Fatal(err)
}
}
func TestPasswordAuthBadPass(t *testing.T) {
t.Parallel()
l := newLocalListener()
srv := &Server{Handler: func(s Session) {}}
srv.SetOption(PasswordAuth(func(ctx Context, password string) bool {
return false
}))
go srv.serveOnce(l)
_, err := gossh.Dial("tcp", l.Addr().String(), &gossh.ClientConfig{
User: "testuser",
Auth: []gossh.AuthMethod{
gossh.Password("testpass"),
},
})
if err != nil {
if !strings.Contains(err.Error(), "unable to authenticate") {
t.Fatal(err)
}
}
}

@ -17,21 +17,24 @@ type Server struct {
HostSigners []Signer // private keys for the host key, must have at least one
Version string // server version to be sent before the initial handshake
PasswordHandler PasswordHandler // password authentication handler
PublicKeyHandler PublicKeyHandler // public key authentication handler
PtyCallback PtyCallback // callback for allowing PTY sessions, allows all if nil
PermissionsCallback PermissionsCallback // optional callback for setting up permissions
PasswordHandler PasswordHandler // password authentication handler
PublicKeyHandler PublicKeyHandler // public key authentication handler
PtyCallback PtyCallback // callback for allowing PTY sessions, allows all if nil
}
func (srv *Server) makeConfig() (*gossh.ServerConfig, error) {
config := &gossh.ServerConfig{}
func (srv *Server) ensureHostSigner() error {
if len(srv.HostSigners) == 0 {
signer, err := generateSigner()
if err != nil {
return nil, err
return err
}
srv.HostSigners = append(srv.HostSigners, signer)
}
return nil
}
func (srv *Server) config(ctx *sshContext) *gossh.ServerConfig {
config := &gossh.ServerConfig{}
for _, signer := range srv.HostSigners {
config.AddHostKey(signer)
}
@ -43,34 +46,24 @@ func (srv *Server) makeConfig() (*gossh.ServerConfig, error) {
}
if srv.PasswordHandler != nil {
config.PasswordCallback = func(conn gossh.ConnMetadata, password []byte) (*gossh.Permissions, error) {
perms := &gossh.Permissions{}
if ok := srv.PasswordHandler(conn.User(), string(password)); !ok {
return perms, fmt.Errorf("permission denied")
ctx.applyConnMetadata(conn)
if ok := srv.PasswordHandler(ctx, string(password)); !ok {
return ctx.Permissions().Permissions, fmt.Errorf("permission denied")
}
if srv.PermissionsCallback != nil {
srv.PermissionsCallback(conn.User(), &Permissions{perms})
}
return perms, nil
return ctx.Permissions().Permissions, nil
}
}
if srv.PublicKeyHandler != nil {
config.PublicKeyCallback = func(conn gossh.ConnMetadata, key gossh.PublicKey) (*gossh.Permissions, error) {
perms := &gossh.Permissions{}
if ok := srv.PublicKeyHandler(conn.User(), key); !ok {
return perms, fmt.Errorf("permission denied")
ctx.applyConnMetadata(conn)
if ok := srv.PublicKeyHandler(ctx, key); !ok {
return ctx.Permissions().Permissions, fmt.Errorf("permission denied")
}
// no other way to pass the key from
// auth handler to session handler
perms.Extensions = map[string]string{
"_publickey": string(key.Marshal()),
}
if srv.PermissionsCallback != nil {
srv.PermissionsCallback(conn.User(), &Permissions{perms})
}
return perms, nil
ctx.SetValue(ContextKeyPublicKey, key)
return ctx.Permissions().Permissions, nil
}
}
return config, nil
return config
}
// Handle sets the Handler for the server.
@ -85,8 +78,7 @@ func (srv *Server) Handle(fn Handler) {
// Serve always returns a non-nil error.
func (srv *Server) Serve(l net.Listener) error {
defer l.Close()
config, err := srv.makeConfig()
if err != nil {
if err := srv.ensureHostSigner(); err != nil {
return err
}
if srv.Handler == nil {
@ -110,41 +102,46 @@ func (srv *Server) Serve(l net.Listener) error {
}
return e
}
go srv.handleConn(conn, config)
go srv.handleConn(conn)
}
}
func (srv *Server) handleConn(conn net.Conn, conf *gossh.ServerConfig) {
func (srv *Server) handleConn(conn net.Conn) {
defer conn.Close()
sshConn, chans, reqs, err := gossh.NewServerConn(conn, conf)
ctx := newContext(srv)
sshConn, chans, reqs, err := gossh.NewServerConn(conn, srv.config(ctx))
if err != nil {
// TODO: trigger event callback
return
}
ctx.applyConnMetadata(sshConn)
go gossh.DiscardRequests(reqs)
for ch := range chans {
if ch.ChannelType() != "session" {
ch.Reject(gossh.UnknownChannelType, "unsupported channel type")
continue
}
go srv.handleChannel(sshConn, ch)
go srv.handleChannel(sshConn, ch, ctx)
}
}
func (srv *Server) handleChannel(conn *gossh.ServerConn, newChan gossh.NewChannel) {
func (srv *Server) handleChannel(conn *gossh.ServerConn, newChan gossh.NewChannel, ctx *sshContext) {
ch, reqs, err := newChan.Accept()
if err != nil {
// TODO: trigger event callback
return
}
sess := srv.newSession(conn, ch)
sess := srv.newSession(conn, ch, ctx)
sess.handleRequests(reqs)
}
func (srv *Server) newSession(conn *gossh.ServerConn, ch gossh.Channel) *session {
func (srv *Server) newSession(conn *gossh.ServerConn, ch gossh.Channel, ctx *sshContext) *session {
sess := &session{
Channel: ch,
conn: conn,
handler: srv.Handler,
ptyCb: srv.PtyCallback,
ctx: ctx,
}
return sess
}

@ -2,6 +2,7 @@ package ssh
import (
"bytes"
"context"
"errors"
"fmt"
"net"
@ -43,6 +44,15 @@ type Session interface {
// used it will return nil.
PublicKey() PublicKey
// Context returns the connection's context. The returned context is always
// non-nil and holds the same data as the Context passed into auth
// handlers and callbacks.
Context() context.Context
// Permissions returns a copy of the Permissions object that was available for
// setup in the auth handlers via the Context.
Permissions() Permissions
// Pty returns PTY information, a channel of window size changes, and a boolean
// of whether or not a PTY was accepted for this session.
Pty() (Pty, <-chan Window, bool)
@ -61,6 +71,7 @@ type session struct {
env []string
ptyCb PtyCallback
cmd []string
ctx *sshContext
}
func (sess *session) Write(p []byte) (n int, err error) {
@ -80,18 +91,18 @@ func (sess *session) Write(p []byte) (n int, err error) {
}
func (sess *session) PublicKey() PublicKey {
if sess.conn.Permissions == nil {
return nil
}
s, ok := sess.conn.Permissions.Extensions["_publickey"]
if !ok {
return nil
}
key, err := ParsePublicKey([]byte(s))
if err != nil {
return nil
}
return key
return sess.ctx.Value(ContextKeyPublicKey).(PublicKey)
}
func (sess *session) Permissions() Permissions {
// use context permissions because its properly
// wrapped and easier to dereference
perms := sess.ctx.Value(ContextKeyPermissions).(*Permissions)
return *perms
}
func (sess *session) Context() context.Context {
return sess.ctx.Context
}
func (sess *session) Exit(code int) error {
@ -163,22 +174,25 @@ func (sess *session) handleRequests(reqs <-chan *gossh.Request) {
req.Reply(false, nil)
continue
}
ptyReq, ok := parsePtyRequest(req.Payload)
if !ok {
req.Reply(false, nil)
continue
}
if sess.ptyCb != nil {
ok := sess.ptyCb(sess.conn.User(), &Permissions{sess.conn.Permissions})
ok := sess.ptyCb(sess.ctx, ptyReq)
if !ok {
req.Reply(false, nil)
continue
}
}
ptyReq, ok := parsePtyRequest(req.Payload)
if ok {
sess.pty = &ptyReq
sess.winch = make(chan Window, 1)
sess.winch <- ptyReq.Window
defer func() {
close(sess.winch)
}()
}
sess.pty = &ptyReq
sess.winch = make(chan Window, 1)
sess.winch <- ptyReq.Window
defer func() {
// when reqs is closed
close(sess.winch)
}()
req.Reply(ok, nil)
case "window-change":
if sess.pty == nil {

@ -11,15 +11,14 @@ import (
)
func (srv *Server) serveOnce(l net.Listener) error {
config, err := srv.makeConfig()
if err != nil {
if err := srv.ensureHostSigner(); err != nil {
return err
}
conn, e := l.Accept()
if e != nil {
return e
}
srv.handleConn(conn, config)
srv.handleConn(conn)
return nil
}
@ -63,6 +62,7 @@ func newTestSession(t *testing.T, srv *Server, cfg *gossh.ClientConfig) (*gossh.
}
func TestStdout(t *testing.T) {
t.Parallel()
testBytes := []byte("Hello world\n")
session, cleanup := newTestSession(t, &Server{
Handler: func(s Session) {
@ -81,6 +81,7 @@ func TestStdout(t *testing.T) {
}
func TestStderr(t *testing.T) {
t.Parallel()
testBytes := []byte("Hello world\n")
session, cleanup := newTestSession(t, &Server{
Handler: func(s Session) {
@ -99,6 +100,7 @@ func TestStderr(t *testing.T) {
}
func TestStdin(t *testing.T) {
t.Parallel()
testBytes := []byte("Hello world\n")
session, cleanup := newTestSession(t, &Server{
Handler: func(s Session) {
@ -118,6 +120,7 @@ func TestStdin(t *testing.T) {
}
func TestUser(t *testing.T) {
t.Parallel()
testUser := []byte("progrium")
session, cleanup := newTestSession(t, &Server{
Handler: func(s Session) {
@ -138,6 +141,7 @@ func TestUser(t *testing.T) {
}
func TestDefaultExitStatusZero(t *testing.T) {
t.Parallel()
session, cleanup := newTestSession(t, &Server{
Handler: func(s Session) {
// noop
@ -151,6 +155,7 @@ func TestDefaultExitStatusZero(t *testing.T) {
}
func TestExplicitExitStatusZero(t *testing.T) {
t.Parallel()
session, cleanup := newTestSession(t, &Server{
Handler: func(s Session) {
s.Exit(0)
@ -164,6 +169,7 @@ func TestExplicitExitStatusZero(t *testing.T) {
}
func TestExitStatusNonZero(t *testing.T) {
t.Parallel()
session, cleanup := newTestSession(t, &Server{
Handler: func(s Session) {
s.Exit(1)
@ -181,6 +187,7 @@ func TestExitStatusNonZero(t *testing.T) {
}
func TestPty(t *testing.T) {
t.Parallel()
term := "xterm"
winWidth := 40
winHeight := 80
@ -214,6 +221,7 @@ func TestPty(t *testing.T) {
}
func TestPtyResize(t *testing.T) {
t.Parallel()
winch0 := Window{40, 80}
winch1 := Window{80, 160}
winch2 := Window{20, 40}

9
ssh.go

@ -34,16 +34,13 @@ type Option func(*Server) error
type Handler func(Session)
// PublicKeyHandler is a callback for performing public key authentication.
type PublicKeyHandler func(user string, key PublicKey) bool
type PublicKeyHandler func(ctx Context, key PublicKey) bool
// PasswordHandler is a callback for performing password authentication.
type PasswordHandler func(user, password string) bool
// PermissionsCallback is a hook for setting up user permissions.
type PermissionsCallback func(user string, permissions *Permissions) error
type PasswordHandler func(ctx Context, password string) bool
// PtyCallback is a hook for allowing PTY sessions.
type PtyCallback func(user string, permissions *Permissions) bool
type PtyCallback func(ctx Context, pty Pty) bool
// Window represents the size of a PTY window.
type Window struct {

@ -10,8 +10,7 @@ type PublicKey interface {
// The Permissions type holds fine-grained permissions that are specific to a
// user or a specific authentication method for a user. Permissions, except for
// "source-address", must be enforced in the server application layer, after
// successful authentication. The Permissions are passed on in ServerConn so a
// server implementation can honor them.
// successful authentication.
type Permissions struct {
*gossh.Permissions
}