From f79e6921242741343082ce3bf25020483176ac14 Mon Sep 17 00:00:00 2001 From: Kaleb Elwert Date: Wed, 16 Oct 2019 10:07:55 -0700 Subject: [PATCH 1/4] Update AddHostKey to avoid always appending --- server.go | 12 ++++++++++++ server_test.go | 20 ++++++++++++++++++++ 2 files changed, 32 insertions(+) diff --git a/server.go b/server.go index cad0402..a10722a 100644 --- a/server.go +++ b/server.go @@ -315,6 +315,18 @@ func (srv *Server) ListenAndServe() error { func (srv *Server) AddHostKey(key Signer) { // these are later added via AddHostKey on ServerConfig, which performs the // check for one of every algorithm. + + // This check is based on the AddHostKey method from the x/crypto/ssh + // library. This allows us to only keep one active key for each type on a + // server at once. So, if you're dynamically updating keys at runtime, this + // list will not keep growing. + for i, k := range srv.HostSigners { + if k.PublicKey().Type() == key.PublicKey().Type() { + srv.HostSigners[i] = key + return + } + } + srv.HostSigners = append(srv.HostSigners, key) } diff --git a/server_test.go b/server_test.go index 558f171..8028a3a 100644 --- a/server_test.go +++ b/server_test.go @@ -8,6 +8,26 @@ import ( "time" ) +func TestAddHostKey(t *testing.T) { + s := Server{} + signer, err := generateSigner() + if err != nil { + t.Fatal(err) + } + s.AddHostKey(signer) + if len(s.HostSigners) != 1 { + t.Fatal("Key was not properly added") + } + signer, err = generateSigner() + if err != nil { + t.Fatal(err) + } + s.AddHostKey(signer) + if len(s.HostSigners) != 1 { + t.Fatal("Key was not properly replaced") + } +} + func TestServerShutdown(t *testing.T) { l := newLocalListener() testBytes := []byte("Hello world\n") From 38820366bb675bdf5eb4b6a4936fde23a2bf5e07 Mon Sep 17 00:00:00 2001 From: Kaleb Elwert Date: Wed, 16 Oct 2019 10:27:39 -0700 Subject: [PATCH 2/4] Start cleaning up config to fix race conditions --- server.go | 26 +++++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/server.go b/server.go index a10722a..359f967 100644 --- a/server.go +++ b/server.go @@ -58,7 +58,7 @@ type Server struct { RequestHandlers map[string]RequestHandler listenerWg sync.WaitGroup - mu sync.Mutex + mu sync.RWMutex listeners map[net.Listener]struct{} conns map[*gossh.ServerConn]struct{} connWg sync.WaitGroup @@ -66,6 +66,9 @@ type Server struct { } func (srv *Server) ensureHostSigner() error { + srv.mu.Lock() + defer srv.mu.Unlock() + if len(srv.HostSigners) == 0 { signer, err := generateSigner() if err != nil { @@ -79,6 +82,7 @@ func (srv *Server) ensureHostSigner() error { func (srv *Server) ensureHandlers() { srv.mu.Lock() defer srv.mu.Unlock() + if srv.RequestHandlers == nil { srv.RequestHandlers = map[string]RequestHandler{} for k, v := range DefaultRequestHandlers { @@ -94,6 +98,9 @@ func (srv *Server) ensureHandlers() { } func (srv *Server) config(ctx Context) *gossh.ServerConfig { + srv.mu.RLock() + defer srv.mu.RUnlock() + var config *gossh.ServerConfig if srv.ServerConfigCallback == nil { config = &gossh.ServerConfig{} @@ -142,6 +149,9 @@ func (srv *Server) config(ctx Context) *gossh.ServerConfig { // Handle sets the Handler for the server. func (srv *Server) Handle(fn Handler) { + srv.mu.Lock() + defer srv.mu.Unlock() + srv.Handler = fn } @@ -153,6 +163,7 @@ func (srv *Server) Handle(fn Handler) { func (srv *Server) Close() error { srv.mu.Lock() defer srv.mu.Unlock() + srv.closeDoneChanLocked() err := srv.closeListenersLocked() for c := range srv.conns { @@ -313,6 +324,9 @@ func (srv *Server) ListenAndServe() error { // with the same algorithm, it is overwritten. Each server config must have at // least one host key. func (srv *Server) AddHostKey(key Signer) { + srv.mu.Lock() + defer srv.mu.Unlock() + // these are later added via AddHostKey on ServerConfig, which performs the // check for one of every algorithm. @@ -332,12 +346,20 @@ func (srv *Server) AddHostKey(key Signer) { // SetOption runs a functional option against the server. func (srv *Server) SetOption(option Option) error { + // NOTE: there is a potential race here for any option that doesn't call an + // internal method. We can't actually lock here because if something calls + // (as an example) AddHostKey, it will deadlock. + + //srv.mu.Lock() + //defer srv.mu.Unlock() + return option(srv) } func (srv *Server) getDoneChan() <-chan struct{} { srv.mu.Lock() defer srv.mu.Unlock() + return srv.getDoneChanLocked() } @@ -374,6 +396,7 @@ func (srv *Server) closeListenersLocked() error { func (srv *Server) trackListener(ln net.Listener, add bool) { srv.mu.Lock() defer srv.mu.Unlock() + if srv.listeners == nil { srv.listeners = make(map[net.Listener]struct{}) } @@ -394,6 +417,7 @@ func (srv *Server) trackListener(ln net.Listener, add bool) { func (srv *Server) trackConn(c *gossh.ServerConn, add bool) { srv.mu.Lock() defer srv.mu.Unlock() + if srv.conns == nil { srv.conns = make(map[*gossh.ServerConn]struct{}) } From be3a169b0cb7b749fc5eaee887ef2f6ed2d60ee2 Mon Sep 17 00:00:00 2001 From: Kaleb Elwert Date: Wed, 16 Oct 2019 10:42:43 -0700 Subject: [PATCH 3/4] Fix TestSignals to remove a possible race --- session_test.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/session_test.go b/session_test.go index f086792..9d3cc50 100644 --- a/session_test.go +++ b/session_test.go @@ -291,7 +291,9 @@ func TestSignals(t *testing.T) { session, _, cleanup := newTestSession(t, &Server{ Handler: func(s Session) { - signals := make(chan Signal) + // We need to use a buffered channel here, otherwise it's possible for the + // second call to Signal to get discarded. + signals := make(chan Signal, 2) s.Signals(signals) if sig := <-signals; sig != SIGINT { t.Fatalf("expected signal %v but got %v", SIGINT, sig) From 1db07d8a37deb5b74169d1aee5638aa23648bdb0 Mon Sep 17 00:00:00 2001 From: Kaleb Elwert Date: Wed, 16 Oct 2019 11:22:57 -0700 Subject: [PATCH 4/4] Make TestSignals a bit more bulletproof --- session_test.go | 44 ++++++++++++++++++++++++++++++++++---------- 1 file changed, 34 insertions(+), 10 deletions(-) diff --git a/session_test.go b/session_test.go index 9d3cc50..786a661 100644 --- a/session_test.go +++ b/session_test.go @@ -289,22 +289,40 @@ func TestPtyResize(t *testing.T) { func TestSignals(t *testing.T) { t.Parallel() + // errChan lets us get errors back from the session + errChan := make(chan error, 5) + + // doneChan lets us specify that we should exit. + doneChan := make(chan interface{}) + session, _, cleanup := newTestSession(t, &Server{ Handler: func(s Session) { // We need to use a buffered channel here, otherwise it's possible for the // second call to Signal to get discarded. signals := make(chan Signal, 2) s.Signals(signals) - if sig := <-signals; sig != SIGINT { - t.Fatalf("expected signal %v but got %v", SIGINT, sig) - } - exiter := make(chan bool) - go func() { - if sig := <-signals; sig == SIGKILL { - close(exiter) + + select { + case sig := <-signals: + if sig != SIGINT { + errChan <- fmt.Errorf("expected signal %v but got %v", SIGINT, sig) + return } - }() - <-exiter + case <-doneChan: + errChan <- fmt.Errorf("Unexpected done") + return + } + + select { + case sig := <-signals: + if sig != SIGKILL { + errChan <- fmt.Errorf("expected signal %v but got %v", SIGKILL, sig) + return + } + case <-doneChan: + errChan <- fmt.Errorf("Unexpected done") + return + } }, }, nil) defer cleanup() @@ -314,7 +332,13 @@ func TestSignals(t *testing.T) { session.Signal(gossh.SIGKILL) }() - err := session.Run("") + go func() { + errChan <- session.Run("") + }() + + err := <-errChan + close(doneChan) + if err != nil { t.Fatalf("expected nil but got %v", err) }