diff --git a/server.go b/server.go index 1c5f290..a18e695 100644 --- a/server.go +++ b/server.go @@ -1,13 +1,20 @@ package ssh import ( + "context" + "errors" "fmt" "net" + "sync" "time" gossh "golang.org/x/crypto/ssh" ) +// ErrServerClosed is returned by the Server's Serve, ListenAndServe, +// and ListenAndServeTLS methods after a call to Shutdown or Close. +var ErrServerClosed = errors.New("ssh: Server closed") + // Server defines parameters for running an SSH server. The zero value for // Server is a valid configuration. When both PasswordHandler and // PublicKeyHandler are nil, no client authentication is performed. @@ -23,6 +30,11 @@ type Server struct { LocalPortForwardingCallback LocalPortForwardingCallback // callback for allowing local port forwarding, denies all if nil channelHandlers map[string]channelHandler + + mu sync.Mutex + listeners map[net.Listener]struct{} + conns map[*gossh.ServerConn]struct{} + doneChan chan struct{} } // internal for now @@ -81,6 +93,60 @@ func (srv *Server) Handle(fn Handler) { srv.Handler = fn } +// Close immediately closes all active listeners and all active +// connections. +// +// Close returns any error returned from closing the Server's +// underlying Listener(s). +func (srv *Server) Close() error { + srv.mu.Lock() + defer srv.mu.Unlock() + srv.closeDoneChanLocked() + err := srv.closeListenersLocked() + for c := range srv.conns { + c.Close() + delete(srv.conns, c) + } + return err +} + +// shutdownPollInterval is how often we poll for quiescence +// during Server.Shutdown. This is lower during tests, to +// speed up tests. +// Ideally we could find a solution that doesn't involve polling, +// but which also doesn't have a high runtime cost (and doesn't +// involve any contentious mutexes), but that is left as an +// exercise for the reader. +var shutdownPollInterval = 500 * time.Millisecond + +// Shutdown gracefully shuts down the server without interrupting any +// active connections. Shutdown works by first closing all open +// listeners, and then waiting indefinitely for connections to close. +// If the provided context expires before the shutdown is complete, +// then the context's error is returned. +func (srv *Server) Shutdown(ctx context.Context) error { + srv.mu.Lock() + lnerr := srv.closeListenersLocked() + srv.closeDoneChanLocked() + srv.mu.Unlock() + ticker := time.NewTicker(shutdownPollInterval) + defer ticker.Stop() + for { + srv.mu.Lock() + conns := len(srv.conns) + srv.mu.Unlock() + if conns == 0 { + return lnerr + } + select { + case <-ctx.Done(): + return ctx.Err() + case <-ticker.C: + } + } + +} + // Serve accepts incoming connections on the Listener l, creating a new // connection goroutine for each. The connection goroutines read requests and then // calls srv.Handler to handle sessions. @@ -95,9 +161,17 @@ func (srv *Server) Serve(l net.Listener) error { srv.Handler = DefaultHandler } var tempDelay time.Duration + + srv.trackListener(l, true) + defer srv.trackListener(l, false) for { conn, e := l.Accept() if e != nil { + select { + case <-srv.getDoneChan(): + return ErrServerClosed + default: + } if ne, ok := e.(net.Error); ok && ne.Temporary() { if tempDelay == 0 { tempDelay = 5 * time.Millisecond @@ -124,6 +198,10 @@ func (srv *Server) handleConn(conn net.Conn) { // TODO: trigger event callback return } + + srv.trackConn(sshConn, true) + defer srv.trackConn(sshConn, false) + ctx.SetValue(ContextKeyConn, sshConn) ctx.applyConnMetadata(sshConn) go gossh.DiscardRequests(reqs) @@ -165,3 +243,70 @@ func (srv *Server) AddHostKey(key Signer) { func (srv *Server) SetOption(option Option) error { return option(srv) } + +func (srv *Server) getDoneChan() <-chan struct{} { + srv.mu.Lock() + defer srv.mu.Unlock() + return srv.getDoneChanLocked() +} + +func (srv *Server) getDoneChanLocked() chan struct{} { + if srv.doneChan == nil { + srv.doneChan = make(chan struct{}) + } + return srv.doneChan +} + +func (srv *Server) closeDoneChanLocked() { + ch := srv.getDoneChanLocked() + select { + case <-ch: + // Already closed. Don't close again. + default: + // Safe to close here. We're the only closer, guarded + // by srv.mu. + close(ch) + } +} + +func (srv *Server) closeListenersLocked() error { + var err error + for ln := range srv.listeners { + if cerr := ln.Close(); cerr != nil && err == nil { + err = cerr + } + delete(srv.listeners, ln) + } + return err +} + +func (srv *Server) trackListener(ln net.Listener, add bool) { + srv.mu.Lock() + defer srv.mu.Unlock() + if srv.listeners == nil { + srv.listeners = make(map[net.Listener]struct{}) + } + if add { + // If the *Server is being reused after a previous + // Close or Shutdown, reset its doneChan: + if len(srv.listeners) == 0 && len(srv.conns) == 0 { + srv.doneChan = nil + } + srv.listeners[ln] = struct{}{} + } else { + delete(srv.listeners, ln) + } +} + +func (srv *Server) trackConn(c *gossh.ServerConn, add bool) { + srv.mu.Lock() + defer srv.mu.Unlock() + if srv.conns == nil { + srv.conns = make(map[*gossh.ServerConn]struct{}) + } + if add { + srv.conns[c] = struct{}{} + } else { + delete(srv.conns, c) + } +} diff --git a/server_test.go b/server_test.go new file mode 100644 index 0000000..838f3f8 --- /dev/null +++ b/server_test.go @@ -0,0 +1,102 @@ +package ssh + +import ( + "bytes" + "context" + "io" + "testing" + "time" +) + +func TestServerShutdown(t *testing.T) { + l := newLocalListener() + testBytes := []byte("Hello world\n") + s := &Server{ + Handler: func(s Session) { + s.Write(testBytes) + time.Sleep(50 * time.Millisecond) + }, + } + go func() { + err := s.Serve(l) + if err != nil && err != ErrServerClosed { + t.Fatal(err) + } + }() + sessDone := make(chan struct{}) + sess, cleanup := newClientSession(t, l.Addr().String(), nil) + go func() { + defer cleanup() + defer close(sessDone) + var stdout bytes.Buffer + sess.Stdout = &stdout + if err := sess.Run(""); err != nil { + t.Fatal(err) + } + if !bytes.Equal(stdout.Bytes(), testBytes) { + t.Fatalf("expected = %s; got %s", testBytes, stdout.Bytes()) + } + }() + + srvDone := make(chan struct{}) + go func() { + defer close(srvDone) + err := s.Shutdown(context.Background()) + if err != nil { + t.Fatal(err) + } + }() + + timeout := time.After(2 * time.Second) + select { + case <-timeout: + t.Fatal("timeout") + return + case <-srvDone: + // TODO: add timeout for sessDone + <-sessDone + return + } +} + +func TestServerClose(t *testing.T) { + l := newLocalListener() + s := &Server{ + Handler: func(s Session) { + time.Sleep(5 * time.Second) + }, + } + go func() { + err := s.Serve(l) + if err != nil && err != ErrServerClosed { + t.Fatal(err) + } + }() + + doneCh := make(chan struct{}) + sess, cleanup := newClientSession(t, l.Addr().String(), nil) + go func() { + defer cleanup() + defer close(doneCh) + if err := sess.Run(""); err != nil && err != io.EOF { + t.Fatal(err) + } + }() + + go func() { + err := s.Close() + if err != nil { + t.Fatal(err) + } + }() + + timeout := time.After(100 * time.Millisecond) + select { + case <-timeout: + t.Error("timeout") + return + case <-s.getDoneChan(): + <-doneCh + return + } +}