session: adding signal handling support (#44)
This commit is contained in:
parent
4a4de396c4
commit
3eeacb7850
45
session.go
45
session.go
@ -6,6 +6,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/anmitsu/go-shlex"
|
||||
gossh "golang.org/x/crypto/ssh"
|
||||
@ -63,9 +64,19 @@ type Session interface {
|
||||
// of whether or not a PTY was accepted for this session.
|
||||
Pty() (Pty, <-chan Window, bool)
|
||||
|
||||
// TODO: Signals(c chan<- Signal)
|
||||
// Signals registers a channel to receive signals sent from the client. The
|
||||
// channel must handle signal sends or it will block the SSH request loop.
|
||||
// Registering nil will unregister the channel from signal sends. During the
|
||||
// time no channel is registered signals are buffered up to a reasonable amount.
|
||||
// If there are buffered signals when a channel is registered, they will be
|
||||
// sent in order on the channel immediately after registering.
|
||||
Signals(c chan<- Signal)
|
||||
}
|
||||
|
||||
// maxSigBufSize is how many signals will be buffered
|
||||
// when there is no signal channel specified
|
||||
const maxSigBufSize = 128
|
||||
|
||||
func sessionHandler(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx *sshContext) {
|
||||
ch, reqs, err := newChan.Accept()
|
||||
if err != nil {
|
||||
@ -83,6 +94,7 @@ func sessionHandler(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChanne
|
||||
}
|
||||
|
||||
type session struct {
|
||||
sync.Mutex
|
||||
gossh.Channel
|
||||
conn *gossh.ServerConn
|
||||
handler Handler
|
||||
@ -94,6 +106,8 @@ type session struct {
|
||||
ptyCb PtyCallback
|
||||
cmd []string
|
||||
ctx *sshContext
|
||||
sigCh chan<- Signal
|
||||
sigBuf []Signal
|
||||
}
|
||||
|
||||
func (sess *session) Write(p []byte) (n int, err error) {
|
||||
@ -132,6 +146,8 @@ func (sess *session) Context() context.Context {
|
||||
}
|
||||
|
||||
func (sess *session) Exit(code int) error {
|
||||
sess.Lock()
|
||||
defer sess.Unlock()
|
||||
if sess.exited {
|
||||
return errors.New("Session.Exit called multiple times")
|
||||
}
|
||||
@ -172,6 +188,19 @@ func (sess *session) Pty() (Pty, <-chan Window, bool) {
|
||||
return Pty{}, sess.winch, false
|
||||
}
|
||||
|
||||
func (sess *session) Signals(c chan<- Signal) {
|
||||
sess.Lock()
|
||||
defer sess.Unlock()
|
||||
sess.sigCh = c
|
||||
if len(sess.sigBuf) > 0 {
|
||||
go func() {
|
||||
for _, sig := range sess.sigBuf {
|
||||
sess.sigCh <- sig
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func (sess *session) handleRequests(reqs <-chan *gossh.Request) {
|
||||
for req := range reqs {
|
||||
switch req.Type {
|
||||
@ -195,10 +224,22 @@ func (sess *session) handleRequests(reqs <-chan *gossh.Request) {
|
||||
req.Reply(false, nil)
|
||||
continue
|
||||
}
|
||||
var kv = struct{ Key, Value string }{}
|
||||
var kv struct{ Key, Value string }
|
||||
gossh.Unmarshal(req.Payload, &kv)
|
||||
sess.env = append(sess.env, fmt.Sprintf("%s=%s", kv.Key, kv.Value))
|
||||
req.Reply(true, nil)
|
||||
case "signal":
|
||||
var payload struct{ Signal string }
|
||||
gossh.Unmarshal(req.Payload, &payload)
|
||||
sess.Lock()
|
||||
if sess.sigCh != nil {
|
||||
sess.sigCh <- Signal(payload.Signal)
|
||||
} else {
|
||||
if len(sess.sigBuf) < maxSigBufSize {
|
||||
sess.sigBuf = append(sess.sigBuf, Signal(payload.Signal))
|
||||
}
|
||||
}
|
||||
sess.Unlock()
|
||||
case "pty-req":
|
||||
if sess.handled || sess.pty != nil {
|
||||
req.Reply(false, nil)
|
||||
|
@ -280,3 +280,35 @@ func TestPtyResize(t *testing.T) {
|
||||
session.Close()
|
||||
<-done
|
||||
}
|
||||
|
||||
func TestSignals(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
session, _, cleanup := newTestSession(t, &Server{
|
||||
Handler: func(s Session) {
|
||||
signals := make(chan Signal)
|
||||
s.Signals(signals)
|
||||
if sig := <-signals; sig != SIGINT {
|
||||
t.Fatalf("expected signal %v but got %v", SIGINT, sig)
|
||||
}
|
||||
exiter := make(chan bool)
|
||||
go func() {
|
||||
if sig := <-signals; sig == SIGKILL {
|
||||
close(exiter)
|
||||
}
|
||||
}()
|
||||
<-exiter
|
||||
},
|
||||
}, nil)
|
||||
defer cleanup()
|
||||
|
||||
go func() {
|
||||
session.Signal(gossh.SIGINT)
|
||||
session.Signal(gossh.SIGKILL)
|
||||
}()
|
||||
|
||||
err := session.Run("")
|
||||
if err != nil {
|
||||
t.Fatalf("expected nil but got %v", err)
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user