diff --git a/README.md b/README.md index 7f667c5..afa54b2 100644 --- a/README.md +++ b/README.md @@ -39,7 +39,7 @@ Below is a simple example of usage ```go // Create a SOCKS5 server -server := socks5.New() +server := socks5.NewServer() // Create SOCKS5 proxy on localhost port 8000 if err := server.ListenAndServe("tcp", "127.0.0.1:8000"); err != nil { diff --git a/client.go b/client.go deleted file mode 100644 index 16a0398..0000000 --- a/client.go +++ /dev/null @@ -1,221 +0,0 @@ -package socks5 - -import ( - "errors" - "net" - "time" - - "github.com/thinkgos/go-socks5/statute" -) - -// Client is socks5 client wrapper -type Client struct { - Server string - Version byte - UserName string - Password string - // On command UDP, let server control the tcp and udp connection relationship - TCPConn *net.TCPConn - UDPConn *net.UDPConn - RemoteAddress net.Addr - TCPDeadline int - TCPTimeout int - UDPDeadline int -} - -// This is just create a client, you need to use Dial to create conn -func NewClient(addr, username, password string, tcpTimeout, tcpDeadline, udpDeadline int) (*Client, error) { - c := &Client{ - Server: addr, - Version: statute.VersionSocks5, - UserName: username, - Password: password, - TCPTimeout: tcpTimeout, - TCPDeadline: tcpDeadline, - UDPDeadline: udpDeadline, - } - return c, nil -} - -func (c *Client) Close() error { - if c.UDPConn == nil { - return c.TCPConn.Close() - } - if c.TCPConn != nil { - c.TCPConn.Close() - } - return c.UDPConn.Close() -} - -func (c *Client) LocalAddr() net.Addr { - if c.UDPConn == nil { - return c.TCPConn.LocalAddr() - } - return c.UDPConn.LocalAddr() -} - -func (c *Client) RemoteAddr() net.Addr { - return c.RemoteAddress -} - -func (c *Client) SetDeadline(t time.Time) error { - if c.UDPConn == nil { - return c.TCPConn.SetDeadline(t) - } - return c.UDPConn.SetDeadline(t) -} - -func (c *Client) SetReadDeadline(t time.Time) error { - if c.UDPConn == nil { - return c.TCPConn.SetReadDeadline(t) - } - return c.UDPConn.SetReadDeadline(t) -} - -func (c *Client) SetWriteDeadline(t time.Time) error { - if c.UDPConn == nil { - return c.TCPConn.SetWriteDeadline(t) - } - return c.UDPConn.SetWriteDeadline(t) -} - -func (c *Client) Read(b []byte) (int, error) { - if c.UDPConn == nil { - return c.TCPConn.Read(b) - } - // TODO: UDP data - // b1 := make([]byte, 65535) - // n, err := c.UDPConn.Read(b1) - // if err != nil { - // return 0, err - // } - // d, err := NewDatagramFromBytes(b1[0:n]) - // if err != nil { - // return 0, err - // } - // if len(b) < len(d.Data) { - // return 0, errors.New("b too small") - // } - // n = copy(b, d.Data) - return 0, nil -} - -func (c *Client) Write(b []byte) (int, error) { - if c.UDPConn == nil { - return c.TCPConn.Write(b) - } - // TODO: UPD data - // addr, err := ParseAddrSpec(c.RemoteAddress.String()) - // if err != nil { - // return 0, err - // } - // if a == ATYPDomain { - // h = h[1:] - // } - // d := NewDatagram(a, h, p, b) - // b1 := d.Bytes() - return c.UDPConn.Write(b) -} - -func (c *Client) Dial(network, addr string) (net.Conn, error) { - // var err error - // - // conn := *c - // if network == "tcp" { - // conn.RemoteAddress, err = net.ResolveTCPAddr("tcp", addr) - // if err != nil { - // return nil, err - // } - // if err := conn.Negotiate(); err != nil { - // return nil, err - // } - // a, h, p, err := ParseAddress(addr) - // if err != nil { - // return nil, err - // } - // if a == ATYPDomain { - // h = h[1:] - // } - // if _, err := conn.Request(NewRequest(CommandConnect, a, h, p)); err != nil { - // return nil, err - // } - // return conn, nil - // } - // - // if network == "udp" { - // conn.RemoteAddress, err = net.ResolveUDPAddr("udp", addr) - // if err != nil { - // return nil, err - // } - // if err := conn.Negotiate(); err != nil { - // return nil, err - // } - // - // laddr := &net.UDPAddr{ - // IP: conn.TCPConn.LocalAddr().(*net.TCPAddr).IP, - // Port: conn.TCPConn.LocalAddr().(*net.TCPAddr).Port, - // Zone: conn.TCPConn.LocalAddr().(*net.TCPAddr).Zone, - // } - // a, h, p, err := ParseAddress(laddr.String()) - // if err != nil { - // return nil, err - // } - // rp, err := conn.Request(NewRequest(CmdUDP, a, h, p)) - // if err != nil { - // return nil, err - // } - // raddr, err := net.ResolveUDPAddr("udp", rp.Address()) - // if err != nil { - // return nil, err - // } - // conn.UDPConn, err = net.DialUDP("udp", laddr, raddr) - // if err != nil { - // return nil, err - // } - // return conn, nil - // } - // return nil, errors.New("unsupport network") - return nil, errors.New("aaa") -} - -func (c *Client) handshake() error { - methods := statute.MethodNoAuth - if c.UserName != "" && c.Password != "" { - methods = statute.MethodUserPassAuth - } - - _, err := c.TCPConn.Write(statute.NewMethodRequest(c.Version, []byte{methods}).Bytes()) - if err != nil { - return err - } - reply, err := statute.ParseMethodReply(c.TCPConn) - if err != nil { - return err - } - - if reply.Ver != c.Version { - return errors.New("handshake failed cause version not same") - } - if reply.Method != methods { - return errors.New("unsupport method") - } - - if methods == statute.MethodUserPassAuth { - _, err = c.TCPConn.Write(statute.NewUserPassRequest(statute.UserPassAuthVersion, []byte(c.UserName), []byte(c.Password)).Bytes()) - if err != nil { - return err - } - - rsp, err := statute.ParseUserPassReply(c.TCPConn) - if err != nil { - return err - } - if rsp.Ver != statute.UserPassAuthVersion { - return errors.New("handshake failed cause version not same") - } - if rsp.Status != statute.RepSuccess { - return statute.ErrUserAuthFailed - } - } - return nil -} diff --git a/client/client.go b/client/client.go new file mode 100644 index 0000000..c4b9df2 --- /dev/null +++ b/client/client.go @@ -0,0 +1,262 @@ +package client + +import ( + "errors" + "net" + "time" + + "golang.org/x/net/proxy" + + "github.com/thinkgos/go-socks5/statute" +) + +// Client is socks5 client wrapper +type Client struct { + Server string + Auth *proxy.Auth + // On command UDP, let server control the tcp and udp connection relationship + TCPConn *net.TCPConn + UDPConn *net.UDPConn + RemoteAddress net.Addr + TCPDeadline time.Duration + TCPTimeout time.Duration + UDPDeadline time.Duration +} + +// This is just create a client, you need to use Dial to create conn +func NewClient(addr string, opts ...Option) (*Client, error) { + c := &Client{ + Server: addr, + TCPTimeout: time.Second, + TCPDeadline: time.Second, + UDPDeadline: time.Second, + } + for _, opt := range opts { + opt(c) + } + return c, nil +} + +func (c *Client) Close() error { + if c.UDPConn == nil { + return c.TCPConn.Close() + } + if c.TCPConn != nil { + c.TCPConn.Close() + } + return c.UDPConn.Close() +} + +func (c *Client) LocalAddr() net.Addr { + if c.UDPConn == nil { + return c.TCPConn.LocalAddr() + } + return c.UDPConn.LocalAddr() +} + +func (c *Client) RemoteAddr() net.Addr { + return c.RemoteAddress +} + +func (c *Client) SetDeadline(t time.Time) error { + if c.UDPConn == nil { + return c.TCPConn.SetDeadline(t) + } + return c.UDPConn.SetDeadline(t) +} + +func (c *Client) SetReadDeadline(t time.Time) error { + if c.UDPConn == nil { + return c.TCPConn.SetReadDeadline(t) + } + return c.UDPConn.SetReadDeadline(t) +} + +func (c *Client) SetWriteDeadline(t time.Time) error { + if c.UDPConn == nil { + return c.TCPConn.SetWriteDeadline(t) + } + return c.UDPConn.SetWriteDeadline(t) +} + +func (c *Client) Read(b []byte) (int, error) { + if c.UDPConn == nil { + return c.TCPConn.Read(b) + } + b1 := make([]byte, 65535) + n, err := c.UDPConn.Read(b1) + if err != nil { + return 0, err + } + pkt := statute.Packet{} + err = pkt.Parse(b1[:n]) + if err != nil { + return 0, err + } + n = copy(b, pkt.Data) + return n, nil +} + +func (c *Client) Write(b []byte) (int, error) { + if c.UDPConn == nil { + return c.TCPConn.Write(b) + } + pkt, err := statute.NewPacket(c.RemoteAddress.String(), b) + if err != nil { + return 0, err + } + return c.UDPConn.Write(pkt.Bytes()) +} + +func (c *Client) Dial(network, addr string) (net.Conn, error) { + var err error + + conn := *c // clone a client + if network == "tcp" { + conn.RemoteAddress, err = net.ResolveTCPAddr("tcp", addr) + if err != nil { + return nil, err + } + + conn.TCPConn, err = conn.dialServer() + if err != nil { + return nil, err + } + if err := conn.handshake(); err != nil { + return nil, err + } + a, err := statute.ParseAddrSpec(addr) + if err != nil { + return nil, err + } + head := statute.Header{ + Version: statute.VersionSocks5, + Command: statute.CommandConnect, + Address: a, + } + if _, err := conn.Write(head.Bytes()); err != nil { + return nil, err + } + + rspHead, err := statute.ParseHeader(conn.TCPConn) + if err != nil { + return nil, err + } + if rspHead.Command != statute.RepSuccess { + return nil, errors.New("host unreachable") + } + return &conn, nil + } + + if network == "udp" { + conn.RemoteAddress, err = net.ResolveUDPAddr("udp", addr) + if err != nil { + return nil, err + } + conn.TCPConn, err = conn.dialServer() + if err != nil { + return nil, err + } + if err := conn.handshake(); err != nil { + return nil, err + } + laddr := &net.UDPAddr{ + IP: conn.TCPConn.LocalAddr().(*net.TCPAddr).IP, + Port: conn.TCPConn.LocalAddr().(*net.TCPAddr).Port, + Zone: conn.TCPConn.LocalAddr().(*net.TCPAddr).Zone, + } + a, err := statute.ParseAddrSpec(laddr.String()) + if err != nil { + return nil, err + } + head := statute.Header{ + Version: statute.VersionSocks5, + Command: statute.CommandConnect, + Address: a, + } + if _, err := conn.Write(head.Bytes()); err != nil { + return nil, err + } + rspHead, err := statute.ParseHeader(conn.TCPConn) + if err != nil { + return nil, err + } + if rspHead.Command != statute.RepSuccess { + return nil, errors.New("host unreachable") + } + + raddr, err := net.ResolveUDPAddr("udp", rspHead.Address.String()) + if err != nil { + return nil, err + } + conn.UDPConn, err = net.DialUDP("udp", laddr, raddr) + if err != nil { + return nil, err + } + return &conn, nil + } + + return nil, errors.New("not support network") +} + +func (c *Client) handshake() error { + methods := statute.MethodNoAuth + if c.Auth != nil { + methods = statute.MethodUserPassAuth + } + + _, err := c.TCPConn.Write(statute.NewMethodRequest(statute.VersionSocks5, []byte{methods}).Bytes()) + if err != nil { + return err + } + reply, err := statute.ParseMethodReply(c.TCPConn) + if err != nil { + return err + } + + if reply.Ver != statute.VersionSocks5 { + return statute.ErrNotSupportVersion + } + if reply.Method != methods { + return statute.ErrNotSupportMethod + } + + if methods == statute.MethodUserPassAuth { + _, err = c.TCPConn.Write(statute.NewUserPassRequest(statute.UserPassAuthVersion, []byte(c.Auth.User), []byte(c.Auth.Password)).Bytes()) + if err != nil { + return err + } + + rsp, err := statute.ParseUserPassReply(c.TCPConn) + if err != nil { + return err + } + if rsp.Ver != statute.UserPassAuthVersion { + return statute.ErrNotSupportMethod + } + if rsp.Status != statute.RepSuccess { + return statute.ErrUserAuthFailed + } + } + return nil +} + +func (c *Client) dialServer() (*net.TCPConn, error) { + conn, err := net.Dial("tcp", c.Server) + if err != nil { + return nil, err + } + + TCPConn := conn.(*net.TCPConn) + if c.TCPTimeout != 0 { + if err := TCPConn.SetKeepAlivePeriod(c.TCPTimeout); err != nil { + return nil, err + } + } + if c.TCPDeadline != 0 { + if err := TCPConn.SetDeadline(time.Now().Add(c.TCPTimeout)); err != nil { + return nil, err + } + } + return TCPConn, nil +} diff --git a/client/option.go b/client/option.go new file mode 100644 index 0000000..85602fb --- /dev/null +++ b/client/option.go @@ -0,0 +1,33 @@ +package client + +import ( + "time" + + "golang.org/x/net/proxy" +) + +type Option func(c *Client) + +func WithAuth(auth *proxy.Auth) Option { + return func(c *Client) { + c.Auth = auth + } +} + +func WithTCPTimeout(t time.Duration) Option { + return func(c *Client) { + c.TCPTimeout = t + } +} + +func WithTCPDeadline(t time.Duration) Option { + return func(c *Client) { + c.TCPDeadline = t + } +} + +func WithUDPDeadline(t time.Duration) Option { + return func(c *Client) { + c.UDPDeadline = t + } +} diff --git a/request.go b/request.go index cd0bc9e..ff59856 100644 --- a/request.go +++ b/request.go @@ -305,7 +305,7 @@ func SendReply(w io.Writer, head statute.Header, resp uint8, bindAddr ...net.Add head.Command = resp if len(bindAddr) == 0 { - head.AddrType = statute.ATYPIPv4 + head.Address.AddrType = statute.ATYPIPv4 head.Address.IP = []byte{0, 0, 0, 0} head.Address.Port = 0 } else { @@ -322,15 +322,15 @@ func SendReply(w io.Writer, head statute.Header, resp uint8, bindAddr ...net.Add } switch { case addrSpec.FQDN != "": - head.AddrType = statute.ATYPDomain + head.Address.AddrType = statute.ATYPDomain head.Address.FQDN = addrSpec.FQDN head.Address.Port = addrSpec.Port case addrSpec.IP.To4() != nil: - head.AddrType = statute.ATYPIPv4 + head.Address.AddrType = statute.ATYPIPv4 head.Address.IP = addrSpec.IP.To4() head.Address.Port = addrSpec.Port case addrSpec.IP.To16() != nil: - head.AddrType = statute.ATYPIPv6 + head.Address.AddrType = statute.ATYPIPv6 head.Address.IP = addrSpec.IP.To16() head.Address.Port = addrSpec.Port default: diff --git a/request_test.go b/request_test.go index d894096..1879277 100644 --- a/request_test.go +++ b/request_test.go @@ -7,8 +7,9 @@ import ( "log" "net" "os" - "strings" "testing" + + "github.com/stretchr/testify/require" ) type MockConn struct { @@ -26,30 +27,24 @@ func (m *MockConn) RemoteAddr() net.Addr { func TestRequest_Connect(t *testing.T) { // Create a local listener l, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - t.Fatalf("err: %v", err) - } + require.NoError(t, err) + go func() { conn, err := l.Accept() - if err != nil { - t.Fatalf("err: %v", err) - } + require.NoError(t, err) defer conn.Close() buf := make([]byte, 4) - if _, err := io.ReadAtLeast(conn, buf, 4); err != nil { - t.Fatalf("err: %v", err) - } + _, err = io.ReadAtLeast(conn, buf, 4) + require.NoError(t, err) + require.Equal(t, []byte("ping"), buf) - if !bytes.Equal(buf, []byte("ping")) { - t.Fatalf("bad: %v", buf) - } - _, _ = conn.Write([]byte("pong")) + conn.Write([]byte("pong")) // nolint: errcheck }() lAddr := l.Addr().(*net.TCPAddr) // Make server - s := &Server{ + proxySrv := &Server{ rules: NewPermitAll(), resolver: DNSResolver{}, logger: NewLogger(log.New(os.Stdout, "socks5: ", log.LstdFlags)), @@ -58,25 +53,22 @@ func TestRequest_Connect(t *testing.T) { // Create the connect request buf := bytes.NewBuffer(nil) - buf.Write([]byte{5, 1, 0, 1, 127, 0, 0, 1}) + buf.Write([]byte{5, 1, 0, 1, 127, 0, 0, 1}) // nolint: errcheck port := []byte{0, 0} binary.BigEndian.PutUint16(port, uint16(lAddr.Port)) - buf.Write(port) + buf.Write(port) // nolint: errcheck // Send a ping - buf.Write([]byte("ping")) + buf.Write([]byte("ping")) // nolint: errcheck // Handle the request resp := &MockConn{} req, err := NewRequest(buf) - if err != nil { - t.Fatalf("err: %v", err) - } + require.NoError(t, err) - if err := s.handleRequest(resp, req); err != nil { - t.Fatalf("err: %v", err) - } + err = proxySrv.handleRequest(resp, req) + require.NoError(t, err) // Verify response out := resp.buf.Bytes() @@ -93,34 +85,24 @@ func TestRequest_Connect(t *testing.T) { // Ignore the port for both out[8] = 0 out[9] = 0 - - if !bytes.Equal(out, expected) { - t.Fatalf("bad: %v %v", out, expected) - } + require.Equal(t, expected, out) } func TestRequest_Connect_RuleFail(t *testing.T) { // Create a local listener l, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - t.Fatalf("err: %v", err) - } + require.NoError(t, err) + go func() { conn, err := l.Accept() - if err != nil { - t.Fatalf("err: %v", err) - } + require.NoError(t, err) defer conn.Close() buf := make([]byte, 4) - if _, err := io.ReadAtLeast(conn, buf, 4); err != nil { - t.Fatalf("err: %v", err) - } - - if !bytes.Equal(buf, []byte("ping")) { - t.Fatalf("bad: %v", buf) - } - _, _ = conn.Write([]byte("pong")) + _, err = io.ReadAtLeast(conn, buf, 4) + require.NoError(t, err) + require.Equal(t, []byte("ping"), buf) + conn.Write([]byte("pong")) // nolint: errcheck }() lAddr := l.Addr().(*net.TCPAddr) @@ -146,13 +128,10 @@ func TestRequest_Connect_RuleFail(t *testing.T) { // Handle the request resp := &MockConn{} req, err := NewRequest(buf) - if err != nil { - t.Fatalf("err: %v", err) - } + require.NoError(t, err) - if err := s.handleRequest(resp, req); !strings.Contains(err.Error(), "blocked by rules") { - t.Fatalf("err: %v", err) - } + err = s.handleRequest(resp, req) + require.Contains(t, err.Error(), "blocked by rules") // Verify response out := resp.buf.Bytes() @@ -164,8 +143,5 @@ func TestRequest_Connect_RuleFail(t *testing.T) { 0, 0, 0, 0, 0, 0, } - - if !bytes.Equal(out, expected) { - t.Fatalf("bad: %v %v", out, expected) - } + require.Equal(t, expected, out) } diff --git a/server.go b/server.go index b3fd3b0..b557abd 100644 --- a/server.go +++ b/server.go @@ -56,8 +56,8 @@ type Server struct { userAssociateHandle func(ctx context.Context, writer io.Writer, request *Request) error } -// New creates a new Server and potentially returns an error -func New(opts ...Option) *Server { +// NewServer creates a new Server and potentially returns an error +func NewServer(opts ...Option) *Server { server := &Server{ authMethods: make(map[uint8]Authenticator), authCustomMethods: []Authenticator{&NoAuthAuthenticator{}}, diff --git a/server_test.go b/server_test.go index b0eca9c..b5a35c7 100644 --- a/server_test.go +++ b/server_test.go @@ -20,50 +20,38 @@ import ( func TestSOCKS5_Connect(t *testing.T) { // Create a local listener l, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - t.Fatalf("err: %v", err) - } + require.NoError(t, err) + go func() { conn, err := l.Accept() - if err != nil { - t.Fatalf("err: %v", err) - } + require.NoError(t, err) defer conn.Close() buf := make([]byte, 4) - if _, err := io.ReadAtLeast(conn, buf, 4); err != nil { - t.Fatalf("err: %v", err) - } - - if !bytes.Equal(buf, []byte("ping")) { - t.Fatalf("bad: %v", buf) - } + _, err = io.ReadAtLeast(conn, buf, 4) + require.NoError(t, err) + assert.Equal(t, []byte("ping"), buf) _, _ = conn.Write([]byte("pong")) }() lAddr := l.Addr().(*net.TCPAddr) // Create a socks server - cator := UserPassAuthenticator{ - Credentials: StaticCredentials{"foo": "bar"}, - } - serv := New( + cator := UserPassAuthenticator{StaticCredentials{"foo": "bar"}} + srv := NewServer( WithAuthMethods([]Authenticator{cator}), WithLogger(NewLogger(log.New(os.Stdout, "socks5: ", log.LstdFlags))), ) // Start listening go func() { - if err := serv.ListenAndServe("tcp", "127.0.0.1:12365"); err != nil { - t.Fatalf("err: %v", err) - } + err := srv.ListenAndServe("tcp", "127.0.0.1:12365") + require.NoError(t, err) }() time.Sleep(10 * time.Millisecond) // Get a local conn conn, err := net.Dial("tcp", "127.0.0.1:12365") - if err != nil { - t.Fatalf("err: %v", err) - } + require.NoError(t, err) // Connect, auth and connec to local req := new(bytes.Buffer) @@ -77,8 +65,8 @@ func TestSOCKS5_Connect(t *testing.T) { "", net.ParseIP("127.0.0.1"), lAddr.Port, + statute.ATYPIPv4, }, - AddrType: statute.ATYPIPv4, } req.Write(reqHead.Bytes()) // Send a ping @@ -100,27 +88,23 @@ func TestSOCKS5_Connect(t *testing.T) { "", net.ParseIP("127.0.0.1"), 0, // Ignore the port + statute.ATYPIPv4, }, - AddrType: statute.ATYPIPv4, } expected = append(expected, rspHead.Bytes()...) expected = append(expected, []byte("pong")...) out := make([]byte, len(expected)) _ = conn.SetDeadline(time.Now().Add(time.Second)) - if _, err := io.ReadFull(conn, out); err != nil { - t.Fatalf("err: %v", err) - } + _, err = io.ReadFull(conn, out) + require.NoError(t, err) t.Logf("proxy bind port: %d", statute.BuildPort(out[12], out[13])) // Ignore the port out[12] = 0 out[13] = 0 - - if !bytes.Equal(out, expected) { - t.Fatalf("bad: %v", out) - } + assert.Equal(t, expected, out) } func TestSOCKS5_Associate(t *testing.T) { @@ -131,10 +115,9 @@ func TestSOCKS5_Associate(t *testing.T) { Port: 12398, } l, err := net.ListenUDP("udp", lAddr) - if err != nil { - t.Fatalf("err: %v", err) - } + require.NoError(t, err) defer l.Close() + go func() { buf := make([]byte, 2048) for { @@ -142,33 +125,28 @@ func TestSOCKS5_Associate(t *testing.T) { if err != nil { return } + require.Equal(t, []byte("ping"), buf[:n]) - if !bytes.Equal(buf[:n], []byte("ping")) { - t.Fatalf("bad: %v", buf) - } - _, _ = l.WriteTo([]byte("pong"), remote) + l.WriteTo([]byte("pong"), remote) // nolint: errcheck } }() // Create a socks server - cator := UserPassAuthenticator{Credentials: StaticCredentials{"foo": "bar"}} - serv := New( + cator := UserPassAuthenticator{StaticCredentials{"foo": "bar"}} + srv := NewServer( WithAuthMethods([]Authenticator{cator}), WithLogger(NewLogger(log.New(os.Stdout, "socks5: ", log.LstdFlags))), ) // Start listening go func() { - if err := serv.ListenAndServe("tcp", "127.0.0.1:12355"); err != nil { - t.Fatalf("err: %v", err) - } + err := srv.ListenAndServe("tcp", "127.0.0.1:12355") + require.NoError(t, err) }() time.Sleep(10 * time.Millisecond) // Get a local conn conn, err := net.Dial("tcp", "127.0.0.1:12355") - if err != nil { - t.Fatalf("err: %v", err) - } + require.NoError(t, err) // Connect, auth and connec to local req := new(bytes.Buffer) @@ -182,12 +160,12 @@ func TestSOCKS5_Associate(t *testing.T) { "", locIP, lAddr.Port, + statute.ATYPIPv4, }, - AddrType: statute.ATYPIPv4, } req.Write(reqHead.Bytes()) // Send all the bytes - conn.Write(req.Bytes()) + conn.Write(req.Bytes()) // nolint: errcheck // Verify response expected := []byte{ @@ -197,21 +175,14 @@ func TestSOCKS5_Associate(t *testing.T) { out := make([]byte, len(expected)) _ = conn.SetDeadline(time.Now().Add(time.Second)) - if _, err := io.ReadFull(conn, out); err != nil { - t.Fatalf("err: %v", err) - } - - if !bytes.Equal(out, expected) { - t.Fatalf("bad: %v", out) - } + _, err = io.ReadFull(conn, out) + require.NoError(t, err) + require.Equal(t, expected, out) rspHead, err := statute.ParseHeader(conn) - if err != nil { - t.Fatalf("bad response header: %v", err) - } - if rspHead.Version != statute.VersionSocks5 && rspHead.Command != statute.RepSuccess { - t.Fatalf("parse success but bad header: %v", rspHead) - } + require.NoError(t, err) + require.Equal(t, statute.VersionSocks5, rspHead.Version) + require.Equal(t, statute.RepSuccess, rspHead.Command) t.Logf("proxy bind listen port: %d", rspHead.Address.Port) @@ -219,11 +190,9 @@ func TestSOCKS5_Associate(t *testing.T) { IP: locIP, Port: rspHead.Address.Port, }) - if err != nil { - t.Fatalf("bad dial: %v", err) - } + require.NoError(t, err) // Send a ping - _, _ = udpConn.Write(append([]byte{0, 0, 0, statute.ATYPIPv4, 0, 0, 0, 0, 0, 0}, []byte("ping")...)) + udpConn.Write(append([]byte{0, 0, 0, statute.ATYPIPv4, 0, 0, 0, 0, 0, 0}, []byte("ping")...)) // nolint: errcheck response := make([]byte, 1024) n, _, err := udpConn.ReadFrom(response) if err != nil || !bytes.Equal(response[n-4:n], []byte("pong")) { @@ -235,66 +204,52 @@ func TestSOCKS5_Associate(t *testing.T) { func Test_SocksWithProxy(t *testing.T) { // Create a local listener l, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - t.Fatalf("err: %v", err) - } + require.NoError(t, err) + go func() { conn, err := l.Accept() - if err != nil { - t.Fatalf("err: %v", err) - } + require.NoError(t, err) defer conn.Close() buf := make([]byte, 4) - if _, err := io.ReadAtLeast(conn, buf, 4); err != nil { - t.Fatalf("err: %v", err) - } + _, err = io.ReadAtLeast(conn, buf, 4) + require.NoError(t, err) + require.Equal(t, []byte("ping"), buf) - if !bytes.Equal(buf, []byte("ping")) { - t.Fatalf("bad: %v", buf) - } conn.Write([]byte("pong")) }() lAddr := l.Addr().(*net.TCPAddr) // Create a socks server - cator := UserPassAuthenticator{Credentials: StaticCredentials{"foo": "bar"}} - serv := New( + cator := UserPassAuthenticator{StaticCredentials{"foo": "bar"}} + serv := NewServer( WithAuthMethods([]Authenticator{cator}), WithLogger(NewLogger(log.New(os.Stdout, "socks5: ", log.LstdFlags))), ) - - // Start listening + // Start socks server go func() { - if err := serv.ListenAndServe("tcp", "127.0.0.1:12395"); err != nil { - t.Fatalf("err: %v", err) - } + err := serv.ListenAndServe("tcp", "127.0.0.1:12395") + require.NoError(t, err) }() time.Sleep(10 * time.Millisecond) + // client dial, err := proxy.SOCKS5("tcp", "127.0.0.1:12395", &proxy.Auth{User: "foo", Password: "bar"}, proxy.Direct) - if err != nil { - t.Fatalf("err: %v", err) - } + require.NoError(t, err) // Connect, auth and connect to local conn, err := dial.Dial("tcp", lAddr.String()) - if err != nil { - t.Fatalf("err: %v", err) - } + require.NoError(t, err) // Send a ping - _, _ = conn.Write([]byte("ping")) + conn.Write([]byte("ping")) // nolint: errcheck out := make([]byte, 4) - _ = conn.SetDeadline(time.Now().Add(time.Second)) - if _, err := io.ReadFull(conn, out); err != nil { - t.Fatalf("err: %v", err) - } + _ = conn.SetDeadline(time.Now().Add(time.Second)) // nolint: errcheck + _, err = io.ReadFull(conn, out) + require.NoError(t, err) - if !bytes.Equal(out, []byte("pong")) { - t.Fatalf("bad: %v", out) - } + require.Equal(t, []byte("pong"), out) } /***************************** auth *******************************/ @@ -302,7 +257,7 @@ func Test_SocksWithProxy(t *testing.T) { func TestNoAuth_Server(t *testing.T) { req := bytes.NewBuffer(nil) rsp := new(bytes.Buffer) - s := New() + s := NewServer() ctx, err := s.authenticate(rsp, req, "", []byte{statute.MethodNoAuth}) require.NoError(t, err) @@ -314,11 +269,9 @@ func TestPasswordAuth_Valid_Server(t *testing.T) { req := bytes.NewBuffer([]byte{1, 3, 'f', 'o', 'o', 3, 'b', 'a', 'r'}) rsp := new(bytes.Buffer) cator := UserPassAuthenticator{ - StaticCredentials{ - "foo": "bar", - }, + StaticCredentials{"foo": "bar"}, } - s := New(WithAuthMethods([]Authenticator{cator})) + s := NewServer(WithAuthMethods([]Authenticator{cator})) ctx, err := s.authenticate(rsp, req, "", []byte{statute.MethodUserPassAuth}) require.NoError(t, err) @@ -339,11 +292,9 @@ func TestPasswordAuth_Invalid_Server(t *testing.T) { req := bytes.NewBuffer([]byte{1, 3, 'f', 'o', 'o', 3, 'b', 'a', 'z'}) rsp := new(bytes.Buffer) cator := UserPassAuthenticator{ - StaticCredentials{ - "foo": "bar", - }, + StaticCredentials{"foo": "bar"}, } - s := New(WithAuthMethods([]Authenticator{cator})) + s := NewServer(WithAuthMethods([]Authenticator{cator})) ctx, err := s.authenticate(rsp, req, "", []byte{statute.MethodNoAuth, statute.MethodUserPassAuth}) require.True(t, errors.Is(err, statute.ErrUserAuthFailed)) @@ -356,12 +307,10 @@ func TestNoSupportedAuth_Server(t *testing.T) { req := bytes.NewBuffer(nil) rsp := new(bytes.Buffer) cator := UserPassAuthenticator{ - StaticCredentials{ - "foo": "bar", - }, + StaticCredentials{"foo": "bar"}, } - s := New(WithAuthMethods([]Authenticator{cator})) + s := NewServer(WithAuthMethods([]Authenticator{cator})) ctx, err := s.authenticate(rsp, req, "", []byte{statute.MethodNoAuth}) require.True(t, errors.Is(err, statute.ErrNoSupportedAuth)) diff --git a/statute/auth.go b/statute/auth.go index 82eba2c..d437b0a 100644 --- a/statute/auth.go +++ b/statute/auth.go @@ -5,6 +5,12 @@ import ( "io" ) +// auth error defined +var ( + ErrUserAuthFailed = fmt.Errorf("user authentication failed") + ErrNoSupportedAuth = fmt.Errorf("no supported authentication mechanism") +) + // UserPassRequest is the negotiation user's password request packet // The SOCKS handshake user's password request is formed as follows: // +--------------+------+----------+------+----------+ diff --git a/statute/errors.go b/statute/errors.go index 807a6d5..4ba1136 100644 --- a/statute/errors.go +++ b/statute/errors.go @@ -7,4 +7,5 @@ import ( var ( ErrUnrecognizedAddrType = errors.New("Unrecognized address type") ErrNotSupportVersion = errors.New("not support version") + ErrNotSupportMethod = errors.New("not support method") ) diff --git a/statute/header.go b/statute/header.go index efbef58..d93bfef 100644 --- a/statute/header.go +++ b/statute/header.go @@ -31,8 +31,6 @@ type Header struct { Reserved uint8 // Address in socks message Address AddrSpec - // private stuff set when Header parsed - AddrType uint8 } // ParseHeader to header from io.Reader @@ -54,8 +52,8 @@ func ParseHeader(r io.Reader) (hd Header, err error) { return hd, fmt.Errorf("failed to get header RSV and address type, %v", err) } hd.Reserved = tmp[0] - hd.AddrType = tmp[1] - switch hd.AddrType { + hd.Address.AddrType = tmp[1] + switch hd.Address.AddrType { case ATYPDomain: if _, err = io.ReadFull(r, tmp[:1]); err != nil { return hd, fmt.Errorf("failed to get header, %v", err) @@ -93,10 +91,10 @@ func (h Header) Bytes() (b []byte) { var addr []byte length := headerVERLen + headerCMDLen + headerRSVLen + headerATYPLen + headerPORTLen - if h.AddrType == ATYPIPv4 { + if h.Address.AddrType == ATYPIPv4 { length += net.IPv4len addr = h.Address.IP.To4() - } else if h.AddrType == ATYPIPv6 { + } else if h.Address.AddrType == ATYPIPv6 { length += net.IPv6len addr = h.Address.IP.To16() } else { //ATYPDomain @@ -108,8 +106,8 @@ func (h Header) Bytes() (b []byte) { b = append(b, h.Version) b = append(b, h.Command) b = append(b, h.Reserved) - b = append(b, h.AddrType) - if h.AddrType == ATYPDomain { + b = append(b, h.Address.AddrType) + if h.Address.AddrType == ATYPDomain { b = append(b, byte(len(h.Address.FQDN))) } b = append(b, addr...) diff --git a/statute/header_test.go b/statute/header_test.go index 8251e49..13a07c7 100644 --- a/statute/header_test.go +++ b/statute/header_test.go @@ -23,8 +23,7 @@ func TestParseHeader(t *testing.T) { args{bytes.NewReader([]byte{VersionSocks5, CommandConnect, 0, ATYPIPv4, 127, 0, 0, 1, 0x1f, 0x90})}, Header{ VersionSocks5, CommandConnect, 0, - AddrSpec{IP: net.IPv4(127, 0, 0, 1), Port: 8080}, - ATYPIPv4, + AddrSpec{IP: net.IPv4(127, 0, 0, 1), Port: 8080, AddrType: ATYPIPv4}, }, false, }, @@ -33,8 +32,7 @@ func TestParseHeader(t *testing.T) { args{bytes.NewReader([]byte{VersionSocks5, CommandConnect, 0, ATYPIPv6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x1f, 0x90})}, Header{ VersionSocks5, CommandConnect, 0, - AddrSpec{IP: net.IPv6zero, Port: 8080}, - ATYPIPv6, + AddrSpec{IP: net.IPv6zero, Port: 8080, AddrType: ATYPIPv6}, }, false, }, @@ -43,8 +41,7 @@ func TestParseHeader(t *testing.T) { args{bytes.NewReader([]byte{VersionSocks5, CommandConnect, 0, ATYPDomain, 9, 'l', 'o', 'c', 'a', 'l', 'h', 'o', 's', 't', 0x1f, 0x90})}, Header{ VersionSocks5, CommandConnect, 0, - AddrSpec{FQDN: "localhost", Port: 8080}, - ATYPDomain, + AddrSpec{FQDN: "localhost", Port: 8080, AddrType: ATYPDomain}, }, false, }, @@ -74,8 +71,7 @@ func TestHeader_Bytes(t *testing.T) { "SOCKS5 IPV4", Header{ VersionSocks5, CommandConnect, 0, - AddrSpec{IP: net.IPv4(127, 0, 0, 1), Port: 8080}, - ATYPIPv4, + AddrSpec{IP: net.IPv4(127, 0, 0, 1), Port: 8080, AddrType: ATYPIPv4}, }, []byte{VersionSocks5, CommandConnect, 0, ATYPIPv4, 127, 0, 0, 1, 0x1f, 0x90}, }, @@ -83,8 +79,7 @@ func TestHeader_Bytes(t *testing.T) { "SOCKS5 IPV6", Header{ VersionSocks5, CommandConnect, 0, - AddrSpec{IP: net.IPv6zero, Port: 8080}, - ATYPIPv6, + AddrSpec{IP: net.IPv6zero, Port: 8080, AddrType: ATYPIPv6}, }, []byte{VersionSocks5, CommandConnect, 0, ATYPIPv6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x1f, 0x90}, }, @@ -92,8 +87,7 @@ func TestHeader_Bytes(t *testing.T) { "SOCKS5 FQDN", Header{ VersionSocks5, CommandConnect, 0, - AddrSpec{FQDN: "localhost", Port: 8080}, - ATYPDomain, + AddrSpec{FQDN: "localhost", Port: 8080, AddrType: ATYPDomain}, }, []byte{VersionSocks5, CommandConnect, 0, ATYPDomain, 9, 'l', 'o', 'c', 'a', 'l', 'h', 'o', 's', 't', 0x1f, 0x90}, }, diff --git a/statute/statute.go b/statute/statute.go index a51ab5d..d805fbe 100644 --- a/statute/statute.go +++ b/statute/statute.go @@ -1,9 +1,5 @@ package statute -import ( - "fmt" -) - // auth defined const ( MethodNoAuth = byte(0x00) @@ -44,9 +40,3 @@ const ( RepAddrTypeNotSupported // 0x09 - 0xff unassigned ) - -// auth error defined -var ( - ErrUserAuthFailed = fmt.Errorf("user authentication failed") - ErrNoSupportedAuth = fmt.Errorf("no supported authentication mechanism") -) diff --git a/statute/util.go b/statute/util.go index 2c960e5..83108d8 100644 --- a/statute/util.go +++ b/statute/util.go @@ -12,6 +12,8 @@ type AddrSpec struct { FQDN string IP net.IP Port int + // private stuff set when Header parsed + AddrType uint8 } // Address returns a string suitable to dial; prefer returning IP-based @@ -44,10 +46,13 @@ func ParseAddrSpec(address string) (a AddrSpec, err error) { } ip := net.ParseIP(host) if ip4 := ip.To4(); ip4 != nil { + a.AddrType = ATYPIPv4 a.IP = ip } else if ip6 := ip.To16(); ip6 != nil { + a.AddrType = ATYPIPv6 a.IP = ip } else { + a.AddrType = ATYPDomain a.FQDN = host } a.Port, err = strconv.Atoi(port) diff --git a/statute/util_test.go b/statute/util_test.go index 39863b6..74199dd 100644 --- a/statute/util_test.go +++ b/statute/util_test.go @@ -45,8 +45,9 @@ func TestParseAddrSpec1(t *testing.T) { "IPv4", args{"127.0.0.1:8080"}, AddrSpec{ - IP: net.IPv4(127, 0, 0, 1), - Port: 8080, + IP: net.IPv4(127, 0, 0, 1), + Port: 8080, + AddrType: ATYPIPv4, }, false, }, @@ -54,8 +55,9 @@ func TestParseAddrSpec1(t *testing.T) { "IPv6", args{"[::1]:8080"}, AddrSpec{ - IP: net.IPv6loopback, - Port: 8080, + IP: net.IPv6loopback, + Port: 8080, + AddrType: ATYPIPv6, }, false, }, @@ -63,8 +65,9 @@ func TestParseAddrSpec1(t *testing.T) { "FQDN", args{"localhost:8080"}, AddrSpec{ - FQDN: "localhost", - Port: 8080, + FQDN: "localhost", + Port: 8080, + AddrType: ATYPDomain, }, false, },