From ceaf26cca1d25e057f845c3813fea0f595fa00dc Mon Sep 17 00:00:00 2001 From: mo Date: Wed, 5 Aug 2020 10:34:29 +0800 Subject: [PATCH] add header test --- README.md | 3 +- _example/main.go | 1 + go.mod | 5 +- go.sum | 9 ++ header.go | 105 +++++++++++++--------- header_test.go | 146 +++++++++++++++++++++++++++++++ packet.go | 10 +-- request.go | 19 ++-- socks5.go => server.go | 15 +++- socks5_test.go => server_test.go | 2 +- 10 files changed, 249 insertions(+), 66 deletions(-) create mode 100644 header_test.go rename socks5.go => server.go (92%) rename socks5_test.go => server_test.go (99%) diff --git a/README.md b/README.md index 1dded13..7f667c5 100644 --- a/README.md +++ b/README.md @@ -48,4 +48,5 @@ if err := server.ListenAndServe("tcp", "127.0.0.1:8000"); err != nil { ``` # Reference -original armon go-sock5 [go-sock5](https://github.com/armon/go-socks5) \ No newline at end of file +- [rfc1928](https://www.ietf.org/rfc/rfc1928.txt) +- original armon go-sock5 [go-sock5](https://github.com/armon/go-socks5) \ No newline at end of file diff --git a/_example/main.go b/_example/main.go index 6f7e04f..8f53f62 100644 --- a/_example/main.go +++ b/_example/main.go @@ -15,4 +15,5 @@ func main() { if err := server.ListenAndServe("tcp", "127.0.0.1:1080"); err != nil { panic(err) } + } diff --git a/go.mod b/go.mod index f92ea03..23d5804 100644 --- a/go.mod +++ b/go.mod @@ -2,4 +2,7 @@ module github.com/thinkgos/go-socks5 go 1.14 -require golang.org/x/net v0.0.0-20200324143707-d3edc9973b7e +require ( + github.com/stretchr/testify v1.6.1 + golang.org/x/net v0.0.0-20200324143707-d3edc9973b7e +) diff --git a/go.sum b/go.sum index 91b3e17..d4fc9ea 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,9 @@ +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/net v0.0.0-20200324143707-d3edc9973b7e h1:3G+cUijn7XD+S4eJFddp53Pv7+slrESplyjG25HgL+k= golang.org/x/net v0.0.0-20200324143707-d3edc9973b7e/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= @@ -5,3 +11,6 @@ golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5h golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd h1:xhmwyvizuTgC2qz7ZlMluP20uW+C3Rm0FD/WLDX8884= golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/header.go b/header.go index 5821045..e886518 100644 --- a/header.go +++ b/header.go @@ -18,8 +18,8 @@ const ( CommandAssociate = uint8(3) // address type ATYPIPv4 = uint8(1) - ATYPDomain = uint8(3) // domain - ATYPIPV6 = uint8(4) + ATYPDomain = uint8(3) + ATYPIPv6 = uint8(4) ) // reply status @@ -36,13 +36,13 @@ const ( // 0x09 - 0xff unassigned ) -// head len defined +// Header represents the SOCKS4/SOCKS5 head len defined const ( - headVERLen = 1 - headCMDLen = 1 - headRSVLen = 1 - headATYPLen = 1 - headPORTLen = 2 + headerVERLen = 1 + headerCMDLen = 1 + headerRSVLen = 1 // only socks5 support + headerATYPLen = 1 + headerPORTLen = 2 ) // AddrSpec is used to return the target AddrSpec @@ -63,9 +63,9 @@ func (a *AddrSpec) String() string { } // Address returns a string which may be specified -// if IPv4,IPv6 will return ip:port -// if FQDN will return domain ip:port -// note: do not used to dial +// if IPv4/IPv6 will return < ip:port > +// if FQDN will return < domain ip:port > +// Note: do not used to dial, Please use String func (a AddrSpec) Address() string { if a.FQDN != "" { return fmt.Sprintf("%s (%s):%d", a.FQDN, a.IP, a.Port) @@ -73,7 +73,19 @@ func (a AddrSpec) Address() string { return fmt.Sprintf("%s:%d", a.IP, a.Port) } -// Header represents the SOCKS5/SOCKS4 header, it contains everything that is not payload +// Header represents the SOCKS4/SOCKS5 header, it contains everything that is not payload +// The SOCKS4 request/response is formed as follows: +// +-----+-----+------+------+ +// | VER | CMD | PORT | IPV4 | +// +-----+-----+------+------+ +// | 1 | 1 | 2 | 2 | +// +-----+-----+------+------+ +// The SOCKS5 request/response is formed as follows: +// +-----+-----+-------+------+----------------+----------------+ +// | VER | CMD | RSV | ATYP | [DST/BND].ADDR | [DST/BND].PORT | +// +-----+-----+-------+------+----------------+----------------+ +// | 1 | 1 | X'00' | 1 | Variable | 2 | +// +-----+-----+-------+------+----------------+----------------+ type Header struct { // Version of socks protocol for message Version uint8 @@ -87,10 +99,10 @@ type Header struct { addrType uint8 } -// Parse to header -func Parse(r io.Reader) (hd Header, err error) { +// ParseHeader to header from io.Reader +func ParseHeader(r io.Reader) (hd Header, err error) { // Read the version and command - tmp := make([]byte, headVERLen+headCMDLen) + tmp := make([]byte, headerVERLen+headerCMDLen) if _, err = io.ReadFull(r, tmp); err != nil { return hd, fmt.Errorf("failed to get header version and command, %v", err) } @@ -102,19 +114,11 @@ func Parse(r io.Reader) (hd Header, err error) { } if hd.Version == VersionSocks4 && hd.Command == CommandAssociate { - return hd, fmt.Errorf("wrong version for command") + return hd, fmt.Errorf("SOCKS4 version not support command: associate") } - if hd.Version == VersionSocks4 { - // read port and ipv4 ip - tmp = make([]byte, headPORTLen+net.IPv4len) - if _, err = io.ReadFull(r, tmp); err != nil { - return hd, fmt.Errorf("failed to get socks4 header port and ip, %v", err) - } - hd.Address.Port = buildPort(tmp[0], tmp[1]) - hd.Address.IP = tmp[2:] - } else if hd.Version == VersionSocks5 { - tmp = make([]byte, headRSVLen+headATYPLen) + if hd.Version == VersionSocks5 { + tmp = make([]byte, headerRSVLen+headerATYPLen) if _, err = io.ReadFull(r, tmp); err != nil { return hd, fmt.Errorf("failed to get header RSV and address type, %v", err) } @@ -125,21 +129,21 @@ func Parse(r io.Reader) (hd Header, err error) { if _, err = io.ReadFull(r, tmp[:1]); err != nil { return hd, fmt.Errorf("failed to get header, %v", err) } - addrLen := int(tmp[0]) - addr := make([]byte, addrLen+2) + domainLen := int(tmp[0]) + addr := make([]byte, domainLen+headerPORTLen) if _, err = io.ReadFull(r, addr); err != nil { return hd, fmt.Errorf("failed to get header, %v", err) } - hd.Address.FQDN = string(addr[:addrLen]) - hd.Address.Port = buildPort(addr[addrLen], addr[addrLen+1]) + hd.Address.FQDN = string(addr[:domainLen]) + hd.Address.Port = buildPort(addr[domainLen], addr[domainLen+1]) case ATYPIPv4: addr := make([]byte, net.IPv4len+2) if _, err = io.ReadFull(r, addr); err != nil { return hd, fmt.Errorf("failed to get header, %v", err) } - hd.Address.IP = addr[:net.IPv4len] + hd.Address.IP = net.IPv4(addr[0], addr[1], addr[2], addr[3]) hd.Address.Port = buildPort(addr[net.IPv4len], addr[net.IPv4len+1]) - case ATYPIPV6: + case ATYPIPv6: addr := make([]byte, net.IPv6len+2) if _, err = io.ReadFull(r, addr); err != nil { return hd, fmt.Errorf("failed to get header, %v", err) @@ -149,19 +153,39 @@ func Parse(r io.Reader) (hd Header, err error) { default: return hd, errUnrecognizedAddrType } + } else { // Socks4 + // read port and ipv4 ip + tmp = make([]byte, headerPORTLen+net.IPv4len) + if _, err = io.ReadFull(r, tmp); err != nil { + return hd, fmt.Errorf("failed to get socks4 header port and ip, %v", err) + } + hd.Address.Port = buildPort(tmp[0], tmp[1]) + hd.Address.IP = net.IPv4(tmp[2], tmp[3], tmp[4], tmp[5]) } return hd, nil } // Bytes returns a slice of header func (h Header) Bytes() (b []byte) { - b = append(b, h.Version) - b = append(b, h.Command) hiPort, loPort := breakPort(h.Address.Port) if h.Version == VersionSocks4 { + b = make([]byte, 0, headerVERLen+headerCMDLen+headerPORTLen+net.IPv4len) + b = append(b, h.Version) + b = append(b, h.Command) b = append(b, hiPort, loPort) - b = append(b, h.Address.IP...) + b = append(b, h.Address.IP.To4()...) } else if h.Version == VersionSocks5 { + length := headerVERLen + headerCMDLen + headerRSVLen + headerATYPLen + headerPORTLen + if h.addrType == ATYPDomain { + length += 1 + len(h.Address.FQDN) + } else if h.addrType == ATYPIPv4 { + length += net.IPv4len + } else if h.addrType == ATYPIPv6 { + length += net.IPv6len + } + b = make([]byte, 0, length) + b = append(b, h.Version) + b = append(b, h.Command) b = append(b, h.Reserved) b = append(b, h.addrType) if h.addrType == ATYPDomain { @@ -169,7 +193,7 @@ func (h Header) Bytes() (b []byte) { b = append(b, []byte(h.Address.FQDN)...) } else if h.addrType == ATYPIPv4 { b = append(b, h.Address.IP.To4()...) - } else if h.addrType == ATYPIPV6 { + } else if h.addrType == ATYPIPv6 { b = append(b, h.Address.IP.To16()...) } b = append(b, hiPort, loPort) @@ -177,10 +201,5 @@ func (h Header) Bytes() (b []byte) { return b } -func buildPort(hi, lo byte) int { - return (int(hi) << 8) | int(lo) -} - -func breakPort(port int) (hi, lo byte) { - return byte(port >> 8), byte(port) -} +func buildPort(hi, lo byte) int { return (int(hi) << 8) | int(lo) } +func breakPort(port int) (hi, lo byte) { return byte(port >> 8), byte(port) } diff --git a/header_test.go b/header_test.go new file mode 100644 index 0000000..3c1f837 --- /dev/null +++ b/header_test.go @@ -0,0 +1,146 @@ +package socks5 + +import ( + "bytes" + "io" + "net" + "reflect" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestAddrSpecAddr(t *testing.T) { + addr1 := AddrSpec{ + IP: net.IPv4(127, 0, 0, 1), + Port: 8080, + } + assert.Equal(t, "127.0.0.1:8080", addr1.String()) + assert.Equal(t, "127.0.0.1:8080", addr1.Address()) + + addr2 := AddrSpec{ + FQDN: "localhost", + IP: net.IPv4(127, 0, 0, 1), + Port: 8080, + } + assert.Equal(t, "127.0.0.1:8080", addr2.String()) + assert.Equal(t, "localhost (127.0.0.1):8080", addr2.Address()) +} + +func TestParseHeader(t *testing.T) { + type args struct { + r io.Reader + } + tests := []struct { + name string + args args + wantHd Header + wantErr bool + }{ + { + "SOCKS5 IPV4", + args{bytes.NewReader([]byte{VersionSocks5, CommandConnect, 0, ATYPIPv4, 127, 0, 0, 1, 0x1f, 0x90})}, + Header{ + VersionSocks5, CommandConnect, 0, + AddrSpec{IP: net.IPv4(127, 0, 0, 1), Port: 8080}, + ATYPIPv4, + }, + false, + }, + { + "SOCKS5 IPV6", + args{bytes.NewReader([]byte{VersionSocks5, CommandConnect, 0, ATYPIPv6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x1f, 0x90})}, + Header{ + VersionSocks5, CommandConnect, 0, + AddrSpec{IP: net.IPv6zero, Port: 8080}, + ATYPIPv6, + }, + false, + }, + { + "SOCKS5 FQDN", + args{bytes.NewReader([]byte{VersionSocks5, CommandConnect, 0, ATYPDomain, 9, 'l', 'o', 'c', 'a', 'l', 'h', 'o', 's', 't', 0x1f, 0x90})}, + Header{ + VersionSocks5, CommandConnect, 0, + AddrSpec{FQDN: "localhost", Port: 8080}, + ATYPDomain, + }, + false, + }, + { + "SOCKS4", + args{bytes.NewReader([]byte{VersionSocks4, CommandConnect, 0x1f, 0x90, 127, 0, 0, 1})}, + Header{ + VersionSocks4, CommandConnect, 0, + AddrSpec{IP: net.IPv4(127, 0, 0, 1), Port: 8080}, + 0, + }, + false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotHd, err := ParseHeader(tt.args.r) + if (err != nil) != tt.wantErr { + t.Errorf("ParseHeader() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if !reflect.DeepEqual(gotHd, tt.wantHd) { + t.Errorf("ParseHeader() gotHd = %+v, want %+v", gotHd, tt.wantHd) + } + }) + } +} + +func TestHeader_Bytes(t *testing.T) { + tests := []struct { + name string + header Header + wantB []byte + }{ + { + "SOCKS5 IPV4", + Header{ + VersionSocks5, CommandConnect, 0, + AddrSpec{IP: net.IPv4(127, 0, 0, 1), Port: 8080}, + ATYPIPv4, + }, + []byte{VersionSocks5, CommandConnect, 0, ATYPIPv4, 127, 0, 0, 1, 0x1f, 0x90}, + }, + { + "SOCKS5 IPV6", + Header{ + VersionSocks5, CommandConnect, 0, + AddrSpec{IP: net.IPv6zero, Port: 8080}, + ATYPIPv6, + }, + []byte{VersionSocks5, CommandConnect, 0, ATYPIPv6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x1f, 0x90}, + }, + { + "SOCKS5 FQDN", + Header{ + VersionSocks5, CommandConnect, 0, + AddrSpec{FQDN: "localhost", Port: 8080}, + ATYPDomain, + }, + []byte{VersionSocks5, CommandConnect, 0, ATYPDomain, 9, 'l', 'o', 'c', 'a', 'l', 'h', 'o', 's', 't', 0x1f, 0x90}, + }, + { + "SOCKS4", + Header{ + VersionSocks4, CommandConnect, 0, + AddrSpec{IP: net.IPv4(127, 0, 0, 1), Port: 8080}, + 0, + }, + []byte{VersionSocks4, CommandConnect, 0x1f, 0x90, 127, 0, 0, 1}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if gotB := tt.header.Bytes(); !reflect.DeepEqual(gotB, tt.wantB) { + t.Errorf("Bytes() = %v, want %v", gotB, tt.wantB) + } + }) + } +} diff --git a/packet.go b/packet.go index 28cdfe0..1acce00 100644 --- a/packet.go +++ b/packet.go @@ -48,7 +48,7 @@ func NewPacket(destAddr string, data []byte) (p Packet, err error) { p.ATYP = ATYPIPv4 p.DstAddr.IP = ip } else { - p.ATYP = ATYPIPV6 + p.ATYP = ATYPIPv6 p.DstAddr.IP = ip } } else { @@ -63,7 +63,7 @@ func NewPacket(destAddr string, data []byte) (p Packet, err error) { return } -// Parse parse to packet +// ParseHeader parse to packet func (sf *Packet) Parse(b []byte) error { if len(b) <= 4+net.IPv4len+2 { // no data return errors.New("too short") @@ -79,7 +79,7 @@ func (sf *Packet) Parse(b []byte) error { headLen += net.IPv4len + 2 sf.DstAddr.IP = net.IPv4(b[4], b[5], b[6], b[7]) sf.DstAddr.Port = buildPort(b[4+net.IPv4len], b[4+net.IPv4len+1]) - case ATYPIPV6: + case ATYPIPv6: headLen += net.IPv6len + 2 if len(b) <= headLen { return errors.New("too short") @@ -112,8 +112,8 @@ func (sf *Packet) Header() []byte { case ATYPIPv4: bs = append(bs, ATYPIPv4) bs = append(bs, sf.DstAddr.IP...) - case ATYPIPV6: - bs = append(bs, ATYPIPV6) + case ATYPIPv6: + bs = append(bs, ATYPIPv6) bs = append(bs, sf.DstAddr.IP...) case ATYPDomain: bs = append(bs, ATYPDomain) diff --git a/request.go b/request.go index 6f1ccaf..29ecc29 100644 --- a/request.go +++ b/request.go @@ -35,22 +35,17 @@ type Request struct { RawDestAddr *AddrSpec } -type conn interface { - Write([]byte) (int, error) - RemoteAddr() net.Addr -} - // NewRequest creates a new Request from the tcp connection func NewRequest(bufConn io.Reader) (*Request, error) { /* The SOCKS request is formed as follows: - +----+-----+-------+------+----------+----------+ - |VER | CMD | RSV | ATYP | DST.ADDR | DST.PORT | - +----+-----+-------+------+----------+----------+ - | 1 | 1 | X'00' | 1 | Variable | 2 | - +----+-----+-------+------+----------+----------+ + +-----+-----+-------+------+----------+----------+ + | VER | CMD | RSV | ATYP | DST.ADDR | DST.PORT | + +-----+-----+-------+------+----------+----------+ + | 1 | 1 | X'00' | 1 | Variable | 2 | + +-----+-----+-------+------+----------+----------+ */ - hd, err := Parse(bufConn) + hd, err := ParseHeader(bufConn) if err != nil { return nil, err } @@ -354,7 +349,7 @@ func SendReply(w io.Writer, head Header, resp uint8, bindAddr ...net.Addr) error head.Address.IP = addrSpec.IP.To4() head.Address.Port = addrSpec.Port case addrSpec.IP.To16() != nil: - head.addrType = ATYPIPV6 + head.addrType = ATYPIPv6 head.Address.IP = addrSpec.IP.To16() head.Address.Port = addrSpec.Port default: diff --git a/socks5.go b/server.go similarity index 92% rename from socks5.go rename to server.go index 7c58166..2ed45cb 100644 --- a/socks5.go +++ b/server.go @@ -110,18 +110,27 @@ func (s *Server) Serve(l net.Listener) error { } // ServeConn is used to serve a single connection. -func (s *Server) ServeConn(conn net.Conn) (err error) { +func (s *Server) ServeConn(conn net.Conn) error { defer conn.Close() bufConn := bufio.NewReader(conn) + /* + The SOCKS handshake is formed as follows: + +-----+----------+---------------+ + | VER | NMETHODS | METHODS | + +-----+----------+---------------+ + | 1 | 1 | X'00' - X'FF' | + +-----+----------+---------------+ + */ // Read the version byte version := []byte{0} - if _, err = bufConn.Read(version); err != nil { + if _, err := bufConn.Read(version); err != nil { s.logger.Errorf("failed to get version byte: %v", err) return err } var authContext *AuthContext + var err error // Ensure we are compatible if version[0] == VersionSocks5 { // Authenticate the connection @@ -137,6 +146,7 @@ func (s *Server) ServeConn(conn net.Conn) (err error) { return err } + // The client request detail request, err := NewRequest(bufConn) if err != nil { if err == errUnrecognizedAddrType { @@ -157,7 +167,6 @@ func (s *Server) ServeConn(conn net.Conn) (err error) { s.logger.Errorf("%v", err) return err } - return nil } diff --git a/socks5_test.go b/server_test.go similarity index 99% rename from socks5_test.go rename to server_test.go index 1d44de3..5489223 100644 --- a/socks5_test.go +++ b/server_test.go @@ -200,7 +200,7 @@ func TestSOCKS5_Associate(t *testing.T) { t.Fatalf("bad: %v", out) } - rspHead, err := Parse(conn) + rspHead, err := ParseHeader(conn) if err != nil { t.Fatalf("bad response header: %v", err) }