session: adding signal handling support (#44)
This commit is contained in:
parent
4a4de396c4
commit
3eeacb7850
45
session.go
45
session.go
@ -6,6 +6,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"sync"
|
||||||
|
|
||||||
"github.com/anmitsu/go-shlex"
|
"github.com/anmitsu/go-shlex"
|
||||||
gossh "golang.org/x/crypto/ssh"
|
gossh "golang.org/x/crypto/ssh"
|
||||||
@ -63,9 +64,19 @@ type Session interface {
|
|||||||
// of whether or not a PTY was accepted for this session.
|
// of whether or not a PTY was accepted for this session.
|
||||||
Pty() (Pty, <-chan Window, bool)
|
Pty() (Pty, <-chan Window, bool)
|
||||||
|
|
||||||
// TODO: Signals(c chan<- Signal)
|
// Signals registers a channel to receive signals sent from the client. The
|
||||||
|
// channel must handle signal sends or it will block the SSH request loop.
|
||||||
|
// Registering nil will unregister the channel from signal sends. During the
|
||||||
|
// time no channel is registered signals are buffered up to a reasonable amount.
|
||||||
|
// If there are buffered signals when a channel is registered, they will be
|
||||||
|
// sent in order on the channel immediately after registering.
|
||||||
|
Signals(c chan<- Signal)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// maxSigBufSize is how many signals will be buffered
|
||||||
|
// when there is no signal channel specified
|
||||||
|
const maxSigBufSize = 128
|
||||||
|
|
||||||
func sessionHandler(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx *sshContext) {
|
func sessionHandler(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx *sshContext) {
|
||||||
ch, reqs, err := newChan.Accept()
|
ch, reqs, err := newChan.Accept()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -83,6 +94,7 @@ func sessionHandler(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChanne
|
|||||||
}
|
}
|
||||||
|
|
||||||
type session struct {
|
type session struct {
|
||||||
|
sync.Mutex
|
||||||
gossh.Channel
|
gossh.Channel
|
||||||
conn *gossh.ServerConn
|
conn *gossh.ServerConn
|
||||||
handler Handler
|
handler Handler
|
||||||
@ -94,6 +106,8 @@ type session struct {
|
|||||||
ptyCb PtyCallback
|
ptyCb PtyCallback
|
||||||
cmd []string
|
cmd []string
|
||||||
ctx *sshContext
|
ctx *sshContext
|
||||||
|
sigCh chan<- Signal
|
||||||
|
sigBuf []Signal
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sess *session) Write(p []byte) (n int, err error) {
|
func (sess *session) Write(p []byte) (n int, err error) {
|
||||||
@ -132,6 +146,8 @@ func (sess *session) Context() context.Context {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (sess *session) Exit(code int) error {
|
func (sess *session) Exit(code int) error {
|
||||||
|
sess.Lock()
|
||||||
|
defer sess.Unlock()
|
||||||
if sess.exited {
|
if sess.exited {
|
||||||
return errors.New("Session.Exit called multiple times")
|
return errors.New("Session.Exit called multiple times")
|
||||||
}
|
}
|
||||||
@ -172,6 +188,19 @@ func (sess *session) Pty() (Pty, <-chan Window, bool) {
|
|||||||
return Pty{}, sess.winch, false
|
return Pty{}, sess.winch, false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (sess *session) Signals(c chan<- Signal) {
|
||||||
|
sess.Lock()
|
||||||
|
defer sess.Unlock()
|
||||||
|
sess.sigCh = c
|
||||||
|
if len(sess.sigBuf) > 0 {
|
||||||
|
go func() {
|
||||||
|
for _, sig := range sess.sigBuf {
|
||||||
|
sess.sigCh <- sig
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (sess *session) handleRequests(reqs <-chan *gossh.Request) {
|
func (sess *session) handleRequests(reqs <-chan *gossh.Request) {
|
||||||
for req := range reqs {
|
for req := range reqs {
|
||||||
switch req.Type {
|
switch req.Type {
|
||||||
@ -195,10 +224,22 @@ func (sess *session) handleRequests(reqs <-chan *gossh.Request) {
|
|||||||
req.Reply(false, nil)
|
req.Reply(false, nil)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
var kv = struct{ Key, Value string }{}
|
var kv struct{ Key, Value string }
|
||||||
gossh.Unmarshal(req.Payload, &kv)
|
gossh.Unmarshal(req.Payload, &kv)
|
||||||
sess.env = append(sess.env, fmt.Sprintf("%s=%s", kv.Key, kv.Value))
|
sess.env = append(sess.env, fmt.Sprintf("%s=%s", kv.Key, kv.Value))
|
||||||
req.Reply(true, nil)
|
req.Reply(true, nil)
|
||||||
|
case "signal":
|
||||||
|
var payload struct{ Signal string }
|
||||||
|
gossh.Unmarshal(req.Payload, &payload)
|
||||||
|
sess.Lock()
|
||||||
|
if sess.sigCh != nil {
|
||||||
|
sess.sigCh <- Signal(payload.Signal)
|
||||||
|
} else {
|
||||||
|
if len(sess.sigBuf) < maxSigBufSize {
|
||||||
|
sess.sigBuf = append(sess.sigBuf, Signal(payload.Signal))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sess.Unlock()
|
||||||
case "pty-req":
|
case "pty-req":
|
||||||
if sess.handled || sess.pty != nil {
|
if sess.handled || sess.pty != nil {
|
||||||
req.Reply(false, nil)
|
req.Reply(false, nil)
|
||||||
|
@ -280,3 +280,35 @@ func TestPtyResize(t *testing.T) {
|
|||||||
session.Close()
|
session.Close()
|
||||||
<-done
|
<-done
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSignals(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
session, _, cleanup := newTestSession(t, &Server{
|
||||||
|
Handler: func(s Session) {
|
||||||
|
signals := make(chan Signal)
|
||||||
|
s.Signals(signals)
|
||||||
|
if sig := <-signals; sig != SIGINT {
|
||||||
|
t.Fatalf("expected signal %v but got %v", SIGINT, sig)
|
||||||
|
}
|
||||||
|
exiter := make(chan bool)
|
||||||
|
go func() {
|
||||||
|
if sig := <-signals; sig == SIGKILL {
|
||||||
|
close(exiter)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
<-exiter
|
||||||
|
},
|
||||||
|
}, nil)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
session.Signal(gossh.SIGINT)
|
||||||
|
session.Signal(gossh.SIGKILL)
|
||||||
|
}()
|
||||||
|
|
||||||
|
err := session.Run("")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected nil but got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user