server: timeouts and context canceling on closed connections (#46)
Signed-off-by: Jeff Lindsay <progrium@gmail.com>
This commit is contained in:
parent
f892d8d851
commit
48c9603bfc
@ -1,58 +1,41 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
)
|
||||
|
||||
var (
|
||||
MaxLifeTimeout = 30 * time.Second
|
||||
IdleTimeout = 5 * time.Second
|
||||
DeadlineTimeout = 30 * time.Second
|
||||
IdleTimeout = 10 * time.Second
|
||||
)
|
||||
|
||||
type timeoutConn struct {
|
||||
net.Conn
|
||||
maxlife time.Time
|
||||
idle time.Time
|
||||
}
|
||||
|
||||
func (c *timeoutConn) Write(p []byte) (n int, err error) {
|
||||
c.updateDeadline()
|
||||
return c.Conn.Write(p)
|
||||
}
|
||||
|
||||
func (c *timeoutConn) Read(b []byte) (n int, err error) {
|
||||
c.idle = time.Now().Add(IdleTimeout)
|
||||
c.updateDeadline()
|
||||
return c.Conn.Read(b)
|
||||
}
|
||||
|
||||
func (c *timeoutConn) updateDeadline() {
|
||||
if c.idle.Unix() < c.maxlife.Unix() {
|
||||
c.Conn.SetDeadline(c.idle)
|
||||
} else {
|
||||
c.Conn.SetDeadline(c.maxlife)
|
||||
}
|
||||
}
|
||||
|
||||
func main() {
|
||||
ssh.Handle(func(s ssh.Session) {
|
||||
log.Println("new connection")
|
||||
i := 0
|
||||
for {
|
||||
i += 1
|
||||
fmt.Fprintln(s, i)
|
||||
time.Sleep(time.Second)
|
||||
log.Println("active seconds:", i)
|
||||
select {
|
||||
case <-time.After(time.Second):
|
||||
continue
|
||||
case <-s.Context().Done():
|
||||
log.Println("connection closed")
|
||||
return
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
log.Println("starting ssh server on port 2222...")
|
||||
log.Printf("connections will only last %s\n", MaxLifeTimeout)
|
||||
log.Printf("and timeout after %s of no client activity\n", IdleTimeout)
|
||||
log.Fatal(ssh.ListenAndServe(":2222", nil, ssh.WrapConn(func(conn net.Conn) net.Conn {
|
||||
return &timeoutConn{conn, time.Now().Add(MaxLifeTimeout), time.Now().Add(IdleTimeout)}
|
||||
})))
|
||||
log.Printf("connections will only last %s\n", DeadlineTimeout)
|
||||
log.Printf("and timeout after %s of no activity\n", IdleTimeout)
|
||||
server := &ssh.Server{
|
||||
Addr: ":2222",
|
||||
MaxTimeout: DeadlineTimeout,
|
||||
IdleTimeout: IdleTimeout,
|
||||
}
|
||||
log.Fatal(server.ListenAndServe())
|
||||
}
|
||||
|
50
conn.go
Normal file
50
conn.go
Normal file
@ -0,0 +1,50 @@
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
type serverConn struct {
|
||||
net.Conn
|
||||
|
||||
idleTimeout time.Duration
|
||||
maxDeadline time.Time
|
||||
closeCanceler context.CancelFunc
|
||||
}
|
||||
|
||||
func (c *serverConn) Write(p []byte) (n int, err error) {
|
||||
c.updateDeadline()
|
||||
n, err = c.Conn.Write(p)
|
||||
if _, isNetErr := err.(net.Error); isNetErr && c.closeCanceler != nil {
|
||||
c.closeCanceler()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c *serverConn) Read(b []byte) (n int, err error) {
|
||||
c.updateDeadline()
|
||||
n, err = c.Conn.Read(b)
|
||||
if _, isNetErr := err.(net.Error); isNetErr && c.closeCanceler != nil {
|
||||
c.closeCanceler()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c *serverConn) Close() (err error) {
|
||||
err = c.Conn.Close()
|
||||
if c.closeCanceler != nil {
|
||||
c.closeCanceler()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c *serverConn) updateDeadline() {
|
||||
idleDeadline := time.Now().Add(c.idleTimeout)
|
||||
if idleDeadline.Unix() < c.maxDeadline.Unix() {
|
||||
c.Conn.SetDeadline(idleDeadline)
|
||||
} else {
|
||||
c.Conn.SetDeadline(c.maxDeadline)
|
||||
}
|
||||
}
|
@ -2,8 +2,8 @@ package ssh
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"encoding/hex"
|
||||
"net"
|
||||
|
||||
gossh "golang.org/x/crypto/ssh"
|
||||
)
|
||||
@ -92,12 +92,13 @@ type sshContext struct {
|
||||
context.Context
|
||||
}
|
||||
|
||||
func newContext(srv *Server) *sshContext {
|
||||
ctx := &sshContext{context.Background()}
|
||||
func newContext(srv *Server) (*sshContext, context.CancelFunc) {
|
||||
innerCtx, cancel := context.WithCancel(context.Background())
|
||||
ctx := &sshContext{innerCtx}
|
||||
ctx.SetValue(ContextKeyServer, srv)
|
||||
perms := &Permissions{&gossh.Permissions{}}
|
||||
ctx.SetValue(ContextKeyPermissions, perms)
|
||||
return ctx
|
||||
return ctx, cancel
|
||||
}
|
||||
|
||||
// this is separate from newContext because we will get ConnMetadata
|
||||
|
21
server.go
21
server.go
@ -30,6 +30,9 @@ type Server struct {
|
||||
ConnCallback ConnCallback // optional callback for wrapping net.Conn before handling
|
||||
LocalPortForwardingCallback LocalPortForwardingCallback // callback for allowing local port forwarding, denies all if nil
|
||||
|
||||
IdleTimeout time.Duration // connection timeout when no activity, none if empty
|
||||
MaxTimeout time.Duration // absolute connection timeout, none if empty
|
||||
|
||||
channelHandlers map[string]channelHandler
|
||||
|
||||
mu sync.Mutex
|
||||
@ -191,17 +194,25 @@ func (srv *Server) Serve(l net.Listener) error {
|
||||
}
|
||||
}
|
||||
|
||||
func (srv *Server) handleConn(conn net.Conn) {
|
||||
func (srv *Server) handleConn(newConn net.Conn) {
|
||||
if srv.ConnCallback != nil {
|
||||
cbConn := srv.ConnCallback(conn)
|
||||
cbConn := srv.ConnCallback(newConn)
|
||||
if cbConn == nil {
|
||||
conn.Close()
|
||||
newConn.Close()
|
||||
return
|
||||
}
|
||||
conn = cbConn
|
||||
newConn = cbConn
|
||||
}
|
||||
ctx, cancel := newContext(srv)
|
||||
conn := &serverConn{
|
||||
Conn: newConn,
|
||||
idleTimeout: srv.IdleTimeout,
|
||||
closeCanceler: cancel,
|
||||
}
|
||||
if int64(srv.MaxTimeout) > 0 {
|
||||
conn.maxDeadline = time.Now().Add(srv.MaxTimeout)
|
||||
}
|
||||
defer conn.Close()
|
||||
ctx := newContext(srv)
|
||||
sshConn, chans, reqs, err := gossh.NewServerConn(conn, srv.config(ctx))
|
||||
if err != nil {
|
||||
// TODO: trigger event callback
|
||||
|
Loading…
Reference in New Issue
Block a user