server: timeouts and context canceling on closed connections (#46)

Signed-off-by: Jeff Lindsay <progrium@gmail.com>
This commit is contained in:
Jeff Lindsay 2017-07-24 16:25:45 -05:00 committed by GitHub
parent f892d8d851
commit 48c9603bfc
4 changed files with 90 additions and 45 deletions

@ -1,58 +1,41 @@
package main
import (
"fmt"
"log"
"net"
"time"
"github.com/gliderlabs/ssh"
)
var (
MaxLifeTimeout = 30 * time.Second
IdleTimeout = 5 * time.Second
DeadlineTimeout = 30 * time.Second
IdleTimeout = 10 * time.Second
)
type timeoutConn struct {
net.Conn
maxlife time.Time
idle time.Time
}
func (c *timeoutConn) Write(p []byte) (n int, err error) {
c.updateDeadline()
return c.Conn.Write(p)
}
func (c *timeoutConn) Read(b []byte) (n int, err error) {
c.idle = time.Now().Add(IdleTimeout)
c.updateDeadline()
return c.Conn.Read(b)
}
func (c *timeoutConn) updateDeadline() {
if c.idle.Unix() < c.maxlife.Unix() {
c.Conn.SetDeadline(c.idle)
} else {
c.Conn.SetDeadline(c.maxlife)
}
}
func main() {
ssh.Handle(func(s ssh.Session) {
log.Println("new connection")
i := 0
for {
i += 1
fmt.Fprintln(s, i)
time.Sleep(time.Second)
log.Println("active seconds:", i)
select {
case <-time.After(time.Second):
continue
case <-s.Context().Done():
log.Println("connection closed")
return
}
}
})
log.Println("starting ssh server on port 2222...")
log.Printf("connections will only last %s\n", MaxLifeTimeout)
log.Printf("and timeout after %s of no client activity\n", IdleTimeout)
log.Fatal(ssh.ListenAndServe(":2222", nil, ssh.WrapConn(func(conn net.Conn) net.Conn {
return &timeoutConn{conn, time.Now().Add(MaxLifeTimeout), time.Now().Add(IdleTimeout)}
})))
log.Printf("connections will only last %s\n", DeadlineTimeout)
log.Printf("and timeout after %s of no activity\n", IdleTimeout)
server := &ssh.Server{
Addr: ":2222",
MaxTimeout: DeadlineTimeout,
IdleTimeout: IdleTimeout,
}
log.Fatal(server.ListenAndServe())
}

50
conn.go Normal file

@ -0,0 +1,50 @@
package ssh
import (
"context"
"net"
"time"
)
type serverConn struct {
net.Conn
idleTimeout time.Duration
maxDeadline time.Time
closeCanceler context.CancelFunc
}
func (c *serverConn) Write(p []byte) (n int, err error) {
c.updateDeadline()
n, err = c.Conn.Write(p)
if _, isNetErr := err.(net.Error); isNetErr && c.closeCanceler != nil {
c.closeCanceler()
}
return
}
func (c *serverConn) Read(b []byte) (n int, err error) {
c.updateDeadline()
n, err = c.Conn.Read(b)
if _, isNetErr := err.(net.Error); isNetErr && c.closeCanceler != nil {
c.closeCanceler()
}
return
}
func (c *serverConn) Close() (err error) {
err = c.Conn.Close()
if c.closeCanceler != nil {
c.closeCanceler()
}
return
}
func (c *serverConn) updateDeadline() {
idleDeadline := time.Now().Add(c.idleTimeout)
if idleDeadline.Unix() < c.maxDeadline.Unix() {
c.Conn.SetDeadline(idleDeadline)
} else {
c.Conn.SetDeadline(c.maxDeadline)
}
}

@ -2,8 +2,8 @@ package ssh
import (
"context"
"net"
"encoding/hex"
"net"
gossh "golang.org/x/crypto/ssh"
)
@ -92,12 +92,13 @@ type sshContext struct {
context.Context
}
func newContext(srv *Server) *sshContext {
ctx := &sshContext{context.Background()}
func newContext(srv *Server) (*sshContext, context.CancelFunc) {
innerCtx, cancel := context.WithCancel(context.Background())
ctx := &sshContext{innerCtx}
ctx.SetValue(ContextKeyServer, srv)
perms := &Permissions{&gossh.Permissions{}}
ctx.SetValue(ContextKeyPermissions, perms)
return ctx
return ctx, cancel
}
// this is separate from newContext because we will get ConnMetadata

@ -30,6 +30,9 @@ type Server struct {
ConnCallback ConnCallback // optional callback for wrapping net.Conn before handling
LocalPortForwardingCallback LocalPortForwardingCallback // callback for allowing local port forwarding, denies all if nil
IdleTimeout time.Duration // connection timeout when no activity, none if empty
MaxTimeout time.Duration // absolute connection timeout, none if empty
channelHandlers map[string]channelHandler
mu sync.Mutex
@ -191,17 +194,25 @@ func (srv *Server) Serve(l net.Listener) error {
}
}
func (srv *Server) handleConn(conn net.Conn) {
func (srv *Server) handleConn(newConn net.Conn) {
if srv.ConnCallback != nil {
cbConn := srv.ConnCallback(conn)
cbConn := srv.ConnCallback(newConn)
if cbConn == nil {
conn.Close()
newConn.Close()
return
}
conn = cbConn
newConn = cbConn
}
ctx, cancel := newContext(srv)
conn := &serverConn{
Conn: newConn,
idleTimeout: srv.IdleTimeout,
closeCanceler: cancel,
}
if int64(srv.MaxTimeout) > 0 {
conn.maxDeadline = time.Now().Add(srv.MaxTimeout)
}
defer conn.Close()
ctx := newContext(srv)
sshConn, chans, reqs, err := gossh.NewServerConn(conn, srv.config(ctx))
if err != nil {
// TODO: trigger event callback