agent forwarding support (#31)

* agent: added agent forwarding support with an example
* context: encode session id to hex string
* agent: ensure conn doesn't change in closure as loop iterates
* tests: use HostKeyCallback in ClientConfig
* README: noting examples in _example
* agent: documented exported names, added constants for temp file creation

Signed-off-by: Jeff Lindsay <progrium@gmail.com>
This commit is contained in:
Jeff Lindsay 2017-04-14 14:47:40 -05:00 committed by GitHub
parent 9b56478e13
commit 1051a0d154
8 changed files with 169 additions and 29 deletions

@ -9,25 +9,28 @@ building SSH servers. The goal of the API was to make it as simple as using
```
package main
import (
"github.com/gliderlabs/ssh"
"io"
"log"
)
func main() {
ssh.Handle(func(s ssh.Session) {
io.WriteString(s, "Hello world\n")
})
log.Fatal(ssh.ListenAndServe(":2222", nil))
}
```
This package was built after working on nearly a dozen projects using SSH and
collaborating with [@shazow](https://twitter.com/shazow) (known for [ssh-chat](https://github.com/shazow/ssh-chat)).
This package was built after working on nearly a dozen projects at Glider Labs using SSH and collaborating with [@shazow](https://twitter.com/shazow) (known for [ssh-chat](https://github.com/shazow/ssh-chat)).
## Examples
A bunch of great examples are in the `_example` directory.
## Usage

@ -0,0 +1,35 @@
package main
import (
"fmt"
"log"
"os/exec"
"github.com/gliderlabs/ssh"
)
func main() {
ssh.Handle(func(s ssh.Session) {
cmd := exec.Command("ssh-add", "-l")
if ssh.AgentRequested(s) {
l, err := ssh.NewAgentListener()
if err != nil {
log.Fatal(err)
}
defer l.Close()
go ssh.ForwardAgentConnections(l, s)
cmd.Env = append(s.Environ(), fmt.Sprintf("%s=%s", "SSH_AUTH_SOCK", l.Addr().String()))
} else {
cmd.Env = s.Environ()
}
cmd.Stdout = s
cmd.Stderr = s.Stderr()
if err := cmd.Run(); err != nil {
log.Println(err)
return
}
})
log.Println("starting ssh server on port 2222...")
log.Fatal(ssh.ListenAndServe(":2222", nil))
}

81
agent.go Normal file

@ -0,0 +1,81 @@
package ssh
import (
"io"
"io/ioutil"
"net"
"path"
"sync"
gossh "golang.org/x/crypto/ssh"
)
const (
agentRequestType = "auth-agent-req@openssh.com"
agentChannelType = "auth-agent@openssh.com"
agentTempDir = "auth-agent"
agentListenFile = "listener.sock"
)
// contextKeyAgentRequest is an internal context key for storing if the
// client requested agent forwarding
var contextKeyAgentRequest = &contextKey{"auth-agent-req"}
func setAgentRequested(sess *session) {
sess.ctx.SetValue(contextKeyAgentRequest, true)
}
// AgentRequested returns true if the client requested agent forwarding.
func AgentRequested(sess Session) bool {
return sess.Context().Value(contextKeyAgentRequest) == true
}
// NewAgentListener sets up a temporary Unix socket that can be communicated
// to the session environment and used for forwarding connections.
func NewAgentListener() (net.Listener, error) {
dir, err := ioutil.TempDir("", agentTempDir)
if err != nil {
return nil, err
}
l, err := net.Listen("unix", path.Join(dir, agentListenFile))
if err != nil {
return nil, err
}
return l, nil
}
// ForwardAgentConnections takes connections from a listener to proxy into the
// session on the OpenSSH channel for agent connections. It blocks and services
// connections until the listener stop accepting.
func ForwardAgentConnections(l net.Listener, s Session) {
sshConn := s.Context().Value(ContextKeyConn).(gossh.Conn)
for {
conn, err := l.Accept()
if err != nil {
return
}
go func(conn net.Conn) {
defer conn.Close()
channel, reqs, err := sshConn.OpenChannel(agentChannelType, nil)
if err != nil {
return
}
defer channel.Close()
go gossh.DiscardRequests(reqs)
var wg sync.WaitGroup
wg.Add(2)
go func() {
io.Copy(conn, channel)
conn.(*net.UnixConn).CloseWrite()
wg.Done()
}()
go func() {
io.Copy(channel, conn)
channel.CloseWrite()
wg.Done()
}()
wg.Wait()
}(conn)
}
}

@ -3,6 +3,7 @@ package ssh
import (
"context"
"net"
"encoding/hex"
gossh "golang.org/x/crypto/ssh"
)
@ -46,6 +47,10 @@ var (
// The associated value will be of type *Server.
ContextKeyServer = &contextKey{"ssh-server"}
// ContextKeyConn is a context key for use with Contexts in this package.
// The associated value will be of type gossh.Conn.
ContextKeyConn = &contextKey{"ssh-conn"}
// ContextKeyPublicKey is a context key for use with Contexts in this package.
// The associated value will be of type PublicKey.
ContextKeyPublicKey = &contextKey{"public-key"}
@ -101,7 +106,7 @@ func (ctx *sshContext) applyConnMetadata(conn gossh.ConnMetadata) {
if ctx.Value(ContextKeySessionID) != nil {
return
}
ctx.SetValue(ContextKeySessionID, string(conn.SessionID()))
ctx.SetValue(ContextKeySessionID, hex.EncodeToString(conn.SessionID()))
ctx.SetValue(ContextKeyClientVersion, string(conn.ClientVersion()))
ctx.SetValue(ContextKeyServerVersion, string(conn.ServerVersion()))
ctx.SetValue(ContextKeyUser, conn.User())

@ -29,6 +29,7 @@ func TestPasswordAuth(t *testing.T) {
Auth: []gossh.AuthMethod{
gossh.Password(testPass),
},
HostKeyCallback: gossh.InsecureIgnoreHostKey(),
}, PasswordAuth(func(ctx Context, password string) bool {
if ctx.User() != testUser {
t.Fatalf("user = %#v; want %#v", ctx.User(), testUser)
@ -57,6 +58,7 @@ func TestPasswordAuthBadPass(t *testing.T) {
Auth: []gossh.AuthMethod{
gossh.Password("testpass"),
},
HostKeyCallback: gossh.InsecureIgnoreHostKey(),
})
if err != nil {
if !strings.Contains(err.Error(), "unable to authenticate") {

@ -20,8 +20,13 @@ type Server struct {
PasswordHandler PasswordHandler // password authentication handler
PublicKeyHandler PublicKeyHandler // public key authentication handler
PtyCallback PtyCallback // callback for allowing PTY sessions, allows all if nil
channelHandlers map[string]channelHandler
}
// internal for now
type channelHandler func(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx *sshContext)
func (srv *Server) ensureHostSigner() error {
if len(srv.HostSigners) == 0 {
signer, err := generateSigner()
@ -34,6 +39,9 @@ func (srv *Server) ensureHostSigner() error {
}
func (srv *Server) config(ctx *sshContext) *gossh.ServerConfig {
srv.channelHandlers = map[string]channelHandler{
"session": sessionHandler,
}
config := &gossh.ServerConfig{}
for _, signer := range srv.HostSigners {
config.AddHostKey(signer)
@ -114,38 +122,19 @@ func (srv *Server) handleConn(conn net.Conn) {
// TODO: trigger event callback
return
}
ctx.SetValue(ContextKeyConn, sshConn)
ctx.applyConnMetadata(sshConn)
go gossh.DiscardRequests(reqs)
for ch := range chans {
if ch.ChannelType() != "session" {
handler, found := srv.channelHandlers[ch.ChannelType()]
if !found {
ch.Reject(gossh.UnknownChannelType, "unsupported channel type")
continue
}
go srv.handleChannel(sshConn, ch, ctx)
go handler(srv, sshConn, ch, ctx)
}
}
func (srv *Server) handleChannel(conn *gossh.ServerConn, newChan gossh.NewChannel, ctx *sshContext) {
ch, reqs, err := newChan.Accept()
if err != nil {
// TODO: trigger event callback
return
}
sess := srv.newSession(conn, ch, ctx)
sess.handleRequests(reqs)
}
func (srv *Server) newSession(conn *gossh.ServerConn, ch gossh.Channel, ctx *sshContext) *session {
sess := &session{
Channel: ch,
conn: conn,
handler: srv.Handler,
ptyCb: srv.PtyCallback,
ctx: ctx,
}
return sess
}
// ListenAndServe listens on the TCP network address srv.Addr and then calls
// Serve to handle incoming connections. If srv.Addr is blank, ":22" is used.
// ListenAndServe always returns a non-nil error.

@ -60,6 +60,22 @@ type Session interface {
// TODO: Signals(c chan<- Signal)
}
func sessionHandler(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx *sshContext) {
ch, reqs, err := newChan.Accept()
if err != nil {
// TODO: trigger event callback
return
}
sess := &session{
Channel: ch,
conn: conn,
handler: srv.Handler,
ptyCb: srv.PtyCallback,
ctx: ctx,
}
sess.handleRequests(reqs)
}
type session struct {
gossh.Channel
conn *gossh.ServerConn
@ -205,6 +221,12 @@ func (sess *session) handleRequests(reqs <-chan *gossh.Request) {
sess.winch <- win
}
req.Reply(ok, nil)
case agentRequestType:
// TODO: option/callback to allow agent forwarding
setAgentRequested(sess)
req.Reply(true, nil)
default:
// TODO: debug log
}
}
}

@ -41,6 +41,9 @@ func newClientSession(t *testing.T, addr string, config *gossh.ClientConfig) (*g
},
}
}
if config.HostKeyCallback == nil {
config.HostKeyCallback = gossh.InsecureIgnoreHostKey()
}
client, err := gossh.Dial("tcp", addr, config)
if err != nil {
t.Fatal(err)