diff --git a/_example/docker/docker.go b/_example/docker/docker.go new file mode 100644 index 0000000..1306965 --- /dev/null +++ b/_example/docker/docker.go @@ -0,0 +1,112 @@ +package main + +import ( + "context" + "fmt" + "io" + "log" + + "github.com/docker/docker/api/types" + "github.com/docker/docker/api/types/container" + "github.com/docker/docker/client" + "github.com/docker/docker/pkg/stdcopy" + "github.com/gliderlabs/ssh" +) + +func main() { + ssh.Handle(func(sess ssh.Session) { + _, _, isTty := sess.Pty() + cfg := &container.Config{ + Image: sess.User(), + Cmd: sess.Command(), + Env: sess.Environ(), + Tty: isTty, + OpenStdin: true, + AttachStderr: true, + AttachStdin: true, + AttachStdout: true, + StdinOnce: true, + Volumes: make(map[string]struct{}), + } + err, status, cleanup := dockerRun(cfg, sess) + defer cleanup() + if err != nil { + fmt.Fprintln(sess, err) + log.Println(err) + } + sess.Exit(int(status)) + }) + + log.Println("starting ssh server on port 2222...") + log.Fatal(ssh.ListenAndServe(":2222", nil)) +} + +func dockerRun(cfg *container.Config, sess ssh.Session) (err error, status int64, cleanup func()) { + docker, err := client.NewEnvClient() + if err != nil { + panic(err) + } + status = 255 + cleanup = func() {} + ctx := context.Background() + res, err := docker.ContainerCreate(ctx, cfg, nil, nil, "") + if err != nil { + return + } + cleanup = func() { + docker.ContainerRemove(ctx, res.ID, types.ContainerRemoveOptions{}) + } + opts := types.ContainerAttachOptions{ + Stdin: cfg.AttachStdin, + Stdout: cfg.AttachStdout, + Stderr: cfg.AttachStderr, + Stream: true, + } + stream, err := docker.ContainerAttach(ctx, res.ID, opts) + if err != nil { + return + } + cleanup = func() { + docker.ContainerRemove(ctx, res.ID, types.ContainerRemoveOptions{}) + stream.Close() + } + outputErr := make(chan error) + go func() { + var err error + if cfg.Tty { + _, err = io.Copy(sess, stream.Conn) + } else { + _, err = stdcopy.StdCopy(sess, sess.Stderr(), stream.Reader) + } + outputErr <- err + }() + go func() { + defer stream.CloseWrite() + io.Copy(stream.Conn, sess) + }() + err = docker.ContainerStart(ctx, res.ID, types.ContainerStartOptions{}) + if err != nil { + return + } + if cfg.Tty { + _, winCh, _ := sess.Pty() + go func() { + for win := range winCh { + err := docker.ContainerResize(ctx, res.ID, types.ResizeOptions{ + Height: uint(win.Height), + Width: uint(win.Width), + }) + if err != nil { + log.Println(err) + break + } + } + }() + } + status, err = docker.ContainerWait(ctx, res.ID) + if err != nil { + return + } + err = <-outputErr + return +} diff --git a/_example/pty/pty.go b/_example/pty/pty.go new file mode 100644 index 0000000..7967deb --- /dev/null +++ b/_example/pty/pty.go @@ -0,0 +1,48 @@ +package main + +import ( + "fmt" + "io" + "log" + "os" + "os/exec" + "syscall" + "unsafe" + + "github.com/gliderlabs/ssh" + "github.com/kr/pty" +) + +func setWinsize(f *os.File, w, h int) { + syscall.Syscall(syscall.SYS_IOCTL, f.Fd(), uintptr(syscall.TIOCSWINSZ), + uintptr(unsafe.Pointer(&struct{ h, w, x, y uint16 }{uint16(h), uint16(w), 0, 0}))) +} + +func main() { + ssh.Handle(func(s ssh.Session) { + cmd := exec.Command("top") + ptyReq, winCh, isPty := s.Pty() + if isPty { + cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", ptyReq.Term)) + f, err := pty.Start(cmd) + if err != nil { + panic(err) + } + go func() { + for win := range winCh { + setWinsize(f, win.Width, win.Height) + } + }() + go func() { + io.Copy(f, s) // stdin + }() + io.Copy(s, f) // stdout + } else { + io.WriteString(s, "No PTY requested.\n") + s.Exit(1) + } + }) + + log.Println("starting ssh server on port 2222...") + log.Fatal(ssh.ListenAndServe(":2222", nil)) +} diff --git a/circle.yml b/circle.yml new file mode 100644 index 0000000..4a2bf3b --- /dev/null +++ b/circle.yml @@ -0,0 +1,3 @@ +test: + override: + - go test -v -race diff --git a/session.go b/session.go index 694c23a..65c8b6a 100644 --- a/session.go +++ b/session.go @@ -126,8 +126,6 @@ func (sess *session) Pty() (Pty, <-chan Window, bool) { func (sess *session) handleRequests(reqs <-chan *gossh.Request) { for req := range reqs { - var width, height int - var ok bool switch req.Type { case "shell", "exec": if sess.handled { @@ -152,8 +150,9 @@ func (sess *session) handleRequests(reqs <-chan *gossh.Request) { var kv = struct{ Key, Value string }{} gossh.Unmarshal(req.Payload, &kv) sess.env = append(sess.env, fmt.Sprintf("%s=%s", kv.Key, kv.Value)) + req.Reply(true, nil) case "pty-req": - if sess.handled { + if sess.handled || sess.pty != nil { req.Reply(false, nil) continue } @@ -164,23 +163,27 @@ func (sess *session) handleRequests(reqs <-chan *gossh.Request) { continue } } - width, height, ok = parsePtyRequest(req.Payload) + ptyReq, ok := parsePtyRequest(req.Payload) if ok { - sess.pty = &Pty{Window{width, height}} - sess.winch = make(chan Window) + sess.pty = &ptyReq + sess.winch = make(chan Window, 1) + sess.winch <- ptyReq.Window + defer func() { + close(sess.winch) + }() } - req.Reply(ok, nil) case "window-change": if sess.pty == nil { req.Reply(false, nil) continue } - width, height, ok = parseWinchRequest(req.Payload) + win, ok := parseWinchRequest(req.Payload) if ok { - sess.pty.Window = Window{width, height} - sess.winch <- sess.pty.Window + sess.pty.Window = win + sess.winch <- win } + req.Reply(ok, nil) } } } diff --git a/session_test.go b/session_test.go new file mode 100644 index 0000000..83705b1 --- /dev/null +++ b/session_test.go @@ -0,0 +1,271 @@ +package ssh + +import ( + "bytes" + "fmt" + "io" + "net" + "testing" + + gossh "golang.org/x/crypto/ssh" +) + +func (srv *Server) serveOnce(l net.Listener) error { + config, err := srv.makeConfig() + if err != nil { + return err + } + conn, e := l.Accept() + if e != nil { + return e + } + srv.handleConn(conn, config) + return nil +} + +func newLocalListener() net.Listener { + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + if l, err = net.Listen("tcp6", "[::1]:0"); err != nil { + panic(fmt.Sprintf("failed to listen on a port: %v", err)) + } + } + return l +} + +func newClientSession(t *testing.T, addr string, config *gossh.ClientConfig) (*gossh.Session, func()) { + if config == nil { + config = &gossh.ClientConfig{ + User: "testuser", + Auth: []gossh.AuthMethod{ + gossh.Password("testpass"), + }, + } + } + client, err := gossh.Dial("tcp", addr, config) + if err != nil { + t.Fatal(err) + } + session, err := client.NewSession() + if err != nil { + t.Fatal(err) + } + return session, func() { + session.Close() + client.Close() + } +} + +func newTestSession(t *testing.T, srv *Server, cfg *gossh.ClientConfig) (*gossh.Session, func()) { + l := newLocalListener() + go srv.serveOnce(l) + return newClientSession(t, l.Addr().String(), cfg) +} + +func TestStdout(t *testing.T) { + testBytes := []byte("Hello world\n") + session, cleanup := newTestSession(t, &Server{ + Handler: func(s Session) { + s.Write(testBytes) + }, + }, nil) + defer cleanup() + var stdout bytes.Buffer + session.Stdout = &stdout + if err := session.Run(""); err != nil { + t.Fatal(err) + } + if !bytes.Equal(stdout.Bytes(), testBytes) { + t.Fatalf("stdout = %#v; want %#v", stdout.Bytes(), testBytes) + } +} + +func TestStderr(t *testing.T) { + testBytes := []byte("Hello world\n") + session, cleanup := newTestSession(t, &Server{ + Handler: func(s Session) { + s.Stderr().Write(testBytes) + }, + }, nil) + defer cleanup() + var stderr bytes.Buffer + session.Stderr = &stderr + if err := session.Run(""); err != nil { + t.Fatal(err) + } + if !bytes.Equal(stderr.Bytes(), testBytes) { + t.Fatalf("stderr = %#v; want %#v", stderr.Bytes(), testBytes) + } +} + +func TestStdin(t *testing.T) { + testBytes := []byte("Hello world\n") + session, cleanup := newTestSession(t, &Server{ + Handler: func(s Session) { + io.Copy(s, s) // stdin back into stdout + }, + }, nil) + defer cleanup() + var stdout bytes.Buffer + session.Stdout = &stdout + session.Stdin = bytes.NewBuffer(testBytes) + if err := session.Run(""); err != nil { + t.Fatal(err) + } + if !bytes.Equal(stdout.Bytes(), testBytes) { + t.Fatalf("stdout = %#v; want %#v given stdin = %#v", stdout.Bytes(), testBytes, testBytes) + } +} + +func TestUser(t *testing.T) { + testUser := []byte("progrium") + session, cleanup := newTestSession(t, &Server{ + Handler: func(s Session) { + io.WriteString(s, s.User()) + }, + }, &gossh.ClientConfig{ + User: string(testUser), + }) + defer cleanup() + var stdout bytes.Buffer + session.Stdout = &stdout + if err := session.Run(""); err != nil { + t.Fatal(err) + } + if !bytes.Equal(stdout.Bytes(), testUser) { + t.Fatalf("stdout = %#v; want %#v given user = %#v", stdout.Bytes(), testUser, string(testUser)) + } +} + +func TestDefaultExitStatusZero(t *testing.T) { + session, cleanup := newTestSession(t, &Server{ + Handler: func(s Session) { + // noop + }, + }, nil) + defer cleanup() + err := session.Run("") + if err != nil { + t.Fatalf("expected nil but got %v", err) + } +} + +func TestExplicitExitStatusZero(t *testing.T) { + session, cleanup := newTestSession(t, &Server{ + Handler: func(s Session) { + s.Exit(0) + }, + }, nil) + defer cleanup() + err := session.Run("") + if err != nil { + t.Fatalf("expected nil but got %v", err) + } +} + +func TestExitStatusNonZero(t *testing.T) { + session, cleanup := newTestSession(t, &Server{ + Handler: func(s Session) { + s.Exit(1) + }, + }, nil) + defer cleanup() + err := session.Run("") + e, ok := err.(*gossh.ExitError) + if !ok { + t.Fatalf("expected ExitError but got %T", err) + } + if e.ExitStatus() != 1 { + t.Fatalf("exit-status = %#v; want %#v", e.ExitStatus(), 1) + } +} + +func TestPty(t *testing.T) { + term := "xterm" + winWidth := 40 + winHeight := 80 + done := make(chan bool) + session, cleanup := newTestSession(t, &Server{ + Handler: func(s Session) { + ptyReq, _, isPty := s.Pty() + if !isPty { + t.Fatalf("expected pty but none requested") + } + if ptyReq.Term != term { + t.Fatalf("expected term %#v but got %#v", term, ptyReq.Term) + } + if ptyReq.Window.Width != winWidth { + t.Fatalf("expected window width %#v but got %#v", winWidth, ptyReq.Window.Width) + } + if ptyReq.Window.Height != winHeight { + t.Fatalf("expected window height %#v but got %#v", winHeight, ptyReq.Window.Height) + } + close(done) + }, + }, nil) + defer cleanup() + if err := session.RequestPty(term, winHeight, winWidth, gossh.TerminalModes{}); err != nil { + t.Fatalf("unexpected error requesting PTY", err) + } + if err := session.Shell(); err != nil { + t.Fatalf("expected nil but got %v", err) + } + <-done +} + +func TestPtyResize(t *testing.T) { + winch0 := Window{40, 80} + winch1 := Window{80, 160} + winch2 := Window{20, 40} + winches := make(chan Window) + done := make(chan bool) + session, cleanup := newTestSession(t, &Server{ + Handler: func(s Session) { + ptyReq, winCh, isPty := s.Pty() + if !isPty { + t.Fatalf("expected pty but none requested") + } + if ptyReq.Window != winch0 { + t.Fatalf("expected window %#v but got %#v", winch0, ptyReq.Window) + } + for win := range winCh { + winches <- win + } + close(done) + }, + }, nil) + defer cleanup() + // winch0 + if err := session.RequestPty("xterm", winch0.Height, winch0.Width, gossh.TerminalModes{}); err != nil { + t.Fatalf("unexpected error requesting PTY", err) + } + if err := session.Shell(); err != nil { + t.Fatalf("expected nil but got %v", err) + } + gotWinch := <-winches + if gotWinch != winch0 { + t.Fatalf("expected window %#v but got %#v", winch0, gotWinch) + } + // winch1 + winchMsg := struct{ w, h uint32 }{uint32(winch1.Width), uint32(winch1.Height)} + ok, err := session.SendRequest("window-change", true, gossh.Marshal(&winchMsg)) + if err == nil && !ok { + t.Fatalf("unexpected error or bad reply on send request") + } + gotWinch = <-winches + if gotWinch != winch1 { + t.Fatalf("expected window %#v but got %#v", winch1, gotWinch) + } + // winch2 + winchMsg = struct{ w, h uint32 }{uint32(winch2.Width), uint32(winch2.Height)} + ok, err = session.SendRequest("window-change", true, gossh.Marshal(&winchMsg)) + if err == nil && !ok { + t.Fatalf("unexpected error or bad reply on send request") + } + gotWinch = <-winches + if gotWinch != winch2 { + t.Fatalf("expected window %#v but got %#v", winch2, gotWinch) + } + session.Close() + <-done +} diff --git a/ssh.go b/ssh.go index acaf297..0f5f43d 100644 --- a/ssh.go +++ b/ssh.go @@ -5,7 +5,6 @@ import ( "net" ) -// Signal as in RFC 4254 Section 6.10. type Signal string // POSIX signals as listed in RFC 4254 Section 6.10. @@ -52,9 +51,11 @@ type Window struct { Height int } -// Pty represents PTY configuration. +// Pty represents a PTY request and configuration. type Pty struct { + Term string Window Window + // HELP WANTED: terminal modes! } // Serve accepts incoming SSH connections on the listener l, creating a new diff --git a/util.go b/util.go index b713b53..ca775f5 100644 --- a/util.go +++ b/util.go @@ -54,44 +54,53 @@ func generateSigner() (ssh.Signer, error) { return ssh.NewSignerFromKey(key) } -func parsePtyRequest(s []byte) (width, height int, ok bool) { - _, s, ok = parseString(s) +func parsePtyRequest(s []byte) (pty Pty, ok bool) { + term, s, ok := parseString(s) if !ok { return } width32, s, ok := parseUint32(s) + if width32 < 1 { + ok = false + } if !ok { return } height32, _, ok := parseUint32(s) - width = int(width32) - height = int(height32) - if width < 1 { + if height32 < 1 { ok = false } - if height < 1 { - ok = false + if !ok { + return + } + pty = Pty{ + Term: term, + Window: Window{ + Width: int(width32), + Height: int(height32), + }, } return } -func parseWinchRequest(s []byte) (width, height int, ok bool) { - width32, _, ok := parseUint32(s) +func parseWinchRequest(s []byte) (win Window, ok bool) { + width32, s, ok := parseUint32(s) + if width32 < 1 { + ok = false + } if !ok { return } height32, _, ok := parseUint32(s) + if height32 < 1 { + ok = false + } if !ok { return } - - width = int(width32) - height = int(height32) - if width < 1 { - ok = false - } - if height < 1 { - ok = false + win = Window{ + Width: int(width32), + Height: int(height32), } return }