2017-04-28 22:54:12 +00:00
|
|
|
package ssh
|
|
|
|
|
|
|
|
import (
|
|
|
|
"bytes"
|
|
|
|
"io/ioutil"
|
|
|
|
"net"
|
Support port forwarding of literal IPv6 addresses (#85)
* Support port forwarding of literal IPv6 addresses
To disambiguate between colons as host:port separators and as IPv6 address separators, literal IPv6 addresses use square brackets around the address (https://en.wikipedia.org/wiki/IPv6_address#Literal_IPv6_addresses_in_network_resource_identifiers). So host ::1, port 22 is written as [::1]:22, and therefore a simple concatenation of host, colon, and port doesn't work. Fortunately net.JoinHostPort already implements this functionality, so with a bit of type gymnastics we can generate dest in an IPv6-safe way.
* Support port forwarding of literal IPv6 addresses
To disambiguate between colons as host:port separators and as IPv6 address separators, literal IPv6 addresses use square brackets around the address (https://en.wikipedia.org/wiki/IPv6_address#Literal_IPv6_addresses_in_network_resource_identifiers). So host ::1, port 22 is written as [::1]:22, and therefore a simple concatenation of host, colon, and port doesn't work. Fortunately net.JoinHostPort already implements this functionality, so with a bit of type gymnastics we can generate dest in an IPv6-safe way.
2018-09-24 00:41:38 +00:00
|
|
|
"strconv"
|
2017-04-28 22:54:12 +00:00
|
|
|
"strings"
|
|
|
|
"testing"
|
|
|
|
|
|
|
|
gossh "golang.org/x/crypto/ssh"
|
|
|
|
)
|
|
|
|
|
|
|
|
var sampleServerResponse = []byte("Hello world")
|
|
|
|
|
|
|
|
func sampleSocketServer() net.Listener {
|
|
|
|
l := newLocalListener()
|
|
|
|
|
|
|
|
go func() {
|
|
|
|
conn, err := l.Accept()
|
|
|
|
if err != nil {
|
|
|
|
return
|
|
|
|
}
|
|
|
|
conn.Write(sampleServerResponse)
|
|
|
|
conn.Close()
|
|
|
|
}()
|
|
|
|
|
|
|
|
return l
|
|
|
|
}
|
|
|
|
|
|
|
|
func newTestSessionWithForwarding(t *testing.T, forwardingEnabled bool) (net.Listener, *gossh.Client, func()) {
|
|
|
|
l := sampleSocketServer()
|
|
|
|
|
|
|
|
_, client, cleanup := newTestSession(t, &Server{
|
|
|
|
Handler: func(s Session) {},
|
|
|
|
LocalPortForwardingCallback: func(ctx Context, destinationHost string, destinationPort uint32) bool {
|
Support port forwarding of literal IPv6 addresses (#85)
* Support port forwarding of literal IPv6 addresses
To disambiguate between colons as host:port separators and as IPv6 address separators, literal IPv6 addresses use square brackets around the address (https://en.wikipedia.org/wiki/IPv6_address#Literal_IPv6_addresses_in_network_resource_identifiers). So host ::1, port 22 is written as [::1]:22, and therefore a simple concatenation of host, colon, and port doesn't work. Fortunately net.JoinHostPort already implements this functionality, so with a bit of type gymnastics we can generate dest in an IPv6-safe way.
* Support port forwarding of literal IPv6 addresses
To disambiguate between colons as host:port separators and as IPv6 address separators, literal IPv6 addresses use square brackets around the address (https://en.wikipedia.org/wiki/IPv6_address#Literal_IPv6_addresses_in_network_resource_identifiers). So host ::1, port 22 is written as [::1]:22, and therefore a simple concatenation of host, colon, and port doesn't work. Fortunately net.JoinHostPort already implements this functionality, so with a bit of type gymnastics we can generate dest in an IPv6-safe way.
2018-09-24 00:41:38 +00:00
|
|
|
addr := net.JoinHostPort(destinationHost, strconv.FormatInt(int64(destinationPort), 10))
|
2017-04-28 22:54:12 +00:00
|
|
|
if addr != l.Addr().String() {
|
|
|
|
panic("unexpected destinationHost: " + addr)
|
|
|
|
}
|
|
|
|
return forwardingEnabled
|
|
|
|
},
|
|
|
|
}, nil)
|
|
|
|
|
|
|
|
return l, client, func() {
|
|
|
|
cleanup()
|
|
|
|
l.Close()
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func TestLocalPortForwardingWorks(t *testing.T) {
|
|
|
|
t.Parallel()
|
|
|
|
|
|
|
|
l, client, cleanup := newTestSessionWithForwarding(t, true)
|
|
|
|
defer cleanup()
|
|
|
|
|
|
|
|
conn, err := client.Dial("tcp", l.Addr().String())
|
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("Error connecting to %v: %v", l.Addr().String(), err)
|
|
|
|
}
|
|
|
|
result, err := ioutil.ReadAll(conn)
|
|
|
|
if err != nil {
|
|
|
|
t.Fatal(err)
|
|
|
|
}
|
|
|
|
if !bytes.Equal(result, sampleServerResponse) {
|
|
|
|
t.Fatalf("result = %#v; want %#v", result, sampleServerResponse)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func TestLocalPortForwardingRespectsCallback(t *testing.T) {
|
|
|
|
t.Parallel()
|
|
|
|
|
|
|
|
l, client, cleanup := newTestSessionWithForwarding(t, false)
|
|
|
|
defer cleanup()
|
|
|
|
|
|
|
|
_, err := client.Dial("tcp", l.Addr().String())
|
|
|
|
if err == nil {
|
|
|
|
t.Fatalf("Expected error connecting to %v but it succeeded", l.Addr().String())
|
|
|
|
}
|
|
|
|
if !strings.Contains(err.Error(), "port forwarding is disabled") {
|
|
|
|
t.Fatalf("Expected permission error but got %#v", err)
|
|
|
|
}
|
|
|
|
}
|