[proposal] ConnCallback (#36)
ConnCallback lets you wrap connection objects for timeouts and limiting
This commit is contained in:
parent
bf3073636e
commit
33ad2fe318
@ -30,7 +30,7 @@ This package was built after working on nearly a dozen projects at Glider Labs u
|
||||
|
||||
## Examples
|
||||
|
||||
A bunch of great examples are in the `_example` directory.
|
||||
A bunch of great examples are in the `_examples` directory.
|
||||
|
||||
## Usage
|
||||
|
||||
|
58
_examples/ssh-timeouts/timeouts.go
Normal file
58
_examples/ssh-timeouts/timeouts.go
Normal file
@ -0,0 +1,58 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
)
|
||||
|
||||
var (
|
||||
MaxLifeTimeout = 30 * time.Second
|
||||
IdleTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
type timeoutConn struct {
|
||||
net.Conn
|
||||
maxlife time.Time
|
||||
idle time.Time
|
||||
}
|
||||
|
||||
func (c *timeoutConn) Write(p []byte) (n int, err error) {
|
||||
c.updateDeadline()
|
||||
return c.Conn.Write(p)
|
||||
}
|
||||
|
||||
func (c *timeoutConn) Read(b []byte) (n int, err error) {
|
||||
c.idle = time.Now().Add(IdleTimeout)
|
||||
c.updateDeadline()
|
||||
return c.Conn.Read(b)
|
||||
}
|
||||
|
||||
func (c *timeoutConn) updateDeadline() {
|
||||
if c.idle.Unix() < c.maxlife.Unix() {
|
||||
c.Conn.SetDeadline(c.idle)
|
||||
} else {
|
||||
c.Conn.SetDeadline(c.maxlife)
|
||||
}
|
||||
}
|
||||
|
||||
func main() {
|
||||
ssh.Handle(func(s ssh.Session) {
|
||||
i := 0
|
||||
for {
|
||||
i += 1
|
||||
fmt.Fprintln(s, i)
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
})
|
||||
|
||||
log.Println("starting ssh server on port 2222...")
|
||||
log.Printf("connections will only last %s\n", MaxLifeTimeout)
|
||||
log.Printf("and timeout after %s of no client activity\n", IdleTimeout)
|
||||
log.Fatal(ssh.ListenAndServe(":2222", nil, ssh.WrapConn(func(conn net.Conn) net.Conn {
|
||||
return &timeoutConn{conn, time.Now().Add(MaxLifeTimeout), time.Now().Add(IdleTimeout)}
|
||||
})))
|
||||
}
|
@ -62,3 +62,11 @@ func NoPty() Option {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WrapConn returns a functional option that sets ConnCallback on the server.
|
||||
func WrapConn(fn ConnCallback) Option {
|
||||
return func(srv *Server) error {
|
||||
srv.ConnCallback = fn
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
@ -1,7 +1,9 @@
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"net"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
gossh "golang.org/x/crypto/ssh"
|
||||
@ -66,3 +68,42 @@ func TestPasswordAuthBadPass(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type wrappedConn struct {
|
||||
net.Conn
|
||||
written int32
|
||||
}
|
||||
|
||||
func (c *wrappedConn) Write(p []byte) (n int, err error) {
|
||||
n, err = c.Conn.Write(p)
|
||||
atomic.AddInt32(&(c.written), int32(n))
|
||||
return
|
||||
}
|
||||
|
||||
func TestConnWrapping(t *testing.T) {
|
||||
t.Parallel()
|
||||
var wrapped *wrappedConn
|
||||
session, _, cleanup := newTestSessionWithOptions(t, &Server{
|
||||
Handler: func(s Session) {
|
||||
// nothing
|
||||
},
|
||||
}, &gossh.ClientConfig{
|
||||
User: "testuser",
|
||||
Auth: []gossh.AuthMethod{
|
||||
gossh.Password("testpass"),
|
||||
},
|
||||
HostKeyCallback: gossh.InsecureIgnoreHostKey(),
|
||||
}, PasswordAuth(func(ctx Context, password string) bool {
|
||||
return true
|
||||
}), WrapConn(func(conn net.Conn) net.Conn {
|
||||
wrapped = &wrappedConn{conn, 0}
|
||||
return wrapped
|
||||
}))
|
||||
defer cleanup()
|
||||
if err := session.Shell(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if atomic.LoadInt32(&(wrapped.written)) == 0 {
|
||||
t.Fatal("wrapped conn not written to")
|
||||
}
|
||||
}
|
||||
|
@ -27,6 +27,7 @@ type Server struct {
|
||||
PasswordHandler PasswordHandler // password authentication handler
|
||||
PublicKeyHandler PublicKeyHandler // public key authentication handler
|
||||
PtyCallback PtyCallback // callback for allowing PTY sessions, allows all if nil
|
||||
ConnCallback ConnCallback // optional callback for wrapping net.Conn before handling
|
||||
LocalPortForwardingCallback LocalPortForwardingCallback // callback for allowing local port forwarding, denies all if nil
|
||||
|
||||
channelHandlers map[string]channelHandler
|
||||
@ -191,6 +192,14 @@ func (srv *Server) Serve(l net.Listener) error {
|
||||
}
|
||||
|
||||
func (srv *Server) handleConn(conn net.Conn) {
|
||||
if srv.ConnCallback != nil {
|
||||
cbConn := srv.ConnCallback(conn)
|
||||
if cbConn == nil {
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
conn = cbConn
|
||||
}
|
||||
defer conn.Close()
|
||||
ctx := newContext(srv)
|
||||
sshConn, chans, reqs, err := gossh.NewServerConn(conn, srv.config(ctx))
|
||||
|
5
ssh.go
5
ssh.go
@ -42,6 +42,11 @@ type PasswordHandler func(ctx Context, password string) bool
|
||||
// PtyCallback is a hook for allowing PTY sessions.
|
||||
type PtyCallback func(ctx Context, pty Pty) bool
|
||||
|
||||
// ConnCallback is a hook for new connections before handling.
|
||||
// It allows wrapping for timeouts and limiting by returning
|
||||
// the net.Conn that will be used as the underlying connection.
|
||||
type ConnCallback func(conn net.Conn) net.Conn
|
||||
|
||||
// LocalPortForwardingCallback is a hook for allowing port forwarding
|
||||
type LocalPortForwardingCallback func(ctx Context, destinationHost string, destinationPort uint32) bool
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user