diff --git a/_examples/ssh-timeouts/timeouts.go b/_examples/ssh-timeouts/timeouts.go index 224b1e9..1dba09f 100644 --- a/_examples/ssh-timeouts/timeouts.go +++ b/_examples/ssh-timeouts/timeouts.go @@ -1,58 +1,41 @@ package main import ( - "fmt" "log" - "net" "time" "github.com/gliderlabs/ssh" ) var ( - MaxLifeTimeout = 30 * time.Second - IdleTimeout = 5 * time.Second + DeadlineTimeout = 30 * time.Second + IdleTimeout = 10 * time.Second ) -type timeoutConn struct { - net.Conn - maxlife time.Time - idle time.Time -} - -func (c *timeoutConn) Write(p []byte) (n int, err error) { - c.updateDeadline() - return c.Conn.Write(p) -} - -func (c *timeoutConn) Read(b []byte) (n int, err error) { - c.idle = time.Now().Add(IdleTimeout) - c.updateDeadline() - return c.Conn.Read(b) -} - -func (c *timeoutConn) updateDeadline() { - if c.idle.Unix() < c.maxlife.Unix() { - c.Conn.SetDeadline(c.idle) - } else { - c.Conn.SetDeadline(c.maxlife) - } -} - func main() { ssh.Handle(func(s ssh.Session) { + log.Println("new connection") i := 0 for { i += 1 - fmt.Fprintln(s, i) - time.Sleep(time.Second) + log.Println("active seconds:", i) + select { + case <-time.After(time.Second): + continue + case <-s.Context().Done(): + log.Println("connection closed") + return + } } }) log.Println("starting ssh server on port 2222...") - log.Printf("connections will only last %s\n", MaxLifeTimeout) - log.Printf("and timeout after %s of no client activity\n", IdleTimeout) - log.Fatal(ssh.ListenAndServe(":2222", nil, ssh.WrapConn(func(conn net.Conn) net.Conn { - return &timeoutConn{conn, time.Now().Add(MaxLifeTimeout), time.Now().Add(IdleTimeout)} - }))) + log.Printf("connections will only last %s\n", DeadlineTimeout) + log.Printf("and timeout after %s of no activity\n", IdleTimeout) + server := &ssh.Server{ + Addr: ":2222", + MaxTimeout: DeadlineTimeout, + IdleTimeout: IdleTimeout, + } + log.Fatal(server.ListenAndServe()) } diff --git a/conn.go b/conn.go new file mode 100644 index 0000000..da84976 --- /dev/null +++ b/conn.go @@ -0,0 +1,50 @@ +package ssh + +import ( + "context" + "net" + "time" +) + +type serverConn struct { + net.Conn + + idleTimeout time.Duration + maxDeadline time.Time + closeCanceler context.CancelFunc +} + +func (c *serverConn) Write(p []byte) (n int, err error) { + c.updateDeadline() + n, err = c.Conn.Write(p) + if _, isNetErr := err.(net.Error); isNetErr && c.closeCanceler != nil { + c.closeCanceler() + } + return +} + +func (c *serverConn) Read(b []byte) (n int, err error) { + c.updateDeadline() + n, err = c.Conn.Read(b) + if _, isNetErr := err.(net.Error); isNetErr && c.closeCanceler != nil { + c.closeCanceler() + } + return +} + +func (c *serverConn) Close() (err error) { + err = c.Conn.Close() + if c.closeCanceler != nil { + c.closeCanceler() + } + return +} + +func (c *serverConn) updateDeadline() { + idleDeadline := time.Now().Add(c.idleTimeout) + if idleDeadline.Unix() < c.maxDeadline.Unix() { + c.Conn.SetDeadline(idleDeadline) + } else { + c.Conn.SetDeadline(c.maxDeadline) + } +} diff --git a/context.go b/context.go index 9a693b3..7d9c7ae 100644 --- a/context.go +++ b/context.go @@ -2,8 +2,8 @@ package ssh import ( "context" - "net" "encoding/hex" + "net" gossh "golang.org/x/crypto/ssh" ) @@ -92,12 +92,13 @@ type sshContext struct { context.Context } -func newContext(srv *Server) *sshContext { - ctx := &sshContext{context.Background()} +func newContext(srv *Server) (*sshContext, context.CancelFunc) { + innerCtx, cancel := context.WithCancel(context.Background()) + ctx := &sshContext{innerCtx} ctx.SetValue(ContextKeyServer, srv) perms := &Permissions{&gossh.Permissions{}} ctx.SetValue(ContextKeyPermissions, perms) - return ctx + return ctx, cancel } // this is separate from newContext because we will get ConnMetadata diff --git a/server.go b/server.go index 54809f9..63a380c 100644 --- a/server.go +++ b/server.go @@ -30,6 +30,9 @@ type Server struct { ConnCallback ConnCallback // optional callback for wrapping net.Conn before handling LocalPortForwardingCallback LocalPortForwardingCallback // callback for allowing local port forwarding, denies all if nil + IdleTimeout time.Duration // connection timeout when no activity, none if empty + MaxTimeout time.Duration // absolute connection timeout, none if empty + channelHandlers map[string]channelHandler mu sync.Mutex @@ -191,17 +194,25 @@ func (srv *Server) Serve(l net.Listener) error { } } -func (srv *Server) handleConn(conn net.Conn) { +func (srv *Server) handleConn(newConn net.Conn) { if srv.ConnCallback != nil { - cbConn := srv.ConnCallback(conn) + cbConn := srv.ConnCallback(newConn) if cbConn == nil { - conn.Close() + newConn.Close() return } - conn = cbConn + newConn = cbConn + } + ctx, cancel := newContext(srv) + conn := &serverConn{ + Conn: newConn, + idleTimeout: srv.IdleTimeout, + closeCanceler: cancel, + } + if int64(srv.MaxTimeout) > 0 { + conn.maxDeadline = time.Now().Add(srv.MaxTimeout) } defer conn.Close() - ctx := newContext(srv) sshConn, chans, reqs, err := gossh.NewServerConn(conn, srv.config(ctx)) if err != nil { // TODO: trigger event callback