Register chan to Session to listen for break requests (#141)
Co-authored-by: Jacob Meisler <meislerj@amazon.com>
This commit is contained in:
parent
76cadaa318
commit
fb34512070
22
session.go
22
session.go
@ -77,6 +77,12 @@ type Session interface {
|
||||
// If there are buffered signals when a channel is registered, they will be
|
||||
// sent in order on the channel immediately after registering.
|
||||
Signals(c chan<- Signal)
|
||||
|
||||
// Break regisers a channel to receive notifications of break requests sent
|
||||
// from the client. The channel must handle break requests, or it will block
|
||||
// the request handling loop. Registering nil will unregister the channel.
|
||||
// During the time that no channel is registered, breaks are ignored.
|
||||
Break(c chan<- bool)
|
||||
}
|
||||
|
||||
// maxSigBufSize is how many signals will be buffered
|
||||
@ -119,6 +125,7 @@ type session struct {
|
||||
ctx Context
|
||||
sigCh chan<- Signal
|
||||
sigBuf []Signal
|
||||
breakCh chan<- bool
|
||||
}
|
||||
|
||||
func (sess *session) Write(p []byte) (n int, err error) {
|
||||
@ -221,6 +228,12 @@ func (sess *session) Signals(c chan<- Signal) {
|
||||
}
|
||||
}
|
||||
|
||||
func (sess *session) Break(c chan<- bool) {
|
||||
sess.Lock()
|
||||
defer sess.Unlock()
|
||||
sess.breakCh = c
|
||||
}
|
||||
|
||||
func (sess *session) handleRequests(reqs <-chan *gossh.Request) {
|
||||
for req := range reqs {
|
||||
switch req.Type {
|
||||
@ -344,6 +357,15 @@ func (sess *session) handleRequests(reqs <-chan *gossh.Request) {
|
||||
// TODO: option/callback to allow agent forwarding
|
||||
SetAgentRequested(sess.ctx)
|
||||
req.Reply(true, nil)
|
||||
case "break":
|
||||
ok := false
|
||||
sess.Lock()
|
||||
if sess.breakCh != nil {
|
||||
sess.breakCh <- true
|
||||
ok = true
|
||||
}
|
||||
req.Reply(ok, nil)
|
||||
sess.Unlock()
|
||||
default:
|
||||
// TODO: debug log
|
||||
req.Reply(false, nil)
|
||||
|
@ -343,3 +343,96 @@ func TestSignals(t *testing.T) {
|
||||
t.Fatalf("expected nil but got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBreakWithChanRegistered(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// errChan lets us get errors back from the session
|
||||
errChan := make(chan error, 5)
|
||||
|
||||
// doneChan lets us specify that we should exit.
|
||||
doneChan := make(chan interface{})
|
||||
|
||||
breakChan := make(chan bool)
|
||||
|
||||
readyToReceiveBreak := make(chan bool)
|
||||
|
||||
session, _, cleanup := newTestSession(t, &Server{
|
||||
Handler: func(s Session) {
|
||||
s.Break(breakChan) // register a break channel with the session
|
||||
readyToReceiveBreak <- true
|
||||
|
||||
select {
|
||||
case <-breakChan:
|
||||
io.WriteString(s, "break")
|
||||
case <-doneChan:
|
||||
errChan <- fmt.Errorf("Unexpected done")
|
||||
return
|
||||
}
|
||||
},
|
||||
}, nil)
|
||||
defer cleanup()
|
||||
var stdout bytes.Buffer
|
||||
session.Stdout = &stdout
|
||||
go func() {
|
||||
errChan <- session.Run("")
|
||||
}()
|
||||
|
||||
<-readyToReceiveBreak
|
||||
ok, err := session.SendRequest("break", true, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("expected nil but got %v", err)
|
||||
}
|
||||
if ok != true {
|
||||
t.Fatalf("expected true but got %v", ok)
|
||||
}
|
||||
|
||||
err = <-errChan
|
||||
close(doneChan)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("expected nil but got %v", err)
|
||||
}
|
||||
if !bytes.Equal(stdout.Bytes(), []byte("break")) {
|
||||
t.Fatalf("stdout = %#v, expected 'break'", stdout.Bytes())
|
||||
}
|
||||
}
|
||||
|
||||
func TestBreakWithoutChanRegistered(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// errChan lets us get errors back from the session
|
||||
errChan := make(chan error, 5)
|
||||
|
||||
// doneChan lets us specify that we should exit.
|
||||
doneChan := make(chan interface{})
|
||||
|
||||
waitUntilAfterBreakSent := make(chan bool)
|
||||
|
||||
session, _, cleanup := newTestSession(t, &Server{
|
||||
Handler: func(s Session) {
|
||||
<-waitUntilAfterBreakSent
|
||||
},
|
||||
}, nil)
|
||||
defer cleanup()
|
||||
var stdout bytes.Buffer
|
||||
session.Stdout = &stdout
|
||||
go func() {
|
||||
errChan <- session.Run("")
|
||||
}()
|
||||
|
||||
ok, err := session.SendRequest("break", true, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("expected nil but got %v", err)
|
||||
}
|
||||
if ok != false {
|
||||
t.Fatalf("expected false but got %v", ok)
|
||||
}
|
||||
waitUntilAfterBreakSent <- true
|
||||
|
||||
err = <-errChan
|
||||
close(doneChan)
|
||||
if err != nil {
|
||||
t.Fatalf("expected nil but got %v", err)
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user