Support for local port forwarding (#38)

* Support local port forwarding
* refactor testSession to return ssh client as well
* Tests for local port forwarding
This commit is contained in:
Mahmood Ali 2017-04-28 18:54:12 -04:00 committed by Jeff Lindsay
parent 1051a0d154
commit 20a454724d
7 changed files with 166 additions and 20 deletions

@ -7,7 +7,7 @@ func TestSetPermissions(t *testing.T) {
permsExt := map[string]string{
"foo": "bar",
}
session, cleanup := newTestSessionWithOptions(t, &Server{
session, _, cleanup := newTestSessionWithOptions(t, &Server{
Handler: func(s Session) {
if _, ok := s.Permissions().Extensions["foo"]; !ok {
t.Fatalf("got %#v; want %#v", s.Permissions().Extensions, permsExt)
@ -29,7 +29,7 @@ func TestSetValue(t *testing.T) {
"foo": "bar",
}
key := "testValue"
session, cleanup := newTestSessionWithOptions(t, &Server{
session, _, cleanup := newTestSessionWithOptions(t, &Server{
Handler: func(s Session) {
v := s.Context().Value(key).(map[string]string)
if v["foo"] != value["foo"] {

@ -7,7 +7,7 @@ import (
gossh "golang.org/x/crypto/ssh"
)
func newTestSessionWithOptions(t *testing.T, srv *Server, cfg *gossh.ClientConfig, options ...Option) (*gossh.Session, func()) {
func newTestSessionWithOptions(t *testing.T, srv *Server, cfg *gossh.ClientConfig, options ...Option) (*gossh.Session, *gossh.Client, func()) {
for _, option := range options {
if err := srv.SetOption(option); err != nil {
t.Fatal(err)
@ -20,7 +20,7 @@ func TestPasswordAuth(t *testing.T) {
t.Parallel()
testUser := "testuser"
testPass := "testpass"
session, cleanup := newTestSessionWithOptions(t, &Server{
session, _, cleanup := newTestSessionWithOptions(t, &Server{
Handler: func(s Session) {
// noop
},

@ -17,9 +17,10 @@ type Server struct {
HostSigners []Signer // private keys for the host key, must have at least one
Version string // server version to be sent before the initial handshake
PasswordHandler PasswordHandler // password authentication handler
PublicKeyHandler PublicKeyHandler // public key authentication handler
PtyCallback PtyCallback // callback for allowing PTY sessions, allows all if nil
PasswordHandler PasswordHandler // password authentication handler
PublicKeyHandler PublicKeyHandler // public key authentication handler
PtyCallback PtyCallback // callback for allowing PTY sessions, allows all if nil
LocalPortForwardingCallback LocalPortForwardingCallback // callback for allowing local port forwarding, denies all if nil
channelHandlers map[string]channelHandler
}
@ -40,7 +41,8 @@ func (srv *Server) ensureHostSigner() error {
func (srv *Server) config(ctx *sshContext) *gossh.ServerConfig {
srv.channelHandlers = map[string]channelHandler{
"session": sessionHandler,
"session": sessionHandler,
"direct-tcpip": directTcpipHandler,
}
config := &gossh.ServerConfig{}
for _, signer := range srv.HostSigners {

@ -32,7 +32,7 @@ func newLocalListener() net.Listener {
return l
}
func newClientSession(t *testing.T, addr string, config *gossh.ClientConfig) (*gossh.Session, func()) {
func newClientSession(t *testing.T, addr string, config *gossh.ClientConfig) (*gossh.Session, *gossh.Client, func()) {
if config == nil {
config = &gossh.ClientConfig{
User: "testuser",
@ -52,13 +52,13 @@ func newClientSession(t *testing.T, addr string, config *gossh.ClientConfig) (*g
if err != nil {
t.Fatal(err)
}
return session, func() {
return session, client, func() {
session.Close()
client.Close()
}
}
func newTestSession(t *testing.T, srv *Server, cfg *gossh.ClientConfig) (*gossh.Session, func()) {
func newTestSession(t *testing.T, srv *Server, cfg *gossh.ClientConfig) (*gossh.Session, *gossh.Client, func()) {
l := newLocalListener()
go srv.serveOnce(l)
return newClientSession(t, l.Addr().String(), cfg)
@ -67,7 +67,7 @@ func newTestSession(t *testing.T, srv *Server, cfg *gossh.ClientConfig) (*gossh.
func TestStdout(t *testing.T) {
t.Parallel()
testBytes := []byte("Hello world\n")
session, cleanup := newTestSession(t, &Server{
session, _, cleanup := newTestSession(t, &Server{
Handler: func(s Session) {
s.Write(testBytes)
},
@ -86,7 +86,7 @@ func TestStdout(t *testing.T) {
func TestStderr(t *testing.T) {
t.Parallel()
testBytes := []byte("Hello world\n")
session, cleanup := newTestSession(t, &Server{
session, _, cleanup := newTestSession(t, &Server{
Handler: func(s Session) {
s.Stderr().Write(testBytes)
},
@ -105,7 +105,7 @@ func TestStderr(t *testing.T) {
func TestStdin(t *testing.T) {
t.Parallel()
testBytes := []byte("Hello world\n")
session, cleanup := newTestSession(t, &Server{
session, _, cleanup := newTestSession(t, &Server{
Handler: func(s Session) {
io.Copy(s, s) // stdin back into stdout
},
@ -125,7 +125,7 @@ func TestStdin(t *testing.T) {
func TestUser(t *testing.T) {
t.Parallel()
testUser := []byte("progrium")
session, cleanup := newTestSession(t, &Server{
session, _, cleanup := newTestSession(t, &Server{
Handler: func(s Session) {
io.WriteString(s, s.User())
},
@ -145,7 +145,7 @@ func TestUser(t *testing.T) {
func TestDefaultExitStatusZero(t *testing.T) {
t.Parallel()
session, cleanup := newTestSession(t, &Server{
session, _, cleanup := newTestSession(t, &Server{
Handler: func(s Session) {
// noop
},
@ -159,7 +159,7 @@ func TestDefaultExitStatusZero(t *testing.T) {
func TestExplicitExitStatusZero(t *testing.T) {
t.Parallel()
session, cleanup := newTestSession(t, &Server{
session, _, cleanup := newTestSession(t, &Server{
Handler: func(s Session) {
s.Exit(0)
},
@ -173,7 +173,7 @@ func TestExplicitExitStatusZero(t *testing.T) {
func TestExitStatusNonZero(t *testing.T) {
t.Parallel()
session, cleanup := newTestSession(t, &Server{
session, _, cleanup := newTestSession(t, &Server{
Handler: func(s Session) {
s.Exit(1)
},
@ -195,7 +195,7 @@ func TestPty(t *testing.T) {
winWidth := 40
winHeight := 80
done := make(chan bool)
session, cleanup := newTestSession(t, &Server{
session, _, cleanup := newTestSession(t, &Server{
Handler: func(s Session) {
ptyReq, _, isPty := s.Pty()
if !isPty {
@ -230,7 +230,7 @@ func TestPtyResize(t *testing.T) {
winch2 := Window{20, 40}
winches := make(chan Window)
done := make(chan bool)
session, cleanup := newTestSession(t, &Server{
session, _, cleanup := newTestSession(t, &Server{
Handler: func(s Session) {
ptyReq, winCh, isPty := s.Pty()
if !isPty {

3
ssh.go

@ -42,6 +42,9 @@ type PasswordHandler func(ctx Context, password string) bool
// PtyCallback is a hook for allowing PTY sessions.
type PtyCallback func(ctx Context, pty Pty) bool
// LocalPortForwardingCallback is a hook for allowing port forwarding
type LocalPortForwardingCallback func(ctx Context, destinationHost string, destinationPort uint32) bool
// Window represents the size of a PTY window.
type Window struct {
Width int

58
tcpip.go Normal file

@ -0,0 +1,58 @@
package ssh
import (
"fmt"
"io"
"net"
gossh "golang.org/x/crypto/ssh"
)
// direct-tcpip data struct as specified in RFC4254, Section 7.2
type forwardData struct {
DestinationHost string
DestinationPort uint32
OriginatorHost string
OriginatorPort uint32
}
func directTcpipHandler(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx *sshContext) {
d := forwardData{}
if err := gossh.Unmarshal(newChan.ExtraData(), &d); err != nil {
newChan.Reject(gossh.ConnectionFailed, "error parsing forward data: "+err.Error())
return
}
if srv.LocalPortForwardingCallback == nil || !srv.LocalPortForwardingCallback(ctx, d.DestinationHost, d.DestinationPort) {
newChan.Reject(gossh.Prohibited, "port forwarding is disabled")
return
}
dest := fmt.Sprintf("%s:%d", d.DestinationHost, d.DestinationPort)
var dialer net.Dialer
dconn, err := dialer.DialContext(ctx, "tcp", dest)
if err != nil {
newChan.Reject(gossh.ConnectionFailed, err.Error())
return
}
ch, reqs, err := newChan.Accept()
if err != nil {
dconn.Close()
return
}
go gossh.DiscardRequests(reqs)
go func() {
defer ch.Close()
defer dconn.Close()
io.Copy(ch, dconn)
}()
go func() {
defer ch.Close()
defer dconn.Close()
io.Copy(dconn, ch)
}()
}

83
tcpip_test.go Normal file

@ -0,0 +1,83 @@
package ssh
import (
"bytes"
"fmt"
"io/ioutil"
"net"
"strings"
"testing"
gossh "golang.org/x/crypto/ssh"
)
var sampleServerResponse = []byte("Hello world")
func sampleSocketServer() net.Listener {
l := newLocalListener()
go func() {
conn, err := l.Accept()
if err != nil {
return
}
conn.Write(sampleServerResponse)
conn.Close()
}()
return l
}
func newTestSessionWithForwarding(t *testing.T, forwardingEnabled bool) (net.Listener, *gossh.Client, func()) {
l := sampleSocketServer()
_, client, cleanup := newTestSession(t, &Server{
Handler: func(s Session) {},
LocalPortForwardingCallback: func(ctx Context, destinationHost string, destinationPort uint32) bool {
addr := fmt.Sprintf("%s:%d", destinationHost, destinationPort)
if addr != l.Addr().String() {
panic("unexpected destinationHost: " + addr)
}
return forwardingEnabled
},
}, nil)
return l, client, func() {
cleanup()
l.Close()
}
}
func TestLocalPortForwardingWorks(t *testing.T) {
t.Parallel()
l, client, cleanup := newTestSessionWithForwarding(t, true)
defer cleanup()
conn, err := client.Dial("tcp", l.Addr().String())
if err != nil {
t.Fatalf("Error connecting to %v: %v", l.Addr().String(), err)
}
result, err := ioutil.ReadAll(conn)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(result, sampleServerResponse) {
t.Fatalf("result = %#v; want %#v", result, sampleServerResponse)
}
}
func TestLocalPortForwardingRespectsCallback(t *testing.T) {
t.Parallel()
l, client, cleanup := newTestSessionWithForwarding(t, false)
defer cleanup()
_, err := client.Dial("tcp", l.Addr().String())
if err == nil {
t.Fatalf("Expected error connecting to %v but it succeeded", l.Addr().String())
}
if !strings.Contains(err.Error(), "port forwarding is disabled") {
t.Fatalf("Expected permission error but got %#v", err)
}
}