fixed/finished basic pty support, added an example, and included tests (#25)

* fixed/finished basic pty support, added an example, and included tests
* session: make the window channel have buffer of 1 and send initial window size on it
* _example/docker: added an ssh to docker-run example
* changes from review: let Reply handle WantReply, only allow setting sess.pty once
* circle: hopefully a working circleci config
This commit is contained in:
Jeff Lindsay 2017-02-15 18:08:25 -06:00 committed by GitHub
parent a307f226ad
commit a2a474964c
7 changed files with 476 additions and 29 deletions

112
_example/docker/docker.go Normal file

@ -0,0 +1,112 @@
package main
import (
"context"
"fmt"
"io"
"log"
"github.com/docker/docker/api/types"
"github.com/docker/docker/api/types/container"
"github.com/docker/docker/client"
"github.com/docker/docker/pkg/stdcopy"
"github.com/gliderlabs/ssh"
)
func main() {
ssh.Handle(func(sess ssh.Session) {
_, _, isTty := sess.Pty()
cfg := &container.Config{
Image: sess.User(),
Cmd: sess.Command(),
Env: sess.Environ(),
Tty: isTty,
OpenStdin: true,
AttachStderr: true,
AttachStdin: true,
AttachStdout: true,
StdinOnce: true,
Volumes: make(map[string]struct{}),
}
err, status, cleanup := dockerRun(cfg, sess)
defer cleanup()
if err != nil {
fmt.Fprintln(sess, err)
log.Println(err)
}
sess.Exit(int(status))
})
log.Println("starting ssh server on port 2222...")
log.Fatal(ssh.ListenAndServe(":2222", nil))
}
func dockerRun(cfg *container.Config, sess ssh.Session) (err error, status int64, cleanup func()) {
docker, err := client.NewEnvClient()
if err != nil {
panic(err)
}
status = 255
cleanup = func() {}
ctx := context.Background()
res, err := docker.ContainerCreate(ctx, cfg, nil, nil, "")
if err != nil {
return
}
cleanup = func() {
docker.ContainerRemove(ctx, res.ID, types.ContainerRemoveOptions{})
}
opts := types.ContainerAttachOptions{
Stdin: cfg.AttachStdin,
Stdout: cfg.AttachStdout,
Stderr: cfg.AttachStderr,
Stream: true,
}
stream, err := docker.ContainerAttach(ctx, res.ID, opts)
if err != nil {
return
}
cleanup = func() {
docker.ContainerRemove(ctx, res.ID, types.ContainerRemoveOptions{})
stream.Close()
}
outputErr := make(chan error)
go func() {
var err error
if cfg.Tty {
_, err = io.Copy(sess, stream.Conn)
} else {
_, err = stdcopy.StdCopy(sess, sess.Stderr(), stream.Reader)
}
outputErr <- err
}()
go func() {
defer stream.CloseWrite()
io.Copy(stream.Conn, sess)
}()
err = docker.ContainerStart(ctx, res.ID, types.ContainerStartOptions{})
if err != nil {
return
}
if cfg.Tty {
_, winCh, _ := sess.Pty()
go func() {
for win := range winCh {
err := docker.ContainerResize(ctx, res.ID, types.ResizeOptions{
Height: uint(win.Height),
Width: uint(win.Width),
})
if err != nil {
log.Println(err)
break
}
}
}()
}
status, err = docker.ContainerWait(ctx, res.ID)
if err != nil {
return
}
err = <-outputErr
return
}

48
_example/pty/pty.go Normal file

@ -0,0 +1,48 @@
package main
import (
"fmt"
"io"
"log"
"os"
"os/exec"
"syscall"
"unsafe"
"github.com/gliderlabs/ssh"
"github.com/kr/pty"
)
func setWinsize(f *os.File, w, h int) {
syscall.Syscall(syscall.SYS_IOCTL, f.Fd(), uintptr(syscall.TIOCSWINSZ),
uintptr(unsafe.Pointer(&struct{ h, w, x, y uint16 }{uint16(h), uint16(w), 0, 0})))
}
func main() {
ssh.Handle(func(s ssh.Session) {
cmd := exec.Command("top")
ptyReq, winCh, isPty := s.Pty()
if isPty {
cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", ptyReq.Term))
f, err := pty.Start(cmd)
if err != nil {
panic(err)
}
go func() {
for win := range winCh {
setWinsize(f, win.Width, win.Height)
}
}()
go func() {
io.Copy(f, s) // stdin
}()
io.Copy(s, f) // stdout
} else {
io.WriteString(s, "No PTY requested.\n")
s.Exit(1)
}
})
log.Println("starting ssh server on port 2222...")
log.Fatal(ssh.ListenAndServe(":2222", nil))
}

3
circle.yml Normal file

@ -0,0 +1,3 @@
test:
override:
- go test -v -race

@ -126,8 +126,6 @@ func (sess *session) Pty() (Pty, <-chan Window, bool) {
func (sess *session) handleRequests(reqs <-chan *gossh.Request) {
for req := range reqs {
var width, height int
var ok bool
switch req.Type {
case "shell", "exec":
if sess.handled {
@ -152,8 +150,9 @@ func (sess *session) handleRequests(reqs <-chan *gossh.Request) {
var kv = struct{ Key, Value string }{}
gossh.Unmarshal(req.Payload, &kv)
sess.env = append(sess.env, fmt.Sprintf("%s=%s", kv.Key, kv.Value))
req.Reply(true, nil)
case "pty-req":
if sess.handled {
if sess.handled || sess.pty != nil {
req.Reply(false, nil)
continue
}
@ -164,23 +163,27 @@ func (sess *session) handleRequests(reqs <-chan *gossh.Request) {
continue
}
}
width, height, ok = parsePtyRequest(req.Payload)
ptyReq, ok := parsePtyRequest(req.Payload)
if ok {
sess.pty = &Pty{Window{width, height}}
sess.winch = make(chan Window)
sess.pty = &ptyReq
sess.winch = make(chan Window, 1)
sess.winch <- ptyReq.Window
defer func() {
close(sess.winch)
}()
}
req.Reply(ok, nil)
case "window-change":
if sess.pty == nil {
req.Reply(false, nil)
continue
}
width, height, ok = parseWinchRequest(req.Payload)
win, ok := parseWinchRequest(req.Payload)
if ok {
sess.pty.Window = Window{width, height}
sess.winch <- sess.pty.Window
sess.pty.Window = win
sess.winch <- win
}
req.Reply(ok, nil)
}
}
}

271
session_test.go Normal file

@ -0,0 +1,271 @@
package ssh
import (
"bytes"
"fmt"
"io"
"net"
"testing"
gossh "golang.org/x/crypto/ssh"
)
func (srv *Server) serveOnce(l net.Listener) error {
config, err := srv.makeConfig()
if err != nil {
return err
}
conn, e := l.Accept()
if e != nil {
return e
}
srv.handleConn(conn, config)
return nil
}
func newLocalListener() net.Listener {
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
if l, err = net.Listen("tcp6", "[::1]:0"); err != nil {
panic(fmt.Sprintf("failed to listen on a port: %v", err))
}
}
return l
}
func newClientSession(t *testing.T, addr string, config *gossh.ClientConfig) (*gossh.Session, func()) {
if config == nil {
config = &gossh.ClientConfig{
User: "testuser",
Auth: []gossh.AuthMethod{
gossh.Password("testpass"),
},
}
}
client, err := gossh.Dial("tcp", addr, config)
if err != nil {
t.Fatal(err)
}
session, err := client.NewSession()
if err != nil {
t.Fatal(err)
}
return session, func() {
session.Close()
client.Close()
}
}
func newTestSession(t *testing.T, srv *Server, cfg *gossh.ClientConfig) (*gossh.Session, func()) {
l := newLocalListener()
go srv.serveOnce(l)
return newClientSession(t, l.Addr().String(), cfg)
}
func TestStdout(t *testing.T) {
testBytes := []byte("Hello world\n")
session, cleanup := newTestSession(t, &Server{
Handler: func(s Session) {
s.Write(testBytes)
},
}, nil)
defer cleanup()
var stdout bytes.Buffer
session.Stdout = &stdout
if err := session.Run(""); err != nil {
t.Fatal(err)
}
if !bytes.Equal(stdout.Bytes(), testBytes) {
t.Fatalf("stdout = %#v; want %#v", stdout.Bytes(), testBytes)
}
}
func TestStderr(t *testing.T) {
testBytes := []byte("Hello world\n")
session, cleanup := newTestSession(t, &Server{
Handler: func(s Session) {
s.Stderr().Write(testBytes)
},
}, nil)
defer cleanup()
var stderr bytes.Buffer
session.Stderr = &stderr
if err := session.Run(""); err != nil {
t.Fatal(err)
}
if !bytes.Equal(stderr.Bytes(), testBytes) {
t.Fatalf("stderr = %#v; want %#v", stderr.Bytes(), testBytes)
}
}
func TestStdin(t *testing.T) {
testBytes := []byte("Hello world\n")
session, cleanup := newTestSession(t, &Server{
Handler: func(s Session) {
io.Copy(s, s) // stdin back into stdout
},
}, nil)
defer cleanup()
var stdout bytes.Buffer
session.Stdout = &stdout
session.Stdin = bytes.NewBuffer(testBytes)
if err := session.Run(""); err != nil {
t.Fatal(err)
}
if !bytes.Equal(stdout.Bytes(), testBytes) {
t.Fatalf("stdout = %#v; want %#v given stdin = %#v", stdout.Bytes(), testBytes, testBytes)
}
}
func TestUser(t *testing.T) {
testUser := []byte("progrium")
session, cleanup := newTestSession(t, &Server{
Handler: func(s Session) {
io.WriteString(s, s.User())
},
}, &gossh.ClientConfig{
User: string(testUser),
})
defer cleanup()
var stdout bytes.Buffer
session.Stdout = &stdout
if err := session.Run(""); err != nil {
t.Fatal(err)
}
if !bytes.Equal(stdout.Bytes(), testUser) {
t.Fatalf("stdout = %#v; want %#v given user = %#v", stdout.Bytes(), testUser, string(testUser))
}
}
func TestDefaultExitStatusZero(t *testing.T) {
session, cleanup := newTestSession(t, &Server{
Handler: func(s Session) {
// noop
},
}, nil)
defer cleanup()
err := session.Run("")
if err != nil {
t.Fatalf("expected nil but got %v", err)
}
}
func TestExplicitExitStatusZero(t *testing.T) {
session, cleanup := newTestSession(t, &Server{
Handler: func(s Session) {
s.Exit(0)
},
}, nil)
defer cleanup()
err := session.Run("")
if err != nil {
t.Fatalf("expected nil but got %v", err)
}
}
func TestExitStatusNonZero(t *testing.T) {
session, cleanup := newTestSession(t, &Server{
Handler: func(s Session) {
s.Exit(1)
},
}, nil)
defer cleanup()
err := session.Run("")
e, ok := err.(*gossh.ExitError)
if !ok {
t.Fatalf("expected ExitError but got %T", err)
}
if e.ExitStatus() != 1 {
t.Fatalf("exit-status = %#v; want %#v", e.ExitStatus(), 1)
}
}
func TestPty(t *testing.T) {
term := "xterm"
winWidth := 40
winHeight := 80
done := make(chan bool)
session, cleanup := newTestSession(t, &Server{
Handler: func(s Session) {
ptyReq, _, isPty := s.Pty()
if !isPty {
t.Fatalf("expected pty but none requested")
}
if ptyReq.Term != term {
t.Fatalf("expected term %#v but got %#v", term, ptyReq.Term)
}
if ptyReq.Window.Width != winWidth {
t.Fatalf("expected window width %#v but got %#v", winWidth, ptyReq.Window.Width)
}
if ptyReq.Window.Height != winHeight {
t.Fatalf("expected window height %#v but got %#v", winHeight, ptyReq.Window.Height)
}
close(done)
},
}, nil)
defer cleanup()
if err := session.RequestPty(term, winHeight, winWidth, gossh.TerminalModes{}); err != nil {
t.Fatalf("unexpected error requesting PTY", err)
}
if err := session.Shell(); err != nil {
t.Fatalf("expected nil but got %v", err)
}
<-done
}
func TestPtyResize(t *testing.T) {
winch0 := Window{40, 80}
winch1 := Window{80, 160}
winch2 := Window{20, 40}
winches := make(chan Window)
done := make(chan bool)
session, cleanup := newTestSession(t, &Server{
Handler: func(s Session) {
ptyReq, winCh, isPty := s.Pty()
if !isPty {
t.Fatalf("expected pty but none requested")
}
if ptyReq.Window != winch0 {
t.Fatalf("expected window %#v but got %#v", winch0, ptyReq.Window)
}
for win := range winCh {
winches <- win
}
close(done)
},
}, nil)
defer cleanup()
// winch0
if err := session.RequestPty("xterm", winch0.Height, winch0.Width, gossh.TerminalModes{}); err != nil {
t.Fatalf("unexpected error requesting PTY", err)
}
if err := session.Shell(); err != nil {
t.Fatalf("expected nil but got %v", err)
}
gotWinch := <-winches
if gotWinch != winch0 {
t.Fatalf("expected window %#v but got %#v", winch0, gotWinch)
}
// winch1
winchMsg := struct{ w, h uint32 }{uint32(winch1.Width), uint32(winch1.Height)}
ok, err := session.SendRequest("window-change", true, gossh.Marshal(&winchMsg))
if err == nil && !ok {
t.Fatalf("unexpected error or bad reply on send request")
}
gotWinch = <-winches
if gotWinch != winch1 {
t.Fatalf("expected window %#v but got %#v", winch1, gotWinch)
}
// winch2
winchMsg = struct{ w, h uint32 }{uint32(winch2.Width), uint32(winch2.Height)}
ok, err = session.SendRequest("window-change", true, gossh.Marshal(&winchMsg))
if err == nil && !ok {
t.Fatalf("unexpected error or bad reply on send request")
}
gotWinch = <-winches
if gotWinch != winch2 {
t.Fatalf("expected window %#v but got %#v", winch2, gotWinch)
}
session.Close()
<-done
}

5
ssh.go

@ -5,7 +5,6 @@ import (
"net"
)
// Signal as in RFC 4254 Section 6.10.
type Signal string
// POSIX signals as listed in RFC 4254 Section 6.10.
@ -52,9 +51,11 @@ type Window struct {
Height int
}
// Pty represents PTY configuration.
// Pty represents a PTY request and configuration.
type Pty struct {
Term string
Window Window
// HELP WANTED: terminal modes!
}
// Serve accepts incoming SSH connections on the listener l, creating a new

43
util.go

@ -54,44 +54,53 @@ func generateSigner() (ssh.Signer, error) {
return ssh.NewSignerFromKey(key)
}
func parsePtyRequest(s []byte) (width, height int, ok bool) {
_, s, ok = parseString(s)
func parsePtyRequest(s []byte) (pty Pty, ok bool) {
term, s, ok := parseString(s)
if !ok {
return
}
width32, s, ok := parseUint32(s)
if width32 < 1 {
ok = false
}
if !ok {
return
}
height32, _, ok := parseUint32(s)
width = int(width32)
height = int(height32)
if width < 1 {
if height32 < 1 {
ok = false
}
if height < 1 {
ok = false
if !ok {
return
}
pty = Pty{
Term: term,
Window: Window{
Width: int(width32),
Height: int(height32),
},
}
return
}
func parseWinchRequest(s []byte) (width, height int, ok bool) {
width32, _, ok := parseUint32(s)
func parseWinchRequest(s []byte) (win Window, ok bool) {
width32, s, ok := parseUint32(s)
if width32 < 1 {
ok = false
}
if !ok {
return
}
height32, _, ok := parseUint32(s)
if height32 < 1 {
ok = false
}
if !ok {
return
}
width = int(width32)
height = int(height32)
if width < 1 {
ok = false
}
if height < 1 {
ok = false
win = Window{
Width: int(width32),
Height: int(height32),
}
return
}