session: adding signal handling support (#44)

This commit is contained in:
Jeff Lindsay 2017-11-01 18:03:54 -05:00 committed by GitHub
parent 4a4de396c4
commit 3eeacb7850
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 75 additions and 2 deletions

@ -6,6 +6,7 @@ import (
"errors"
"fmt"
"net"
"sync"
"github.com/anmitsu/go-shlex"
gossh "golang.org/x/crypto/ssh"
@ -63,9 +64,19 @@ type Session interface {
// of whether or not a PTY was accepted for this session.
Pty() (Pty, <-chan Window, bool)
// TODO: Signals(c chan<- Signal)
// Signals registers a channel to receive signals sent from the client. The
// channel must handle signal sends or it will block the SSH request loop.
// Registering nil will unregister the channel from signal sends. During the
// time no channel is registered signals are buffered up to a reasonable amount.
// If there are buffered signals when a channel is registered, they will be
// sent in order on the channel immediately after registering.
Signals(c chan<- Signal)
}
// maxSigBufSize is how many signals will be buffered
// when there is no signal channel specified
const maxSigBufSize = 128
func sessionHandler(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx *sshContext) {
ch, reqs, err := newChan.Accept()
if err != nil {
@ -83,6 +94,7 @@ func sessionHandler(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChanne
}
type session struct {
sync.Mutex
gossh.Channel
conn *gossh.ServerConn
handler Handler
@ -94,6 +106,8 @@ type session struct {
ptyCb PtyCallback
cmd []string
ctx *sshContext
sigCh chan<- Signal
sigBuf []Signal
}
func (sess *session) Write(p []byte) (n int, err error) {
@ -132,6 +146,8 @@ func (sess *session) Context() context.Context {
}
func (sess *session) Exit(code int) error {
sess.Lock()
defer sess.Unlock()
if sess.exited {
return errors.New("Session.Exit called multiple times")
}
@ -172,6 +188,19 @@ func (sess *session) Pty() (Pty, <-chan Window, bool) {
return Pty{}, sess.winch, false
}
func (sess *session) Signals(c chan<- Signal) {
sess.Lock()
defer sess.Unlock()
sess.sigCh = c
if len(sess.sigBuf) > 0 {
go func() {
for _, sig := range sess.sigBuf {
sess.sigCh <- sig
}
}()
}
}
func (sess *session) handleRequests(reqs <-chan *gossh.Request) {
for req := range reqs {
switch req.Type {
@ -195,10 +224,22 @@ func (sess *session) handleRequests(reqs <-chan *gossh.Request) {
req.Reply(false, nil)
continue
}
var kv = struct{ Key, Value string }{}
var kv struct{ Key, Value string }
gossh.Unmarshal(req.Payload, &kv)
sess.env = append(sess.env, fmt.Sprintf("%s=%s", kv.Key, kv.Value))
req.Reply(true, nil)
case "signal":
var payload struct{ Signal string }
gossh.Unmarshal(req.Payload, &payload)
sess.Lock()
if sess.sigCh != nil {
sess.sigCh <- Signal(payload.Signal)
} else {
if len(sess.sigBuf) < maxSigBufSize {
sess.sigBuf = append(sess.sigBuf, Signal(payload.Signal))
}
}
sess.Unlock()
case "pty-req":
if sess.handled || sess.pty != nil {
req.Reply(false, nil)

@ -280,3 +280,35 @@ func TestPtyResize(t *testing.T) {
session.Close()
<-done
}
func TestSignals(t *testing.T) {
t.Parallel()
session, _, cleanup := newTestSession(t, &Server{
Handler: func(s Session) {
signals := make(chan Signal)
s.Signals(signals)
if sig := <-signals; sig != SIGINT {
t.Fatalf("expected signal %v but got %v", SIGINT, sig)
}
exiter := make(chan bool)
go func() {
if sig := <-signals; sig == SIGKILL {
close(exiter)
}
}()
<-exiter
},
}, nil)
defer cleanup()
go func() {
session.Signal(gossh.SIGINT)
session.Signal(gossh.SIGKILL)
}()
err := session.Run("")
if err != nil {
t.Fatalf("expected nil but got %v", err)
}
}