Merge pull request #123 from gliderlabs/optimize-add-host-key
Update AddHostKey to avoid always appending
This commit is contained in:
commit
59d6e4540d
38
server.go
38
server.go
@ -58,7 +58,7 @@ type Server struct {
|
||||
RequestHandlers map[string]RequestHandler
|
||||
|
||||
listenerWg sync.WaitGroup
|
||||
mu sync.Mutex
|
||||
mu sync.RWMutex
|
||||
listeners map[net.Listener]struct{}
|
||||
conns map[*gossh.ServerConn]struct{}
|
||||
connWg sync.WaitGroup
|
||||
@ -66,6 +66,9 @@ type Server struct {
|
||||
}
|
||||
|
||||
func (srv *Server) ensureHostSigner() error {
|
||||
srv.mu.Lock()
|
||||
defer srv.mu.Unlock()
|
||||
|
||||
if len(srv.HostSigners) == 0 {
|
||||
signer, err := generateSigner()
|
||||
if err != nil {
|
||||
@ -79,6 +82,7 @@ func (srv *Server) ensureHostSigner() error {
|
||||
func (srv *Server) ensureHandlers() {
|
||||
srv.mu.Lock()
|
||||
defer srv.mu.Unlock()
|
||||
|
||||
if srv.RequestHandlers == nil {
|
||||
srv.RequestHandlers = map[string]RequestHandler{}
|
||||
for k, v := range DefaultRequestHandlers {
|
||||
@ -94,6 +98,9 @@ func (srv *Server) ensureHandlers() {
|
||||
}
|
||||
|
||||
func (srv *Server) config(ctx Context) *gossh.ServerConfig {
|
||||
srv.mu.RLock()
|
||||
defer srv.mu.RUnlock()
|
||||
|
||||
var config *gossh.ServerConfig
|
||||
if srv.ServerConfigCallback == nil {
|
||||
config = &gossh.ServerConfig{}
|
||||
@ -142,6 +149,9 @@ func (srv *Server) config(ctx Context) *gossh.ServerConfig {
|
||||
|
||||
// Handle sets the Handler for the server.
|
||||
func (srv *Server) Handle(fn Handler) {
|
||||
srv.mu.Lock()
|
||||
defer srv.mu.Unlock()
|
||||
|
||||
srv.Handler = fn
|
||||
}
|
||||
|
||||
@ -153,6 +163,7 @@ func (srv *Server) Handle(fn Handler) {
|
||||
func (srv *Server) Close() error {
|
||||
srv.mu.Lock()
|
||||
defer srv.mu.Unlock()
|
||||
|
||||
srv.closeDoneChanLocked()
|
||||
err := srv.closeListenersLocked()
|
||||
for c := range srv.conns {
|
||||
@ -313,19 +324,42 @@ func (srv *Server) ListenAndServe() error {
|
||||
// with the same algorithm, it is overwritten. Each server config must have at
|
||||
// least one host key.
|
||||
func (srv *Server) AddHostKey(key Signer) {
|
||||
srv.mu.Lock()
|
||||
defer srv.mu.Unlock()
|
||||
|
||||
// these are later added via AddHostKey on ServerConfig, which performs the
|
||||
// check for one of every algorithm.
|
||||
|
||||
// This check is based on the AddHostKey method from the x/crypto/ssh
|
||||
// library. This allows us to only keep one active key for each type on a
|
||||
// server at once. So, if you're dynamically updating keys at runtime, this
|
||||
// list will not keep growing.
|
||||
for i, k := range srv.HostSigners {
|
||||
if k.PublicKey().Type() == key.PublicKey().Type() {
|
||||
srv.HostSigners[i] = key
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
srv.HostSigners = append(srv.HostSigners, key)
|
||||
}
|
||||
|
||||
// SetOption runs a functional option against the server.
|
||||
func (srv *Server) SetOption(option Option) error {
|
||||
// NOTE: there is a potential race here for any option that doesn't call an
|
||||
// internal method. We can't actually lock here because if something calls
|
||||
// (as an example) AddHostKey, it will deadlock.
|
||||
|
||||
//srv.mu.Lock()
|
||||
//defer srv.mu.Unlock()
|
||||
|
||||
return option(srv)
|
||||
}
|
||||
|
||||
func (srv *Server) getDoneChan() <-chan struct{} {
|
||||
srv.mu.Lock()
|
||||
defer srv.mu.Unlock()
|
||||
|
||||
return srv.getDoneChanLocked()
|
||||
}
|
||||
|
||||
@ -362,6 +396,7 @@ func (srv *Server) closeListenersLocked() error {
|
||||
func (srv *Server) trackListener(ln net.Listener, add bool) {
|
||||
srv.mu.Lock()
|
||||
defer srv.mu.Unlock()
|
||||
|
||||
if srv.listeners == nil {
|
||||
srv.listeners = make(map[net.Listener]struct{})
|
||||
}
|
||||
@ -382,6 +417,7 @@ func (srv *Server) trackListener(ln net.Listener, add bool) {
|
||||
func (srv *Server) trackConn(c *gossh.ServerConn, add bool) {
|
||||
srv.mu.Lock()
|
||||
defer srv.mu.Unlock()
|
||||
|
||||
if srv.conns == nil {
|
||||
srv.conns = make(map[*gossh.ServerConn]struct{})
|
||||
}
|
||||
|
@ -8,6 +8,26 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestAddHostKey(t *testing.T) {
|
||||
s := Server{}
|
||||
signer, err := generateSigner()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
s.AddHostKey(signer)
|
||||
if len(s.HostSigners) != 1 {
|
||||
t.Fatal("Key was not properly added")
|
||||
}
|
||||
signer, err = generateSigner()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
s.AddHostKey(signer)
|
||||
if len(s.HostSigners) != 1 {
|
||||
t.Fatal("Key was not properly replaced")
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerShutdown(t *testing.T) {
|
||||
l := newLocalListener()
|
||||
testBytes := []byte("Hello world\n")
|
||||
|
@ -289,20 +289,40 @@ func TestPtyResize(t *testing.T) {
|
||||
func TestSignals(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// errChan lets us get errors back from the session
|
||||
errChan := make(chan error, 5)
|
||||
|
||||
// doneChan lets us specify that we should exit.
|
||||
doneChan := make(chan interface{})
|
||||
|
||||
session, _, cleanup := newTestSession(t, &Server{
|
||||
Handler: func(s Session) {
|
||||
signals := make(chan Signal)
|
||||
// We need to use a buffered channel here, otherwise it's possible for the
|
||||
// second call to Signal to get discarded.
|
||||
signals := make(chan Signal, 2)
|
||||
s.Signals(signals)
|
||||
if sig := <-signals; sig != SIGINT {
|
||||
t.Fatalf("expected signal %v but got %v", SIGINT, sig)
|
||||
}
|
||||
exiter := make(chan bool)
|
||||
go func() {
|
||||
if sig := <-signals; sig == SIGKILL {
|
||||
close(exiter)
|
||||
|
||||
select {
|
||||
case sig := <-signals:
|
||||
if sig != SIGINT {
|
||||
errChan <- fmt.Errorf("expected signal %v but got %v", SIGINT, sig)
|
||||
return
|
||||
}
|
||||
}()
|
||||
<-exiter
|
||||
case <-doneChan:
|
||||
errChan <- fmt.Errorf("Unexpected done")
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case sig := <-signals:
|
||||
if sig != SIGKILL {
|
||||
errChan <- fmt.Errorf("expected signal %v but got %v", SIGKILL, sig)
|
||||
return
|
||||
}
|
||||
case <-doneChan:
|
||||
errChan <- fmt.Errorf("Unexpected done")
|
||||
return
|
||||
}
|
||||
},
|
||||
}, nil)
|
||||
defer cleanup()
|
||||
@ -312,7 +332,13 @@ func TestSignals(t *testing.T) {
|
||||
session.Signal(gossh.SIGKILL)
|
||||
}()
|
||||
|
||||
err := session.Run("")
|
||||
go func() {
|
||||
errChan <- session.Run("")
|
||||
}()
|
||||
|
||||
err := <-errChan
|
||||
close(doneChan)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("expected nil but got %v", err)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user