diff --git a/session_test.go b/session_test.go index 9d3cc50..786a661 100644 --- a/session_test.go +++ b/session_test.go @@ -289,22 +289,40 @@ func TestPtyResize(t *testing.T) { func TestSignals(t *testing.T) { t.Parallel() + // errChan lets us get errors back from the session + errChan := make(chan error, 5) + + // doneChan lets us specify that we should exit. + doneChan := make(chan interface{}) + session, _, cleanup := newTestSession(t, &Server{ Handler: func(s Session) { // We need to use a buffered channel here, otherwise it's possible for the // second call to Signal to get discarded. signals := make(chan Signal, 2) s.Signals(signals) - if sig := <-signals; sig != SIGINT { - t.Fatalf("expected signal %v but got %v", SIGINT, sig) - } - exiter := make(chan bool) - go func() { - if sig := <-signals; sig == SIGKILL { - close(exiter) + + select { + case sig := <-signals: + if sig != SIGINT { + errChan <- fmt.Errorf("expected signal %v but got %v", SIGINT, sig) + return } - }() - <-exiter + case <-doneChan: + errChan <- fmt.Errorf("Unexpected done") + return + } + + select { + case sig := <-signals: + if sig != SIGKILL { + errChan <- fmt.Errorf("expected signal %v but got %v", SIGKILL, sig) + return + } + case <-doneChan: + errChan <- fmt.Errorf("Unexpected done") + return + } }, }, nil) defer cleanup() @@ -314,7 +332,13 @@ func TestSignals(t *testing.T) { session.Signal(gossh.SIGKILL) }() - err := session.Run("") + go func() { + errChan <- session.Run("") + }() + + err := <-errChan + close(doneChan) + if err != nil { t.Fatalf("expected nil but got %v", err) }