add header test

This commit is contained in:
mo 2020-08-05 10:34:29 +08:00
parent af6ce456f9
commit ceaf26cca1
10 changed files with 249 additions and 66 deletions

@ -48,4 +48,5 @@ if err := server.ListenAndServe("tcp", "127.0.0.1:8000"); err != nil {
```
# Reference
original armon go-sock5 [go-sock5](https://github.com/armon/go-socks5)
- [rfc1928](https://www.ietf.org/rfc/rfc1928.txt)
- original armon go-sock5 [go-sock5](https://github.com/armon/go-socks5)

@ -15,4 +15,5 @@ func main() {
if err := server.ListenAndServe("tcp", "127.0.0.1:1080"); err != nil {
panic(err)
}
}

5
go.mod

@ -2,4 +2,7 @@ module github.com/thinkgos/go-socks5
go 1.14
require golang.org/x/net v0.0.0-20200324143707-d3edc9973b7e
require (
github.com/stretchr/testify v1.6.1
golang.org/x/net v0.0.0-20200324143707-d3edc9973b7e
)

9
go.sum

@ -1,3 +1,9 @@
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/net v0.0.0-20200324143707-d3edc9973b7e h1:3G+cUijn7XD+S4eJFddp53Pv7+slrESplyjG25HgL+k=
golang.org/x/net v0.0.0-20200324143707-d3edc9973b7e/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A=
@ -5,3 +11,6 @@ golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5h
golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd h1:xhmwyvizuTgC2qz7ZlMluP20uW+C3Rm0FD/WLDX8884=
golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

105
header.go

@ -18,8 +18,8 @@ const (
CommandAssociate = uint8(3)
// address type
ATYPIPv4 = uint8(1)
ATYPDomain = uint8(3) // domain
ATYPIPV6 = uint8(4)
ATYPDomain = uint8(3)
ATYPIPv6 = uint8(4)
)
// reply status
@ -36,13 +36,13 @@ const (
// 0x09 - 0xff unassigned
)
// head len defined
// Header represents the SOCKS4/SOCKS5 head len defined
const (
headVERLen = 1
headCMDLen = 1
headRSVLen = 1
headATYPLen = 1
headPORTLen = 2
headerVERLen = 1
headerCMDLen = 1
headerRSVLen = 1 // only socks5 support
headerATYPLen = 1
headerPORTLen = 2
)
// AddrSpec is used to return the target AddrSpec
@ -63,9 +63,9 @@ func (a *AddrSpec) String() string {
}
// Address returns a string which may be specified
// if IPv4,IPv6 will return ip:port
// if FQDN will return domain ip:port
// note: do not used to dial
// if IPv4/IPv6 will return < ip:port >
// if FQDN will return < domain ip:port >
// Note: do not used to dial, Please use String
func (a AddrSpec) Address() string {
if a.FQDN != "" {
return fmt.Sprintf("%s (%s):%d", a.FQDN, a.IP, a.Port)
@ -73,7 +73,19 @@ func (a AddrSpec) Address() string {
return fmt.Sprintf("%s:%d", a.IP, a.Port)
}
// Header represents the SOCKS5/SOCKS4 header, it contains everything that is not payload
// Header represents the SOCKS4/SOCKS5 header, it contains everything that is not payload
// The SOCKS4 request/response is formed as follows:
// +-----+-----+------+------+
// | VER | CMD | PORT | IPV4 |
// +-----+-----+------+------+
// | 1 | 1 | 2 | 2 |
// +-----+-----+------+------+
// The SOCKS5 request/response is formed as follows:
// +-----+-----+-------+------+----------------+----------------+
// | VER | CMD | RSV | ATYP | [DST/BND].ADDR | [DST/BND].PORT |
// +-----+-----+-------+------+----------------+----------------+
// | 1 | 1 | X'00' | 1 | Variable | 2 |
// +-----+-----+-------+------+----------------+----------------+
type Header struct {
// Version of socks protocol for message
Version uint8
@ -87,10 +99,10 @@ type Header struct {
addrType uint8
}
// Parse to header
func Parse(r io.Reader) (hd Header, err error) {
// ParseHeader to header from io.Reader
func ParseHeader(r io.Reader) (hd Header, err error) {
// Read the version and command
tmp := make([]byte, headVERLen+headCMDLen)
tmp := make([]byte, headerVERLen+headerCMDLen)
if _, err = io.ReadFull(r, tmp); err != nil {
return hd, fmt.Errorf("failed to get header version and command, %v", err)
}
@ -102,19 +114,11 @@ func Parse(r io.Reader) (hd Header, err error) {
}
if hd.Version == VersionSocks4 && hd.Command == CommandAssociate {
return hd, fmt.Errorf("wrong version for command")
return hd, fmt.Errorf("SOCKS4 version not support command: associate")
}
if hd.Version == VersionSocks4 {
// read port and ipv4 ip
tmp = make([]byte, headPORTLen+net.IPv4len)
if _, err = io.ReadFull(r, tmp); err != nil {
return hd, fmt.Errorf("failed to get socks4 header port and ip, %v", err)
}
hd.Address.Port = buildPort(tmp[0], tmp[1])
hd.Address.IP = tmp[2:]
} else if hd.Version == VersionSocks5 {
tmp = make([]byte, headRSVLen+headATYPLen)
if hd.Version == VersionSocks5 {
tmp = make([]byte, headerRSVLen+headerATYPLen)
if _, err = io.ReadFull(r, tmp); err != nil {
return hd, fmt.Errorf("failed to get header RSV and address type, %v", err)
}
@ -125,21 +129,21 @@ func Parse(r io.Reader) (hd Header, err error) {
if _, err = io.ReadFull(r, tmp[:1]); err != nil {
return hd, fmt.Errorf("failed to get header, %v", err)
}
addrLen := int(tmp[0])
addr := make([]byte, addrLen+2)
domainLen := int(tmp[0])
addr := make([]byte, domainLen+headerPORTLen)
if _, err = io.ReadFull(r, addr); err != nil {
return hd, fmt.Errorf("failed to get header, %v", err)
}
hd.Address.FQDN = string(addr[:addrLen])
hd.Address.Port = buildPort(addr[addrLen], addr[addrLen+1])
hd.Address.FQDN = string(addr[:domainLen])
hd.Address.Port = buildPort(addr[domainLen], addr[domainLen+1])
case ATYPIPv4:
addr := make([]byte, net.IPv4len+2)
if _, err = io.ReadFull(r, addr); err != nil {
return hd, fmt.Errorf("failed to get header, %v", err)
}
hd.Address.IP = addr[:net.IPv4len]
hd.Address.IP = net.IPv4(addr[0], addr[1], addr[2], addr[3])
hd.Address.Port = buildPort(addr[net.IPv4len], addr[net.IPv4len+1])
case ATYPIPV6:
case ATYPIPv6:
addr := make([]byte, net.IPv6len+2)
if _, err = io.ReadFull(r, addr); err != nil {
return hd, fmt.Errorf("failed to get header, %v", err)
@ -149,19 +153,39 @@ func Parse(r io.Reader) (hd Header, err error) {
default:
return hd, errUnrecognizedAddrType
}
} else { // Socks4
// read port and ipv4 ip
tmp = make([]byte, headerPORTLen+net.IPv4len)
if _, err = io.ReadFull(r, tmp); err != nil {
return hd, fmt.Errorf("failed to get socks4 header port and ip, %v", err)
}
hd.Address.Port = buildPort(tmp[0], tmp[1])
hd.Address.IP = net.IPv4(tmp[2], tmp[3], tmp[4], tmp[5])
}
return hd, nil
}
// Bytes returns a slice of header
func (h Header) Bytes() (b []byte) {
b = append(b, h.Version)
b = append(b, h.Command)
hiPort, loPort := breakPort(h.Address.Port)
if h.Version == VersionSocks4 {
b = make([]byte, 0, headerVERLen+headerCMDLen+headerPORTLen+net.IPv4len)
b = append(b, h.Version)
b = append(b, h.Command)
b = append(b, hiPort, loPort)
b = append(b, h.Address.IP...)
b = append(b, h.Address.IP.To4()...)
} else if h.Version == VersionSocks5 {
length := headerVERLen + headerCMDLen + headerRSVLen + headerATYPLen + headerPORTLen
if h.addrType == ATYPDomain {
length += 1 + len(h.Address.FQDN)
} else if h.addrType == ATYPIPv4 {
length += net.IPv4len
} else if h.addrType == ATYPIPv6 {
length += net.IPv6len
}
b = make([]byte, 0, length)
b = append(b, h.Version)
b = append(b, h.Command)
b = append(b, h.Reserved)
b = append(b, h.addrType)
if h.addrType == ATYPDomain {
@ -169,7 +193,7 @@ func (h Header) Bytes() (b []byte) {
b = append(b, []byte(h.Address.FQDN)...)
} else if h.addrType == ATYPIPv4 {
b = append(b, h.Address.IP.To4()...)
} else if h.addrType == ATYPIPV6 {
} else if h.addrType == ATYPIPv6 {
b = append(b, h.Address.IP.To16()...)
}
b = append(b, hiPort, loPort)
@ -177,10 +201,5 @@ func (h Header) Bytes() (b []byte) {
return b
}
func buildPort(hi, lo byte) int {
return (int(hi) << 8) | int(lo)
}
func breakPort(port int) (hi, lo byte) {
return byte(port >> 8), byte(port)
}
func buildPort(hi, lo byte) int { return (int(hi) << 8) | int(lo) }
func breakPort(port int) (hi, lo byte) { return byte(port >> 8), byte(port) }

146
header_test.go Normal file

@ -0,0 +1,146 @@
package socks5
import (
"bytes"
"io"
"net"
"reflect"
"testing"
"github.com/stretchr/testify/assert"
)
func TestAddrSpecAddr(t *testing.T) {
addr1 := AddrSpec{
IP: net.IPv4(127, 0, 0, 1),
Port: 8080,
}
assert.Equal(t, "127.0.0.1:8080", addr1.String())
assert.Equal(t, "127.0.0.1:8080", addr1.Address())
addr2 := AddrSpec{
FQDN: "localhost",
IP: net.IPv4(127, 0, 0, 1),
Port: 8080,
}
assert.Equal(t, "127.0.0.1:8080", addr2.String())
assert.Equal(t, "localhost (127.0.0.1):8080", addr2.Address())
}
func TestParseHeader(t *testing.T) {
type args struct {
r io.Reader
}
tests := []struct {
name string
args args
wantHd Header
wantErr bool
}{
{
"SOCKS5 IPV4",
args{bytes.NewReader([]byte{VersionSocks5, CommandConnect, 0, ATYPIPv4, 127, 0, 0, 1, 0x1f, 0x90})},
Header{
VersionSocks5, CommandConnect, 0,
AddrSpec{IP: net.IPv4(127, 0, 0, 1), Port: 8080},
ATYPIPv4,
},
false,
},
{
"SOCKS5 IPV6",
args{bytes.NewReader([]byte{VersionSocks5, CommandConnect, 0, ATYPIPv6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x1f, 0x90})},
Header{
VersionSocks5, CommandConnect, 0,
AddrSpec{IP: net.IPv6zero, Port: 8080},
ATYPIPv6,
},
false,
},
{
"SOCKS5 FQDN",
args{bytes.NewReader([]byte{VersionSocks5, CommandConnect, 0, ATYPDomain, 9, 'l', 'o', 'c', 'a', 'l', 'h', 'o', 's', 't', 0x1f, 0x90})},
Header{
VersionSocks5, CommandConnect, 0,
AddrSpec{FQDN: "localhost", Port: 8080},
ATYPDomain,
},
false,
},
{
"SOCKS4",
args{bytes.NewReader([]byte{VersionSocks4, CommandConnect, 0x1f, 0x90, 127, 0, 0, 1})},
Header{
VersionSocks4, CommandConnect, 0,
AddrSpec{IP: net.IPv4(127, 0, 0, 1), Port: 8080},
0,
},
false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotHd, err := ParseHeader(tt.args.r)
if (err != nil) != tt.wantErr {
t.Errorf("ParseHeader() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(gotHd, tt.wantHd) {
t.Errorf("ParseHeader() gotHd = %+v, want %+v", gotHd, tt.wantHd)
}
})
}
}
func TestHeader_Bytes(t *testing.T) {
tests := []struct {
name string
header Header
wantB []byte
}{
{
"SOCKS5 IPV4",
Header{
VersionSocks5, CommandConnect, 0,
AddrSpec{IP: net.IPv4(127, 0, 0, 1), Port: 8080},
ATYPIPv4,
},
[]byte{VersionSocks5, CommandConnect, 0, ATYPIPv4, 127, 0, 0, 1, 0x1f, 0x90},
},
{
"SOCKS5 IPV6",
Header{
VersionSocks5, CommandConnect, 0,
AddrSpec{IP: net.IPv6zero, Port: 8080},
ATYPIPv6,
},
[]byte{VersionSocks5, CommandConnect, 0, ATYPIPv6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x1f, 0x90},
},
{
"SOCKS5 FQDN",
Header{
VersionSocks5, CommandConnect, 0,
AddrSpec{FQDN: "localhost", Port: 8080},
ATYPDomain,
},
[]byte{VersionSocks5, CommandConnect, 0, ATYPDomain, 9, 'l', 'o', 'c', 'a', 'l', 'h', 'o', 's', 't', 0x1f, 0x90},
},
{
"SOCKS4",
Header{
VersionSocks4, CommandConnect, 0,
AddrSpec{IP: net.IPv4(127, 0, 0, 1), Port: 8080},
0,
},
[]byte{VersionSocks4, CommandConnect, 0x1f, 0x90, 127, 0, 0, 1},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if gotB := tt.header.Bytes(); !reflect.DeepEqual(gotB, tt.wantB) {
t.Errorf("Bytes() = %v, want %v", gotB, tt.wantB)
}
})
}
}

@ -48,7 +48,7 @@ func NewPacket(destAddr string, data []byte) (p Packet, err error) {
p.ATYP = ATYPIPv4
p.DstAddr.IP = ip
} else {
p.ATYP = ATYPIPV6
p.ATYP = ATYPIPv6
p.DstAddr.IP = ip
}
} else {
@ -63,7 +63,7 @@ func NewPacket(destAddr string, data []byte) (p Packet, err error) {
return
}
// Parse parse to packet
// ParseHeader parse to packet
func (sf *Packet) Parse(b []byte) error {
if len(b) <= 4+net.IPv4len+2 { // no data
return errors.New("too short")
@ -79,7 +79,7 @@ func (sf *Packet) Parse(b []byte) error {
headLen += net.IPv4len + 2
sf.DstAddr.IP = net.IPv4(b[4], b[5], b[6], b[7])
sf.DstAddr.Port = buildPort(b[4+net.IPv4len], b[4+net.IPv4len+1])
case ATYPIPV6:
case ATYPIPv6:
headLen += net.IPv6len + 2
if len(b) <= headLen {
return errors.New("too short")
@ -112,8 +112,8 @@ func (sf *Packet) Header() []byte {
case ATYPIPv4:
bs = append(bs, ATYPIPv4)
bs = append(bs, sf.DstAddr.IP...)
case ATYPIPV6:
bs = append(bs, ATYPIPV6)
case ATYPIPv6:
bs = append(bs, ATYPIPv6)
bs = append(bs, sf.DstAddr.IP...)
case ATYPDomain:
bs = append(bs, ATYPDomain)

@ -35,22 +35,17 @@ type Request struct {
RawDestAddr *AddrSpec
}
type conn interface {
Write([]byte) (int, error)
RemoteAddr() net.Addr
}
// NewRequest creates a new Request from the tcp connection
func NewRequest(bufConn io.Reader) (*Request, error) {
/*
The SOCKS request is formed as follows:
+----+-----+-------+------+----------+----------+
|VER | CMD | RSV | ATYP | DST.ADDR | DST.PORT |
+----+-----+-------+------+----------+----------+
| 1 | 1 | X'00' | 1 | Variable | 2 |
+----+-----+-------+------+----------+----------+
+-----+-----+-------+------+----------+----------+
| VER | CMD | RSV | ATYP | DST.ADDR | DST.PORT |
+-----+-----+-------+------+----------+----------+
| 1 | 1 | X'00' | 1 | Variable | 2 |
+-----+-----+-------+------+----------+----------+
*/
hd, err := Parse(bufConn)
hd, err := ParseHeader(bufConn)
if err != nil {
return nil, err
}
@ -354,7 +349,7 @@ func SendReply(w io.Writer, head Header, resp uint8, bindAddr ...net.Addr) error
head.Address.IP = addrSpec.IP.To4()
head.Address.Port = addrSpec.Port
case addrSpec.IP.To16() != nil:
head.addrType = ATYPIPV6
head.addrType = ATYPIPv6
head.Address.IP = addrSpec.IP.To16()
head.Address.Port = addrSpec.Port
default:

@ -110,18 +110,27 @@ func (s *Server) Serve(l net.Listener) error {
}
// ServeConn is used to serve a single connection.
func (s *Server) ServeConn(conn net.Conn) (err error) {
func (s *Server) ServeConn(conn net.Conn) error {
defer conn.Close()
bufConn := bufio.NewReader(conn)
/*
The SOCKS handshake is formed as follows:
+-----+----------+---------------+
| VER | NMETHODS | METHODS |
+-----+----------+---------------+
| 1 | 1 | X'00' - X'FF' |
+-----+----------+---------------+
*/
// Read the version byte
version := []byte{0}
if _, err = bufConn.Read(version); err != nil {
if _, err := bufConn.Read(version); err != nil {
s.logger.Errorf("failed to get version byte: %v", err)
return err
}
var authContext *AuthContext
var err error
// Ensure we are compatible
if version[0] == VersionSocks5 {
// Authenticate the connection
@ -137,6 +146,7 @@ func (s *Server) ServeConn(conn net.Conn) (err error) {
return err
}
// The client request detail
request, err := NewRequest(bufConn)
if err != nil {
if err == errUnrecognizedAddrType {
@ -157,7 +167,6 @@ func (s *Server) ServeConn(conn net.Conn) (err error) {
s.logger.Errorf("%v", err)
return err
}
return nil
}

@ -200,7 +200,7 @@ func TestSOCKS5_Associate(t *testing.T) {
t.Fatalf("bad: %v", out)
}
rspHead, err := Parse(conn)
rspHead, err := ParseHeader(conn)
if err != nil {
t.Fatalf("bad response header: %v", err)
}