diff --git a/options_test.go b/options_test.go index 2301aeb..23fca5a 100644 --- a/options_test.go +++ b/options_test.go @@ -95,7 +95,7 @@ func TestConnWrapping(t *testing.T) { HostKeyCallback: gossh.InsecureIgnoreHostKey(), }, PasswordAuth(func(ctx Context, password string) bool { return true - }), WrapConn(func(conn net.Conn) net.Conn { + }), WrapConn(func(ctx Context, conn net.Conn) net.Conn { wrapped = &wrappedConn{conn, 0} return wrapped })) diff --git a/server.go b/server.go index 249d7ce..cad0402 100644 --- a/server.go +++ b/server.go @@ -233,15 +233,15 @@ func (srv *Server) Serve(l net.Listener) error { } func (srv *Server) HandleConn(newConn net.Conn) { + ctx, cancel := newContext(srv) if srv.ConnCallback != nil { - cbConn := srv.ConnCallback(newConn) + cbConn := srv.ConnCallback(ctx, newConn) if cbConn == nil { newConn.Close() return } newConn = cbConn } - ctx, cancel := newContext(srv) conn := &serverConn{ Conn: newConn, idleTimeout: srv.IdleTimeout, diff --git a/ssh.go b/ssh.go index f5a935a..9673ac3 100644 --- a/ssh.go +++ b/ssh.go @@ -53,7 +53,7 @@ type SessionRequestCallback func(sess Session, requestType string) bool // ConnCallback is a hook for new connections before handling. // It allows wrapping for timeouts and limiting by returning // the net.Conn that will be used as the underlying connection. -type ConnCallback func(conn net.Conn) net.Conn +type ConnCallback func(ctx Context, conn net.Conn) net.Conn // LocalPortForwardingCallback is a hook for allowing port forwarding type LocalPortForwardingCallback func(ctx Context, destinationHost string, destinationPort uint32) bool