Support for local port forwarding (#38)
* Support local port forwarding * refactor testSession to return ssh client as well * Tests for local port forwarding
This commit is contained in:
parent
1051a0d154
commit
20a454724d
@ -7,7 +7,7 @@ func TestSetPermissions(t *testing.T) {
|
||||
permsExt := map[string]string{
|
||||
"foo": "bar",
|
||||
}
|
||||
session, cleanup := newTestSessionWithOptions(t, &Server{
|
||||
session, _, cleanup := newTestSessionWithOptions(t, &Server{
|
||||
Handler: func(s Session) {
|
||||
if _, ok := s.Permissions().Extensions["foo"]; !ok {
|
||||
t.Fatalf("got %#v; want %#v", s.Permissions().Extensions, permsExt)
|
||||
@ -29,7 +29,7 @@ func TestSetValue(t *testing.T) {
|
||||
"foo": "bar",
|
||||
}
|
||||
key := "testValue"
|
||||
session, cleanup := newTestSessionWithOptions(t, &Server{
|
||||
session, _, cleanup := newTestSessionWithOptions(t, &Server{
|
||||
Handler: func(s Session) {
|
||||
v := s.Context().Value(key).(map[string]string)
|
||||
if v["foo"] != value["foo"] {
|
||||
|
@ -7,7 +7,7 @@ import (
|
||||
gossh "golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
func newTestSessionWithOptions(t *testing.T, srv *Server, cfg *gossh.ClientConfig, options ...Option) (*gossh.Session, func()) {
|
||||
func newTestSessionWithOptions(t *testing.T, srv *Server, cfg *gossh.ClientConfig, options ...Option) (*gossh.Session, *gossh.Client, func()) {
|
||||
for _, option := range options {
|
||||
if err := srv.SetOption(option); err != nil {
|
||||
t.Fatal(err)
|
||||
@ -20,7 +20,7 @@ func TestPasswordAuth(t *testing.T) {
|
||||
t.Parallel()
|
||||
testUser := "testuser"
|
||||
testPass := "testpass"
|
||||
session, cleanup := newTestSessionWithOptions(t, &Server{
|
||||
session, _, cleanup := newTestSessionWithOptions(t, &Server{
|
||||
Handler: func(s Session) {
|
||||
// noop
|
||||
},
|
||||
|
10
server.go
10
server.go
@ -17,9 +17,10 @@ type Server struct {
|
||||
HostSigners []Signer // private keys for the host key, must have at least one
|
||||
Version string // server version to be sent before the initial handshake
|
||||
|
||||
PasswordHandler PasswordHandler // password authentication handler
|
||||
PublicKeyHandler PublicKeyHandler // public key authentication handler
|
||||
PtyCallback PtyCallback // callback for allowing PTY sessions, allows all if nil
|
||||
PasswordHandler PasswordHandler // password authentication handler
|
||||
PublicKeyHandler PublicKeyHandler // public key authentication handler
|
||||
PtyCallback PtyCallback // callback for allowing PTY sessions, allows all if nil
|
||||
LocalPortForwardingCallback LocalPortForwardingCallback // callback for allowing local port forwarding, denies all if nil
|
||||
|
||||
channelHandlers map[string]channelHandler
|
||||
}
|
||||
@ -40,7 +41,8 @@ func (srv *Server) ensureHostSigner() error {
|
||||
|
||||
func (srv *Server) config(ctx *sshContext) *gossh.ServerConfig {
|
||||
srv.channelHandlers = map[string]channelHandler{
|
||||
"session": sessionHandler,
|
||||
"session": sessionHandler,
|
||||
"direct-tcpip": directTcpipHandler,
|
||||
}
|
||||
config := &gossh.ServerConfig{}
|
||||
for _, signer := range srv.HostSigners {
|
||||
|
@ -32,7 +32,7 @@ func newLocalListener() net.Listener {
|
||||
return l
|
||||
}
|
||||
|
||||
func newClientSession(t *testing.T, addr string, config *gossh.ClientConfig) (*gossh.Session, func()) {
|
||||
func newClientSession(t *testing.T, addr string, config *gossh.ClientConfig) (*gossh.Session, *gossh.Client, func()) {
|
||||
if config == nil {
|
||||
config = &gossh.ClientConfig{
|
||||
User: "testuser",
|
||||
@ -52,13 +52,13 @@ func newClientSession(t *testing.T, addr string, config *gossh.ClientConfig) (*g
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return session, func() {
|
||||
return session, client, func() {
|
||||
session.Close()
|
||||
client.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func newTestSession(t *testing.T, srv *Server, cfg *gossh.ClientConfig) (*gossh.Session, func()) {
|
||||
func newTestSession(t *testing.T, srv *Server, cfg *gossh.ClientConfig) (*gossh.Session, *gossh.Client, func()) {
|
||||
l := newLocalListener()
|
||||
go srv.serveOnce(l)
|
||||
return newClientSession(t, l.Addr().String(), cfg)
|
||||
@ -67,7 +67,7 @@ func newTestSession(t *testing.T, srv *Server, cfg *gossh.ClientConfig) (*gossh.
|
||||
func TestStdout(t *testing.T) {
|
||||
t.Parallel()
|
||||
testBytes := []byte("Hello world\n")
|
||||
session, cleanup := newTestSession(t, &Server{
|
||||
session, _, cleanup := newTestSession(t, &Server{
|
||||
Handler: func(s Session) {
|
||||
s.Write(testBytes)
|
||||
},
|
||||
@ -86,7 +86,7 @@ func TestStdout(t *testing.T) {
|
||||
func TestStderr(t *testing.T) {
|
||||
t.Parallel()
|
||||
testBytes := []byte("Hello world\n")
|
||||
session, cleanup := newTestSession(t, &Server{
|
||||
session, _, cleanup := newTestSession(t, &Server{
|
||||
Handler: func(s Session) {
|
||||
s.Stderr().Write(testBytes)
|
||||
},
|
||||
@ -105,7 +105,7 @@ func TestStderr(t *testing.T) {
|
||||
func TestStdin(t *testing.T) {
|
||||
t.Parallel()
|
||||
testBytes := []byte("Hello world\n")
|
||||
session, cleanup := newTestSession(t, &Server{
|
||||
session, _, cleanup := newTestSession(t, &Server{
|
||||
Handler: func(s Session) {
|
||||
io.Copy(s, s) // stdin back into stdout
|
||||
},
|
||||
@ -125,7 +125,7 @@ func TestStdin(t *testing.T) {
|
||||
func TestUser(t *testing.T) {
|
||||
t.Parallel()
|
||||
testUser := []byte("progrium")
|
||||
session, cleanup := newTestSession(t, &Server{
|
||||
session, _, cleanup := newTestSession(t, &Server{
|
||||
Handler: func(s Session) {
|
||||
io.WriteString(s, s.User())
|
||||
},
|
||||
@ -145,7 +145,7 @@ func TestUser(t *testing.T) {
|
||||
|
||||
func TestDefaultExitStatusZero(t *testing.T) {
|
||||
t.Parallel()
|
||||
session, cleanup := newTestSession(t, &Server{
|
||||
session, _, cleanup := newTestSession(t, &Server{
|
||||
Handler: func(s Session) {
|
||||
// noop
|
||||
},
|
||||
@ -159,7 +159,7 @@ func TestDefaultExitStatusZero(t *testing.T) {
|
||||
|
||||
func TestExplicitExitStatusZero(t *testing.T) {
|
||||
t.Parallel()
|
||||
session, cleanup := newTestSession(t, &Server{
|
||||
session, _, cleanup := newTestSession(t, &Server{
|
||||
Handler: func(s Session) {
|
||||
s.Exit(0)
|
||||
},
|
||||
@ -173,7 +173,7 @@ func TestExplicitExitStatusZero(t *testing.T) {
|
||||
|
||||
func TestExitStatusNonZero(t *testing.T) {
|
||||
t.Parallel()
|
||||
session, cleanup := newTestSession(t, &Server{
|
||||
session, _, cleanup := newTestSession(t, &Server{
|
||||
Handler: func(s Session) {
|
||||
s.Exit(1)
|
||||
},
|
||||
@ -195,7 +195,7 @@ func TestPty(t *testing.T) {
|
||||
winWidth := 40
|
||||
winHeight := 80
|
||||
done := make(chan bool)
|
||||
session, cleanup := newTestSession(t, &Server{
|
||||
session, _, cleanup := newTestSession(t, &Server{
|
||||
Handler: func(s Session) {
|
||||
ptyReq, _, isPty := s.Pty()
|
||||
if !isPty {
|
||||
@ -230,7 +230,7 @@ func TestPtyResize(t *testing.T) {
|
||||
winch2 := Window{20, 40}
|
||||
winches := make(chan Window)
|
||||
done := make(chan bool)
|
||||
session, cleanup := newTestSession(t, &Server{
|
||||
session, _, cleanup := newTestSession(t, &Server{
|
||||
Handler: func(s Session) {
|
||||
ptyReq, winCh, isPty := s.Pty()
|
||||
if !isPty {
|
||||
|
3
ssh.go
3
ssh.go
@ -42,6 +42,9 @@ type PasswordHandler func(ctx Context, password string) bool
|
||||
// PtyCallback is a hook for allowing PTY sessions.
|
||||
type PtyCallback func(ctx Context, pty Pty) bool
|
||||
|
||||
// LocalPortForwardingCallback is a hook for allowing port forwarding
|
||||
type LocalPortForwardingCallback func(ctx Context, destinationHost string, destinationPort uint32) bool
|
||||
|
||||
// Window represents the size of a PTY window.
|
||||
type Window struct {
|
||||
Width int
|
||||
|
58
tcpip.go
Normal file
58
tcpip.go
Normal file
@ -0,0 +1,58 @@
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
|
||||
gossh "golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
// direct-tcpip data struct as specified in RFC4254, Section 7.2
|
||||
type forwardData struct {
|
||||
DestinationHost string
|
||||
DestinationPort uint32
|
||||
|
||||
OriginatorHost string
|
||||
OriginatorPort uint32
|
||||
}
|
||||
|
||||
func directTcpipHandler(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx *sshContext) {
|
||||
d := forwardData{}
|
||||
if err := gossh.Unmarshal(newChan.ExtraData(), &d); err != nil {
|
||||
newChan.Reject(gossh.ConnectionFailed, "error parsing forward data: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if srv.LocalPortForwardingCallback == nil || !srv.LocalPortForwardingCallback(ctx, d.DestinationHost, d.DestinationPort) {
|
||||
newChan.Reject(gossh.Prohibited, "port forwarding is disabled")
|
||||
return
|
||||
}
|
||||
|
||||
dest := fmt.Sprintf("%s:%d", d.DestinationHost, d.DestinationPort)
|
||||
|
||||
var dialer net.Dialer
|
||||
dconn, err := dialer.DialContext(ctx, "tcp", dest)
|
||||
if err != nil {
|
||||
newChan.Reject(gossh.ConnectionFailed, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
ch, reqs, err := newChan.Accept()
|
||||
if err != nil {
|
||||
dconn.Close()
|
||||
return
|
||||
}
|
||||
go gossh.DiscardRequests(reqs)
|
||||
|
||||
go func() {
|
||||
defer ch.Close()
|
||||
defer dconn.Close()
|
||||
io.Copy(ch, dconn)
|
||||
}()
|
||||
go func() {
|
||||
defer ch.Close()
|
||||
defer dconn.Close()
|
||||
io.Copy(dconn, ch)
|
||||
}()
|
||||
}
|
83
tcpip_test.go
Normal file
83
tcpip_test.go
Normal file
@ -0,0 +1,83 @@
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
gossh "golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
var sampleServerResponse = []byte("Hello world")
|
||||
|
||||
func sampleSocketServer() net.Listener {
|
||||
l := newLocalListener()
|
||||
|
||||
go func() {
|
||||
conn, err := l.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
conn.Write(sampleServerResponse)
|
||||
conn.Close()
|
||||
}()
|
||||
|
||||
return l
|
||||
}
|
||||
|
||||
func newTestSessionWithForwarding(t *testing.T, forwardingEnabled bool) (net.Listener, *gossh.Client, func()) {
|
||||
l := sampleSocketServer()
|
||||
|
||||
_, client, cleanup := newTestSession(t, &Server{
|
||||
Handler: func(s Session) {},
|
||||
LocalPortForwardingCallback: func(ctx Context, destinationHost string, destinationPort uint32) bool {
|
||||
addr := fmt.Sprintf("%s:%d", destinationHost, destinationPort)
|
||||
if addr != l.Addr().String() {
|
||||
panic("unexpected destinationHost: " + addr)
|
||||
}
|
||||
return forwardingEnabled
|
||||
},
|
||||
}, nil)
|
||||
|
||||
return l, client, func() {
|
||||
cleanup()
|
||||
l.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func TestLocalPortForwardingWorks(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
l, client, cleanup := newTestSessionWithForwarding(t, true)
|
||||
defer cleanup()
|
||||
|
||||
conn, err := client.Dial("tcp", l.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("Error connecting to %v: %v", l.Addr().String(), err)
|
||||
}
|
||||
result, err := ioutil.ReadAll(conn)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(result, sampleServerResponse) {
|
||||
t.Fatalf("result = %#v; want %#v", result, sampleServerResponse)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLocalPortForwardingRespectsCallback(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
l, client, cleanup := newTestSessionWithForwarding(t, false)
|
||||
defer cleanup()
|
||||
|
||||
_, err := client.Dial("tcp", l.Addr().String())
|
||||
if err == nil {
|
||||
t.Fatalf("Expected error connecting to %v but it succeeded", l.Addr().String())
|
||||
}
|
||||
if !strings.Contains(err.Error(), "port forwarding is disabled") {
|
||||
t.Fatalf("Expected permission error but got %#v", err)
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user