[proposal] ConnCallback (#36)

ConnCallback lets you wrap connection objects for timeouts and limiting
This commit is contained in:
Jeff Lindsay 2017-07-12 12:27:56 -05:00 committed by GitHub
parent bf3073636e
commit 33ad2fe318
11 changed files with 122 additions and 1 deletions

@ -30,7 +30,7 @@ This package was built after working on nearly a dozen projects at Glider Labs u
## Examples
A bunch of great examples are in the `_example` directory.
A bunch of great examples are in the `_examples` directory.
## Usage

@ -0,0 +1,58 @@
package main
import (
"fmt"
"log"
"net"
"time"
"github.com/gliderlabs/ssh"
)
var (
MaxLifeTimeout = 30 * time.Second
IdleTimeout = 5 * time.Second
)
type timeoutConn struct {
net.Conn
maxlife time.Time
idle time.Time
}
func (c *timeoutConn) Write(p []byte) (n int, err error) {
c.updateDeadline()
return c.Conn.Write(p)
}
func (c *timeoutConn) Read(b []byte) (n int, err error) {
c.idle = time.Now().Add(IdleTimeout)
c.updateDeadline()
return c.Conn.Read(b)
}
func (c *timeoutConn) updateDeadline() {
if c.idle.Unix() < c.maxlife.Unix() {
c.Conn.SetDeadline(c.idle)
} else {
c.Conn.SetDeadline(c.maxlife)
}
}
func main() {
ssh.Handle(func(s ssh.Session) {
i := 0
for {
i += 1
fmt.Fprintln(s, i)
time.Sleep(time.Second)
}
})
log.Println("starting ssh server on port 2222...")
log.Printf("connections will only last %s\n", MaxLifeTimeout)
log.Printf("and timeout after %s of no client activity\n", IdleTimeout)
log.Fatal(ssh.ListenAndServe(":2222", nil, ssh.WrapConn(func(conn net.Conn) net.Conn {
return &timeoutConn{conn, time.Now().Add(MaxLifeTimeout), time.Now().Add(IdleTimeout)}
})))
}

@ -62,3 +62,11 @@ func NoPty() Option {
return nil
}
}
// WrapConn returns a functional option that sets ConnCallback on the server.
func WrapConn(fn ConnCallback) Option {
return func(srv *Server) error {
srv.ConnCallback = fn
return nil
}
}

@ -1,7 +1,9 @@
package ssh
import (
"net"
"strings"
"sync/atomic"
"testing"
gossh "golang.org/x/crypto/ssh"
@ -66,3 +68,42 @@ func TestPasswordAuthBadPass(t *testing.T) {
}
}
}
type wrappedConn struct {
net.Conn
written int32
}
func (c *wrappedConn) Write(p []byte) (n int, err error) {
n, err = c.Conn.Write(p)
atomic.AddInt32(&(c.written), int32(n))
return
}
func TestConnWrapping(t *testing.T) {
t.Parallel()
var wrapped *wrappedConn
session, _, cleanup := newTestSessionWithOptions(t, &Server{
Handler: func(s Session) {
// nothing
},
}, &gossh.ClientConfig{
User: "testuser",
Auth: []gossh.AuthMethod{
gossh.Password("testpass"),
},
HostKeyCallback: gossh.InsecureIgnoreHostKey(),
}, PasswordAuth(func(ctx Context, password string) bool {
return true
}), WrapConn(func(conn net.Conn) net.Conn {
wrapped = &wrappedConn{conn, 0}
return wrapped
}))
defer cleanup()
if err := session.Shell(); err != nil {
t.Fatal(err)
}
if atomic.LoadInt32(&(wrapped.written)) == 0 {
t.Fatal("wrapped conn not written to")
}
}

@ -27,6 +27,7 @@ type Server struct {
PasswordHandler PasswordHandler // password authentication handler
PublicKeyHandler PublicKeyHandler // public key authentication handler
PtyCallback PtyCallback // callback for allowing PTY sessions, allows all if nil
ConnCallback ConnCallback // optional callback for wrapping net.Conn before handling
LocalPortForwardingCallback LocalPortForwardingCallback // callback for allowing local port forwarding, denies all if nil
channelHandlers map[string]channelHandler
@ -191,6 +192,14 @@ func (srv *Server) Serve(l net.Listener) error {
}
func (srv *Server) handleConn(conn net.Conn) {
if srv.ConnCallback != nil {
cbConn := srv.ConnCallback(conn)
if cbConn == nil {
conn.Close()
return
}
conn = cbConn
}
defer conn.Close()
ctx := newContext(srv)
sshConn, chans, reqs, err := gossh.NewServerConn(conn, srv.config(ctx))

5
ssh.go

@ -42,6 +42,11 @@ type PasswordHandler func(ctx Context, password string) bool
// PtyCallback is a hook for allowing PTY sessions.
type PtyCallback func(ctx Context, pty Pty) bool
// ConnCallback is a hook for new connections before handling.
// It allows wrapping for timeouts and limiting by returning
// the net.Conn that will be used as the underlying connection.
type ConnCallback func(conn net.Conn) net.Conn
// LocalPortForwardingCallback is a hook for allowing port forwarding
type LocalPortForwardingCallback func(ctx Context, destinationHost string, destinationPort uint32) bool