fixed/finished basic pty support, added an example, and included tests (#25)
* fixed/finished basic pty support, added an example, and included tests * session: make the window channel have buffer of 1 and send initial window size on it * _example/docker: added an ssh to docker-run example * changes from review: let Reply handle WantReply, only allow setting sess.pty once * circle: hopefully a working circleci config
This commit is contained in:
parent
a307f226ad
commit
a2a474964c
112
_example/docker/docker.go
Normal file
112
_example/docker/docker.go
Normal file
@ -0,0 +1,112 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
|
||||
"github.com/docker/docker/api/types"
|
||||
"github.com/docker/docker/api/types/container"
|
||||
"github.com/docker/docker/client"
|
||||
"github.com/docker/docker/pkg/stdcopy"
|
||||
"github.com/gliderlabs/ssh"
|
||||
)
|
||||
|
||||
func main() {
|
||||
ssh.Handle(func(sess ssh.Session) {
|
||||
_, _, isTty := sess.Pty()
|
||||
cfg := &container.Config{
|
||||
Image: sess.User(),
|
||||
Cmd: sess.Command(),
|
||||
Env: sess.Environ(),
|
||||
Tty: isTty,
|
||||
OpenStdin: true,
|
||||
AttachStderr: true,
|
||||
AttachStdin: true,
|
||||
AttachStdout: true,
|
||||
StdinOnce: true,
|
||||
Volumes: make(map[string]struct{}),
|
||||
}
|
||||
err, status, cleanup := dockerRun(cfg, sess)
|
||||
defer cleanup()
|
||||
if err != nil {
|
||||
fmt.Fprintln(sess, err)
|
||||
log.Println(err)
|
||||
}
|
||||
sess.Exit(int(status))
|
||||
})
|
||||
|
||||
log.Println("starting ssh server on port 2222...")
|
||||
log.Fatal(ssh.ListenAndServe(":2222", nil))
|
||||
}
|
||||
|
||||
func dockerRun(cfg *container.Config, sess ssh.Session) (err error, status int64, cleanup func()) {
|
||||
docker, err := client.NewEnvClient()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
status = 255
|
||||
cleanup = func() {}
|
||||
ctx := context.Background()
|
||||
res, err := docker.ContainerCreate(ctx, cfg, nil, nil, "")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
cleanup = func() {
|
||||
docker.ContainerRemove(ctx, res.ID, types.ContainerRemoveOptions{})
|
||||
}
|
||||
opts := types.ContainerAttachOptions{
|
||||
Stdin: cfg.AttachStdin,
|
||||
Stdout: cfg.AttachStdout,
|
||||
Stderr: cfg.AttachStderr,
|
||||
Stream: true,
|
||||
}
|
||||
stream, err := docker.ContainerAttach(ctx, res.ID, opts)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
cleanup = func() {
|
||||
docker.ContainerRemove(ctx, res.ID, types.ContainerRemoveOptions{})
|
||||
stream.Close()
|
||||
}
|
||||
outputErr := make(chan error)
|
||||
go func() {
|
||||
var err error
|
||||
if cfg.Tty {
|
||||
_, err = io.Copy(sess, stream.Conn)
|
||||
} else {
|
||||
_, err = stdcopy.StdCopy(sess, sess.Stderr(), stream.Reader)
|
||||
}
|
||||
outputErr <- err
|
||||
}()
|
||||
go func() {
|
||||
defer stream.CloseWrite()
|
||||
io.Copy(stream.Conn, sess)
|
||||
}()
|
||||
err = docker.ContainerStart(ctx, res.ID, types.ContainerStartOptions{})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if cfg.Tty {
|
||||
_, winCh, _ := sess.Pty()
|
||||
go func() {
|
||||
for win := range winCh {
|
||||
err := docker.ContainerResize(ctx, res.ID, types.ResizeOptions{
|
||||
Height: uint(win.Height),
|
||||
Width: uint(win.Width),
|
||||
})
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
break
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
status, err = docker.ContainerWait(ctx, res.ID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = <-outputErr
|
||||
return
|
||||
}
|
48
_example/pty/pty.go
Normal file
48
_example/pty/pty.go
Normal file
@ -0,0 +1,48 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
"os/exec"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
"github.com/kr/pty"
|
||||
)
|
||||
|
||||
func setWinsize(f *os.File, w, h int) {
|
||||
syscall.Syscall(syscall.SYS_IOCTL, f.Fd(), uintptr(syscall.TIOCSWINSZ),
|
||||
uintptr(unsafe.Pointer(&struct{ h, w, x, y uint16 }{uint16(h), uint16(w), 0, 0})))
|
||||
}
|
||||
|
||||
func main() {
|
||||
ssh.Handle(func(s ssh.Session) {
|
||||
cmd := exec.Command("top")
|
||||
ptyReq, winCh, isPty := s.Pty()
|
||||
if isPty {
|
||||
cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", ptyReq.Term))
|
||||
f, err := pty.Start(cmd)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
go func() {
|
||||
for win := range winCh {
|
||||
setWinsize(f, win.Width, win.Height)
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
io.Copy(f, s) // stdin
|
||||
}()
|
||||
io.Copy(s, f) // stdout
|
||||
} else {
|
||||
io.WriteString(s, "No PTY requested.\n")
|
||||
s.Exit(1)
|
||||
}
|
||||
})
|
||||
|
||||
log.Println("starting ssh server on port 2222...")
|
||||
log.Fatal(ssh.ListenAndServe(":2222", nil))
|
||||
}
|
3
circle.yml
Normal file
3
circle.yml
Normal file
@ -0,0 +1,3 @@
|
||||
test:
|
||||
override:
|
||||
- go test -v -race
|
23
session.go
23
session.go
@ -126,8 +126,6 @@ func (sess *session) Pty() (Pty, <-chan Window, bool) {
|
||||
|
||||
func (sess *session) handleRequests(reqs <-chan *gossh.Request) {
|
||||
for req := range reqs {
|
||||
var width, height int
|
||||
var ok bool
|
||||
switch req.Type {
|
||||
case "shell", "exec":
|
||||
if sess.handled {
|
||||
@ -152,8 +150,9 @@ func (sess *session) handleRequests(reqs <-chan *gossh.Request) {
|
||||
var kv = struct{ Key, Value string }{}
|
||||
gossh.Unmarshal(req.Payload, &kv)
|
||||
sess.env = append(sess.env, fmt.Sprintf("%s=%s", kv.Key, kv.Value))
|
||||
req.Reply(true, nil)
|
||||
case "pty-req":
|
||||
if sess.handled {
|
||||
if sess.handled || sess.pty != nil {
|
||||
req.Reply(false, nil)
|
||||
continue
|
||||
}
|
||||
@ -164,23 +163,27 @@ func (sess *session) handleRequests(reqs <-chan *gossh.Request) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
width, height, ok = parsePtyRequest(req.Payload)
|
||||
ptyReq, ok := parsePtyRequest(req.Payload)
|
||||
if ok {
|
||||
sess.pty = &Pty{Window{width, height}}
|
||||
sess.winch = make(chan Window)
|
||||
sess.pty = &ptyReq
|
||||
sess.winch = make(chan Window, 1)
|
||||
sess.winch <- ptyReq.Window
|
||||
defer func() {
|
||||
close(sess.winch)
|
||||
}()
|
||||
}
|
||||
|
||||
req.Reply(ok, nil)
|
||||
case "window-change":
|
||||
if sess.pty == nil {
|
||||
req.Reply(false, nil)
|
||||
continue
|
||||
}
|
||||
width, height, ok = parseWinchRequest(req.Payload)
|
||||
win, ok := parseWinchRequest(req.Payload)
|
||||
if ok {
|
||||
sess.pty.Window = Window{width, height}
|
||||
sess.winch <- sess.pty.Window
|
||||
sess.pty.Window = win
|
||||
sess.winch <- win
|
||||
}
|
||||
req.Reply(ok, nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
271
session_test.go
Normal file
271
session_test.go
Normal file
@ -0,0 +1,271 @@
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
gossh "golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
func (srv *Server) serveOnce(l net.Listener) error {
|
||||
config, err := srv.makeConfig()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
conn, e := l.Accept()
|
||||
if e != nil {
|
||||
return e
|
||||
}
|
||||
srv.handleConn(conn, config)
|
||||
return nil
|
||||
}
|
||||
|
||||
func newLocalListener() net.Listener {
|
||||
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
if l, err = net.Listen("tcp6", "[::1]:0"); err != nil {
|
||||
panic(fmt.Sprintf("failed to listen on a port: %v", err))
|
||||
}
|
||||
}
|
||||
return l
|
||||
}
|
||||
|
||||
func newClientSession(t *testing.T, addr string, config *gossh.ClientConfig) (*gossh.Session, func()) {
|
||||
if config == nil {
|
||||
config = &gossh.ClientConfig{
|
||||
User: "testuser",
|
||||
Auth: []gossh.AuthMethod{
|
||||
gossh.Password("testpass"),
|
||||
},
|
||||
}
|
||||
}
|
||||
client, err := gossh.Dial("tcp", addr, config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
session, err := client.NewSession()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return session, func() {
|
||||
session.Close()
|
||||
client.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func newTestSession(t *testing.T, srv *Server, cfg *gossh.ClientConfig) (*gossh.Session, func()) {
|
||||
l := newLocalListener()
|
||||
go srv.serveOnce(l)
|
||||
return newClientSession(t, l.Addr().String(), cfg)
|
||||
}
|
||||
|
||||
func TestStdout(t *testing.T) {
|
||||
testBytes := []byte("Hello world\n")
|
||||
session, cleanup := newTestSession(t, &Server{
|
||||
Handler: func(s Session) {
|
||||
s.Write(testBytes)
|
||||
},
|
||||
}, nil)
|
||||
defer cleanup()
|
||||
var stdout bytes.Buffer
|
||||
session.Stdout = &stdout
|
||||
if err := session.Run(""); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(stdout.Bytes(), testBytes) {
|
||||
t.Fatalf("stdout = %#v; want %#v", stdout.Bytes(), testBytes)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStderr(t *testing.T) {
|
||||
testBytes := []byte("Hello world\n")
|
||||
session, cleanup := newTestSession(t, &Server{
|
||||
Handler: func(s Session) {
|
||||
s.Stderr().Write(testBytes)
|
||||
},
|
||||
}, nil)
|
||||
defer cleanup()
|
||||
var stderr bytes.Buffer
|
||||
session.Stderr = &stderr
|
||||
if err := session.Run(""); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(stderr.Bytes(), testBytes) {
|
||||
t.Fatalf("stderr = %#v; want %#v", stderr.Bytes(), testBytes)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStdin(t *testing.T) {
|
||||
testBytes := []byte("Hello world\n")
|
||||
session, cleanup := newTestSession(t, &Server{
|
||||
Handler: func(s Session) {
|
||||
io.Copy(s, s) // stdin back into stdout
|
||||
},
|
||||
}, nil)
|
||||
defer cleanup()
|
||||
var stdout bytes.Buffer
|
||||
session.Stdout = &stdout
|
||||
session.Stdin = bytes.NewBuffer(testBytes)
|
||||
if err := session.Run(""); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(stdout.Bytes(), testBytes) {
|
||||
t.Fatalf("stdout = %#v; want %#v given stdin = %#v", stdout.Bytes(), testBytes, testBytes)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUser(t *testing.T) {
|
||||
testUser := []byte("progrium")
|
||||
session, cleanup := newTestSession(t, &Server{
|
||||
Handler: func(s Session) {
|
||||
io.WriteString(s, s.User())
|
||||
},
|
||||
}, &gossh.ClientConfig{
|
||||
User: string(testUser),
|
||||
})
|
||||
defer cleanup()
|
||||
var stdout bytes.Buffer
|
||||
session.Stdout = &stdout
|
||||
if err := session.Run(""); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(stdout.Bytes(), testUser) {
|
||||
t.Fatalf("stdout = %#v; want %#v given user = %#v", stdout.Bytes(), testUser, string(testUser))
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultExitStatusZero(t *testing.T) {
|
||||
session, cleanup := newTestSession(t, &Server{
|
||||
Handler: func(s Session) {
|
||||
// noop
|
||||
},
|
||||
}, nil)
|
||||
defer cleanup()
|
||||
err := session.Run("")
|
||||
if err != nil {
|
||||
t.Fatalf("expected nil but got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExplicitExitStatusZero(t *testing.T) {
|
||||
session, cleanup := newTestSession(t, &Server{
|
||||
Handler: func(s Session) {
|
||||
s.Exit(0)
|
||||
},
|
||||
}, nil)
|
||||
defer cleanup()
|
||||
err := session.Run("")
|
||||
if err != nil {
|
||||
t.Fatalf("expected nil but got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExitStatusNonZero(t *testing.T) {
|
||||
session, cleanup := newTestSession(t, &Server{
|
||||
Handler: func(s Session) {
|
||||
s.Exit(1)
|
||||
},
|
||||
}, nil)
|
||||
defer cleanup()
|
||||
err := session.Run("")
|
||||
e, ok := err.(*gossh.ExitError)
|
||||
if !ok {
|
||||
t.Fatalf("expected ExitError but got %T", err)
|
||||
}
|
||||
if e.ExitStatus() != 1 {
|
||||
t.Fatalf("exit-status = %#v; want %#v", e.ExitStatus(), 1)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPty(t *testing.T) {
|
||||
term := "xterm"
|
||||
winWidth := 40
|
||||
winHeight := 80
|
||||
done := make(chan bool)
|
||||
session, cleanup := newTestSession(t, &Server{
|
||||
Handler: func(s Session) {
|
||||
ptyReq, _, isPty := s.Pty()
|
||||
if !isPty {
|
||||
t.Fatalf("expected pty but none requested")
|
||||
}
|
||||
if ptyReq.Term != term {
|
||||
t.Fatalf("expected term %#v but got %#v", term, ptyReq.Term)
|
||||
}
|
||||
if ptyReq.Window.Width != winWidth {
|
||||
t.Fatalf("expected window width %#v but got %#v", winWidth, ptyReq.Window.Width)
|
||||
}
|
||||
if ptyReq.Window.Height != winHeight {
|
||||
t.Fatalf("expected window height %#v but got %#v", winHeight, ptyReq.Window.Height)
|
||||
}
|
||||
close(done)
|
||||
},
|
||||
}, nil)
|
||||
defer cleanup()
|
||||
if err := session.RequestPty(term, winHeight, winWidth, gossh.TerminalModes{}); err != nil {
|
||||
t.Fatalf("unexpected error requesting PTY", err)
|
||||
}
|
||||
if err := session.Shell(); err != nil {
|
||||
t.Fatalf("expected nil but got %v", err)
|
||||
}
|
||||
<-done
|
||||
}
|
||||
|
||||
func TestPtyResize(t *testing.T) {
|
||||
winch0 := Window{40, 80}
|
||||
winch1 := Window{80, 160}
|
||||
winch2 := Window{20, 40}
|
||||
winches := make(chan Window)
|
||||
done := make(chan bool)
|
||||
session, cleanup := newTestSession(t, &Server{
|
||||
Handler: func(s Session) {
|
||||
ptyReq, winCh, isPty := s.Pty()
|
||||
if !isPty {
|
||||
t.Fatalf("expected pty but none requested")
|
||||
}
|
||||
if ptyReq.Window != winch0 {
|
||||
t.Fatalf("expected window %#v but got %#v", winch0, ptyReq.Window)
|
||||
}
|
||||
for win := range winCh {
|
||||
winches <- win
|
||||
}
|
||||
close(done)
|
||||
},
|
||||
}, nil)
|
||||
defer cleanup()
|
||||
// winch0
|
||||
if err := session.RequestPty("xterm", winch0.Height, winch0.Width, gossh.TerminalModes{}); err != nil {
|
||||
t.Fatalf("unexpected error requesting PTY", err)
|
||||
}
|
||||
if err := session.Shell(); err != nil {
|
||||
t.Fatalf("expected nil but got %v", err)
|
||||
}
|
||||
gotWinch := <-winches
|
||||
if gotWinch != winch0 {
|
||||
t.Fatalf("expected window %#v but got %#v", winch0, gotWinch)
|
||||
}
|
||||
// winch1
|
||||
winchMsg := struct{ w, h uint32 }{uint32(winch1.Width), uint32(winch1.Height)}
|
||||
ok, err := session.SendRequest("window-change", true, gossh.Marshal(&winchMsg))
|
||||
if err == nil && !ok {
|
||||
t.Fatalf("unexpected error or bad reply on send request")
|
||||
}
|
||||
gotWinch = <-winches
|
||||
if gotWinch != winch1 {
|
||||
t.Fatalf("expected window %#v but got %#v", winch1, gotWinch)
|
||||
}
|
||||
// winch2
|
||||
winchMsg = struct{ w, h uint32 }{uint32(winch2.Width), uint32(winch2.Height)}
|
||||
ok, err = session.SendRequest("window-change", true, gossh.Marshal(&winchMsg))
|
||||
if err == nil && !ok {
|
||||
t.Fatalf("unexpected error or bad reply on send request")
|
||||
}
|
||||
gotWinch = <-winches
|
||||
if gotWinch != winch2 {
|
||||
t.Fatalf("expected window %#v but got %#v", winch2, gotWinch)
|
||||
}
|
||||
session.Close()
|
||||
<-done
|
||||
}
|
5
ssh.go
5
ssh.go
@ -5,7 +5,6 @@ import (
|
||||
"net"
|
||||
)
|
||||
|
||||
// Signal as in RFC 4254 Section 6.10.
|
||||
type Signal string
|
||||
|
||||
// POSIX signals as listed in RFC 4254 Section 6.10.
|
||||
@ -52,9 +51,11 @@ type Window struct {
|
||||
Height int
|
||||
}
|
||||
|
||||
// Pty represents PTY configuration.
|
||||
// Pty represents a PTY request and configuration.
|
||||
type Pty struct {
|
||||
Term string
|
||||
Window Window
|
||||
// HELP WANTED: terminal modes!
|
||||
}
|
||||
|
||||
// Serve accepts incoming SSH connections on the listener l, creating a new
|
||||
|
43
util.go
43
util.go
@ -54,44 +54,53 @@ func generateSigner() (ssh.Signer, error) {
|
||||
return ssh.NewSignerFromKey(key)
|
||||
}
|
||||
|
||||
func parsePtyRequest(s []byte) (width, height int, ok bool) {
|
||||
_, s, ok = parseString(s)
|
||||
func parsePtyRequest(s []byte) (pty Pty, ok bool) {
|
||||
term, s, ok := parseString(s)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
width32, s, ok := parseUint32(s)
|
||||
if width32 < 1 {
|
||||
ok = false
|
||||
}
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
height32, _, ok := parseUint32(s)
|
||||
width = int(width32)
|
||||
height = int(height32)
|
||||
if width < 1 {
|
||||
if height32 < 1 {
|
||||
ok = false
|
||||
}
|
||||
if height < 1 {
|
||||
ok = false
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
pty = Pty{
|
||||
Term: term,
|
||||
Window: Window{
|
||||
Width: int(width32),
|
||||
Height: int(height32),
|
||||
},
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func parseWinchRequest(s []byte) (width, height int, ok bool) {
|
||||
width32, _, ok := parseUint32(s)
|
||||
func parseWinchRequest(s []byte) (win Window, ok bool) {
|
||||
width32, s, ok := parseUint32(s)
|
||||
if width32 < 1 {
|
||||
ok = false
|
||||
}
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
height32, _, ok := parseUint32(s)
|
||||
if height32 < 1 {
|
||||
ok = false
|
||||
}
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
width = int(width32)
|
||||
height = int(height32)
|
||||
if width < 1 {
|
||||
ok = false
|
||||
}
|
||||
if height < 1 {
|
||||
ok = false
|
||||
win = Window{
|
||||
Width: int(width32),
|
||||
Height: int(height32),
|
||||
}
|
||||
return
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user