fix client

fix server test
This commit is contained in:
mo 2020-08-05 17:32:37 +08:00
parent 274a120a58
commit 75dd487cbf
15 changed files with 424 additions and 428 deletions

@ -39,7 +39,7 @@ Below is a simple example of usage
```go
// Create a SOCKS5 server
server := socks5.New()
server := socks5.NewServer()
// Create SOCKS5 proxy on localhost port 8000
if err := server.ListenAndServe("tcp", "127.0.0.1:8000"); err != nil {

221
client.go

@ -1,221 +0,0 @@
package socks5
import (
"errors"
"net"
"time"
"github.com/thinkgos/go-socks5/statute"
)
// Client is socks5 client wrapper
type Client struct {
Server string
Version byte
UserName string
Password string
// On command UDP, let server control the tcp and udp connection relationship
TCPConn *net.TCPConn
UDPConn *net.UDPConn
RemoteAddress net.Addr
TCPDeadline int
TCPTimeout int
UDPDeadline int
}
// This is just create a client, you need to use Dial to create conn
func NewClient(addr, username, password string, tcpTimeout, tcpDeadline, udpDeadline int) (*Client, error) {
c := &Client{
Server: addr,
Version: statute.VersionSocks5,
UserName: username,
Password: password,
TCPTimeout: tcpTimeout,
TCPDeadline: tcpDeadline,
UDPDeadline: udpDeadline,
}
return c, nil
}
func (c *Client) Close() error {
if c.UDPConn == nil {
return c.TCPConn.Close()
}
if c.TCPConn != nil {
c.TCPConn.Close()
}
return c.UDPConn.Close()
}
func (c *Client) LocalAddr() net.Addr {
if c.UDPConn == nil {
return c.TCPConn.LocalAddr()
}
return c.UDPConn.LocalAddr()
}
func (c *Client) RemoteAddr() net.Addr {
return c.RemoteAddress
}
func (c *Client) SetDeadline(t time.Time) error {
if c.UDPConn == nil {
return c.TCPConn.SetDeadline(t)
}
return c.UDPConn.SetDeadline(t)
}
func (c *Client) SetReadDeadline(t time.Time) error {
if c.UDPConn == nil {
return c.TCPConn.SetReadDeadline(t)
}
return c.UDPConn.SetReadDeadline(t)
}
func (c *Client) SetWriteDeadline(t time.Time) error {
if c.UDPConn == nil {
return c.TCPConn.SetWriteDeadline(t)
}
return c.UDPConn.SetWriteDeadline(t)
}
func (c *Client) Read(b []byte) (int, error) {
if c.UDPConn == nil {
return c.TCPConn.Read(b)
}
// TODO: UDP data
// b1 := make([]byte, 65535)
// n, err := c.UDPConn.Read(b1)
// if err != nil {
// return 0, err
// }
// d, err := NewDatagramFromBytes(b1[0:n])
// if err != nil {
// return 0, err
// }
// if len(b) < len(d.Data) {
// return 0, errors.New("b too small")
// }
// n = copy(b, d.Data)
return 0, nil
}
func (c *Client) Write(b []byte) (int, error) {
if c.UDPConn == nil {
return c.TCPConn.Write(b)
}
// TODO: UPD data
// addr, err := ParseAddrSpec(c.RemoteAddress.String())
// if err != nil {
// return 0, err
// }
// if a == ATYPDomain {
// h = h[1:]
// }
// d := NewDatagram(a, h, p, b)
// b1 := d.Bytes()
return c.UDPConn.Write(b)
}
func (c *Client) Dial(network, addr string) (net.Conn, error) {
// var err error
//
// conn := *c
// if network == "tcp" {
// conn.RemoteAddress, err = net.ResolveTCPAddr("tcp", addr)
// if err != nil {
// return nil, err
// }
// if err := conn.Negotiate(); err != nil {
// return nil, err
// }
// a, h, p, err := ParseAddress(addr)
// if err != nil {
// return nil, err
// }
// if a == ATYPDomain {
// h = h[1:]
// }
// if _, err := conn.Request(NewRequest(CommandConnect, a, h, p)); err != nil {
// return nil, err
// }
// return conn, nil
// }
//
// if network == "udp" {
// conn.RemoteAddress, err = net.ResolveUDPAddr("udp", addr)
// if err != nil {
// return nil, err
// }
// if err := conn.Negotiate(); err != nil {
// return nil, err
// }
//
// laddr := &net.UDPAddr{
// IP: conn.TCPConn.LocalAddr().(*net.TCPAddr).IP,
// Port: conn.TCPConn.LocalAddr().(*net.TCPAddr).Port,
// Zone: conn.TCPConn.LocalAddr().(*net.TCPAddr).Zone,
// }
// a, h, p, err := ParseAddress(laddr.String())
// if err != nil {
// return nil, err
// }
// rp, err := conn.Request(NewRequest(CmdUDP, a, h, p))
// if err != nil {
// return nil, err
// }
// raddr, err := net.ResolveUDPAddr("udp", rp.Address())
// if err != nil {
// return nil, err
// }
// conn.UDPConn, err = net.DialUDP("udp", laddr, raddr)
// if err != nil {
// return nil, err
// }
// return conn, nil
// }
// return nil, errors.New("unsupport network")
return nil, errors.New("aaa")
}
func (c *Client) handshake() error {
methods := statute.MethodNoAuth
if c.UserName != "" && c.Password != "" {
methods = statute.MethodUserPassAuth
}
_, err := c.TCPConn.Write(statute.NewMethodRequest(c.Version, []byte{methods}).Bytes())
if err != nil {
return err
}
reply, err := statute.ParseMethodReply(c.TCPConn)
if err != nil {
return err
}
if reply.Ver != c.Version {
return errors.New("handshake failed cause version not same")
}
if reply.Method != methods {
return errors.New("unsupport method")
}
if methods == statute.MethodUserPassAuth {
_, err = c.TCPConn.Write(statute.NewUserPassRequest(statute.UserPassAuthVersion, []byte(c.UserName), []byte(c.Password)).Bytes())
if err != nil {
return err
}
rsp, err := statute.ParseUserPassReply(c.TCPConn)
if err != nil {
return err
}
if rsp.Ver != statute.UserPassAuthVersion {
return errors.New("handshake failed cause version not same")
}
if rsp.Status != statute.RepSuccess {
return statute.ErrUserAuthFailed
}
}
return nil
}

262
client/client.go Normal file

@ -0,0 +1,262 @@
package client
import (
"errors"
"net"
"time"
"golang.org/x/net/proxy"
"github.com/thinkgos/go-socks5/statute"
)
// Client is socks5 client wrapper
type Client struct {
Server string
Auth *proxy.Auth
// On command UDP, let server control the tcp and udp connection relationship
TCPConn *net.TCPConn
UDPConn *net.UDPConn
RemoteAddress net.Addr
TCPDeadline time.Duration
TCPTimeout time.Duration
UDPDeadline time.Duration
}
// This is just create a client, you need to use Dial to create conn
func NewClient(addr string, opts ...Option) (*Client, error) {
c := &Client{
Server: addr,
TCPTimeout: time.Second,
TCPDeadline: time.Second,
UDPDeadline: time.Second,
}
for _, opt := range opts {
opt(c)
}
return c, nil
}
func (c *Client) Close() error {
if c.UDPConn == nil {
return c.TCPConn.Close()
}
if c.TCPConn != nil {
c.TCPConn.Close()
}
return c.UDPConn.Close()
}
func (c *Client) LocalAddr() net.Addr {
if c.UDPConn == nil {
return c.TCPConn.LocalAddr()
}
return c.UDPConn.LocalAddr()
}
func (c *Client) RemoteAddr() net.Addr {
return c.RemoteAddress
}
func (c *Client) SetDeadline(t time.Time) error {
if c.UDPConn == nil {
return c.TCPConn.SetDeadline(t)
}
return c.UDPConn.SetDeadline(t)
}
func (c *Client) SetReadDeadline(t time.Time) error {
if c.UDPConn == nil {
return c.TCPConn.SetReadDeadline(t)
}
return c.UDPConn.SetReadDeadline(t)
}
func (c *Client) SetWriteDeadline(t time.Time) error {
if c.UDPConn == nil {
return c.TCPConn.SetWriteDeadline(t)
}
return c.UDPConn.SetWriteDeadline(t)
}
func (c *Client) Read(b []byte) (int, error) {
if c.UDPConn == nil {
return c.TCPConn.Read(b)
}
b1 := make([]byte, 65535)
n, err := c.UDPConn.Read(b1)
if err != nil {
return 0, err
}
pkt := statute.Packet{}
err = pkt.Parse(b1[:n])
if err != nil {
return 0, err
}
n = copy(b, pkt.Data)
return n, nil
}
func (c *Client) Write(b []byte) (int, error) {
if c.UDPConn == nil {
return c.TCPConn.Write(b)
}
pkt, err := statute.NewPacket(c.RemoteAddress.String(), b)
if err != nil {
return 0, err
}
return c.UDPConn.Write(pkt.Bytes())
}
func (c *Client) Dial(network, addr string) (net.Conn, error) {
var err error
conn := *c // clone a client
if network == "tcp" {
conn.RemoteAddress, err = net.ResolveTCPAddr("tcp", addr)
if err != nil {
return nil, err
}
conn.TCPConn, err = conn.dialServer()
if err != nil {
return nil, err
}
if err := conn.handshake(); err != nil {
return nil, err
}
a, err := statute.ParseAddrSpec(addr)
if err != nil {
return nil, err
}
head := statute.Header{
Version: statute.VersionSocks5,
Command: statute.CommandConnect,
Address: a,
}
if _, err := conn.Write(head.Bytes()); err != nil {
return nil, err
}
rspHead, err := statute.ParseHeader(conn.TCPConn)
if err != nil {
return nil, err
}
if rspHead.Command != statute.RepSuccess {
return nil, errors.New("host unreachable")
}
return &conn, nil
}
if network == "udp" {
conn.RemoteAddress, err = net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
}
conn.TCPConn, err = conn.dialServer()
if err != nil {
return nil, err
}
if err := conn.handshake(); err != nil {
return nil, err
}
laddr := &net.UDPAddr{
IP: conn.TCPConn.LocalAddr().(*net.TCPAddr).IP,
Port: conn.TCPConn.LocalAddr().(*net.TCPAddr).Port,
Zone: conn.TCPConn.LocalAddr().(*net.TCPAddr).Zone,
}
a, err := statute.ParseAddrSpec(laddr.String())
if err != nil {
return nil, err
}
head := statute.Header{
Version: statute.VersionSocks5,
Command: statute.CommandConnect,
Address: a,
}
if _, err := conn.Write(head.Bytes()); err != nil {
return nil, err
}
rspHead, err := statute.ParseHeader(conn.TCPConn)
if err != nil {
return nil, err
}
if rspHead.Command != statute.RepSuccess {
return nil, errors.New("host unreachable")
}
raddr, err := net.ResolveUDPAddr("udp", rspHead.Address.String())
if err != nil {
return nil, err
}
conn.UDPConn, err = net.DialUDP("udp", laddr, raddr)
if err != nil {
return nil, err
}
return &conn, nil
}
return nil, errors.New("not support network")
}
func (c *Client) handshake() error {
methods := statute.MethodNoAuth
if c.Auth != nil {
methods = statute.MethodUserPassAuth
}
_, err := c.TCPConn.Write(statute.NewMethodRequest(statute.VersionSocks5, []byte{methods}).Bytes())
if err != nil {
return err
}
reply, err := statute.ParseMethodReply(c.TCPConn)
if err != nil {
return err
}
if reply.Ver != statute.VersionSocks5 {
return statute.ErrNotSupportVersion
}
if reply.Method != methods {
return statute.ErrNotSupportMethod
}
if methods == statute.MethodUserPassAuth {
_, err = c.TCPConn.Write(statute.NewUserPassRequest(statute.UserPassAuthVersion, []byte(c.Auth.User), []byte(c.Auth.Password)).Bytes())
if err != nil {
return err
}
rsp, err := statute.ParseUserPassReply(c.TCPConn)
if err != nil {
return err
}
if rsp.Ver != statute.UserPassAuthVersion {
return statute.ErrNotSupportMethod
}
if rsp.Status != statute.RepSuccess {
return statute.ErrUserAuthFailed
}
}
return nil
}
func (c *Client) dialServer() (*net.TCPConn, error) {
conn, err := net.Dial("tcp", c.Server)
if err != nil {
return nil, err
}
TCPConn := conn.(*net.TCPConn)
if c.TCPTimeout != 0 {
if err := TCPConn.SetKeepAlivePeriod(c.TCPTimeout); err != nil {
return nil, err
}
}
if c.TCPDeadline != 0 {
if err := TCPConn.SetDeadline(time.Now().Add(c.TCPTimeout)); err != nil {
return nil, err
}
}
return TCPConn, nil
}

33
client/option.go Normal file

@ -0,0 +1,33 @@
package client
import (
"time"
"golang.org/x/net/proxy"
)
type Option func(c *Client)
func WithAuth(auth *proxy.Auth) Option {
return func(c *Client) {
c.Auth = auth
}
}
func WithTCPTimeout(t time.Duration) Option {
return func(c *Client) {
c.TCPTimeout = t
}
}
func WithTCPDeadline(t time.Duration) Option {
return func(c *Client) {
c.TCPDeadline = t
}
}
func WithUDPDeadline(t time.Duration) Option {
return func(c *Client) {
c.UDPDeadline = t
}
}

@ -305,7 +305,7 @@ func SendReply(w io.Writer, head statute.Header, resp uint8, bindAddr ...net.Add
head.Command = resp
if len(bindAddr) == 0 {
head.AddrType = statute.ATYPIPv4
head.Address.AddrType = statute.ATYPIPv4
head.Address.IP = []byte{0, 0, 0, 0}
head.Address.Port = 0
} else {
@ -322,15 +322,15 @@ func SendReply(w io.Writer, head statute.Header, resp uint8, bindAddr ...net.Add
}
switch {
case addrSpec.FQDN != "":
head.AddrType = statute.ATYPDomain
head.Address.AddrType = statute.ATYPDomain
head.Address.FQDN = addrSpec.FQDN
head.Address.Port = addrSpec.Port
case addrSpec.IP.To4() != nil:
head.AddrType = statute.ATYPIPv4
head.Address.AddrType = statute.ATYPIPv4
head.Address.IP = addrSpec.IP.To4()
head.Address.Port = addrSpec.Port
case addrSpec.IP.To16() != nil:
head.AddrType = statute.ATYPIPv6
head.Address.AddrType = statute.ATYPIPv6
head.Address.IP = addrSpec.IP.To16()
head.Address.Port = addrSpec.Port
default:

@ -7,8 +7,9 @@ import (
"log"
"net"
"os"
"strings"
"testing"
"github.com/stretchr/testify/require"
)
type MockConn struct {
@ -26,30 +27,24 @@ func (m *MockConn) RemoteAddr() net.Addr {
func TestRequest_Connect(t *testing.T) {
// Create a local listener
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("err: %v", err)
}
require.NoError(t, err)
go func() {
conn, err := l.Accept()
if err != nil {
t.Fatalf("err: %v", err)
}
require.NoError(t, err)
defer conn.Close()
buf := make([]byte, 4)
if _, err := io.ReadAtLeast(conn, buf, 4); err != nil {
t.Fatalf("err: %v", err)
}
_, err = io.ReadAtLeast(conn, buf, 4)
require.NoError(t, err)
require.Equal(t, []byte("ping"), buf)
if !bytes.Equal(buf, []byte("ping")) {
t.Fatalf("bad: %v", buf)
}
_, _ = conn.Write([]byte("pong"))
conn.Write([]byte("pong")) // nolint: errcheck
}()
lAddr := l.Addr().(*net.TCPAddr)
// Make server
s := &Server{
proxySrv := &Server{
rules: NewPermitAll(),
resolver: DNSResolver{},
logger: NewLogger(log.New(os.Stdout, "socks5: ", log.LstdFlags)),
@ -58,25 +53,22 @@ func TestRequest_Connect(t *testing.T) {
// Create the connect request
buf := bytes.NewBuffer(nil)
buf.Write([]byte{5, 1, 0, 1, 127, 0, 0, 1})
buf.Write([]byte{5, 1, 0, 1, 127, 0, 0, 1}) // nolint: errcheck
port := []byte{0, 0}
binary.BigEndian.PutUint16(port, uint16(lAddr.Port))
buf.Write(port)
buf.Write(port) // nolint: errcheck
// Send a ping
buf.Write([]byte("ping"))
buf.Write([]byte("ping")) // nolint: errcheck
// Handle the request
resp := &MockConn{}
req, err := NewRequest(buf)
if err != nil {
t.Fatalf("err: %v", err)
}
require.NoError(t, err)
if err := s.handleRequest(resp, req); err != nil {
t.Fatalf("err: %v", err)
}
err = proxySrv.handleRequest(resp, req)
require.NoError(t, err)
// Verify response
out := resp.buf.Bytes()
@ -93,34 +85,24 @@ func TestRequest_Connect(t *testing.T) {
// Ignore the port for both
out[8] = 0
out[9] = 0
if !bytes.Equal(out, expected) {
t.Fatalf("bad: %v %v", out, expected)
}
require.Equal(t, expected, out)
}
func TestRequest_Connect_RuleFail(t *testing.T) {
// Create a local listener
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("err: %v", err)
}
require.NoError(t, err)
go func() {
conn, err := l.Accept()
if err != nil {
t.Fatalf("err: %v", err)
}
require.NoError(t, err)
defer conn.Close()
buf := make([]byte, 4)
if _, err := io.ReadAtLeast(conn, buf, 4); err != nil {
t.Fatalf("err: %v", err)
}
if !bytes.Equal(buf, []byte("ping")) {
t.Fatalf("bad: %v", buf)
}
_, _ = conn.Write([]byte("pong"))
_, err = io.ReadAtLeast(conn, buf, 4)
require.NoError(t, err)
require.Equal(t, []byte("ping"), buf)
conn.Write([]byte("pong")) // nolint: errcheck
}()
lAddr := l.Addr().(*net.TCPAddr)
@ -146,13 +128,10 @@ func TestRequest_Connect_RuleFail(t *testing.T) {
// Handle the request
resp := &MockConn{}
req, err := NewRequest(buf)
if err != nil {
t.Fatalf("err: %v", err)
}
require.NoError(t, err)
if err := s.handleRequest(resp, req); !strings.Contains(err.Error(), "blocked by rules") {
t.Fatalf("err: %v", err)
}
err = s.handleRequest(resp, req)
require.Contains(t, err.Error(), "blocked by rules")
// Verify response
out := resp.buf.Bytes()
@ -164,8 +143,5 @@ func TestRequest_Connect_RuleFail(t *testing.T) {
0, 0, 0, 0,
0, 0,
}
if !bytes.Equal(out, expected) {
t.Fatalf("bad: %v %v", out, expected)
}
require.Equal(t, expected, out)
}

@ -56,8 +56,8 @@ type Server struct {
userAssociateHandle func(ctx context.Context, writer io.Writer, request *Request) error
}
// New creates a new Server and potentially returns an error
func New(opts ...Option) *Server {
// NewServer creates a new Server and potentially returns an error
func NewServer(opts ...Option) *Server {
server := &Server{
authMethods: make(map[uint8]Authenticator),
authCustomMethods: []Authenticator{&NoAuthAuthenticator{}},

@ -20,50 +20,38 @@ import (
func TestSOCKS5_Connect(t *testing.T) {
// Create a local listener
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("err: %v", err)
}
require.NoError(t, err)
go func() {
conn, err := l.Accept()
if err != nil {
t.Fatalf("err: %v", err)
}
require.NoError(t, err)
defer conn.Close()
buf := make([]byte, 4)
if _, err := io.ReadAtLeast(conn, buf, 4); err != nil {
t.Fatalf("err: %v", err)
}
if !bytes.Equal(buf, []byte("ping")) {
t.Fatalf("bad: %v", buf)
}
_, err = io.ReadAtLeast(conn, buf, 4)
require.NoError(t, err)
assert.Equal(t, []byte("ping"), buf)
_, _ = conn.Write([]byte("pong"))
}()
lAddr := l.Addr().(*net.TCPAddr)
// Create a socks server
cator := UserPassAuthenticator{
Credentials: StaticCredentials{"foo": "bar"},
}
serv := New(
cator := UserPassAuthenticator{StaticCredentials{"foo": "bar"}}
srv := NewServer(
WithAuthMethods([]Authenticator{cator}),
WithLogger(NewLogger(log.New(os.Stdout, "socks5: ", log.LstdFlags))),
)
// Start listening
go func() {
if err := serv.ListenAndServe("tcp", "127.0.0.1:12365"); err != nil {
t.Fatalf("err: %v", err)
}
err := srv.ListenAndServe("tcp", "127.0.0.1:12365")
require.NoError(t, err)
}()
time.Sleep(10 * time.Millisecond)
// Get a local conn
conn, err := net.Dial("tcp", "127.0.0.1:12365")
if err != nil {
t.Fatalf("err: %v", err)
}
require.NoError(t, err)
// Connect, auth and connec to local
req := new(bytes.Buffer)
@ -77,8 +65,8 @@ func TestSOCKS5_Connect(t *testing.T) {
"",
net.ParseIP("127.0.0.1"),
lAddr.Port,
statute.ATYPIPv4,
},
AddrType: statute.ATYPIPv4,
}
req.Write(reqHead.Bytes())
// Send a ping
@ -100,27 +88,23 @@ func TestSOCKS5_Connect(t *testing.T) {
"",
net.ParseIP("127.0.0.1"),
0, // Ignore the port
statute.ATYPIPv4,
},
AddrType: statute.ATYPIPv4,
}
expected = append(expected, rspHead.Bytes()...)
expected = append(expected, []byte("pong")...)
out := make([]byte, len(expected))
_ = conn.SetDeadline(time.Now().Add(time.Second))
if _, err := io.ReadFull(conn, out); err != nil {
t.Fatalf("err: %v", err)
}
_, err = io.ReadFull(conn, out)
require.NoError(t, err)
t.Logf("proxy bind port: %d", statute.BuildPort(out[12], out[13]))
// Ignore the port
out[12] = 0
out[13] = 0
if !bytes.Equal(out, expected) {
t.Fatalf("bad: %v", out)
}
assert.Equal(t, expected, out)
}
func TestSOCKS5_Associate(t *testing.T) {
@ -131,10 +115,9 @@ func TestSOCKS5_Associate(t *testing.T) {
Port: 12398,
}
l, err := net.ListenUDP("udp", lAddr)
if err != nil {
t.Fatalf("err: %v", err)
}
require.NoError(t, err)
defer l.Close()
go func() {
buf := make([]byte, 2048)
for {
@ -142,33 +125,28 @@ func TestSOCKS5_Associate(t *testing.T) {
if err != nil {
return
}
require.Equal(t, []byte("ping"), buf[:n])
if !bytes.Equal(buf[:n], []byte("ping")) {
t.Fatalf("bad: %v", buf)
}
_, _ = l.WriteTo([]byte("pong"), remote)
l.WriteTo([]byte("pong"), remote) // nolint: errcheck
}
}()
// Create a socks server
cator := UserPassAuthenticator{Credentials: StaticCredentials{"foo": "bar"}}
serv := New(
cator := UserPassAuthenticator{StaticCredentials{"foo": "bar"}}
srv := NewServer(
WithAuthMethods([]Authenticator{cator}),
WithLogger(NewLogger(log.New(os.Stdout, "socks5: ", log.LstdFlags))),
)
// Start listening
go func() {
if err := serv.ListenAndServe("tcp", "127.0.0.1:12355"); err != nil {
t.Fatalf("err: %v", err)
}
err := srv.ListenAndServe("tcp", "127.0.0.1:12355")
require.NoError(t, err)
}()
time.Sleep(10 * time.Millisecond)
// Get a local conn
conn, err := net.Dial("tcp", "127.0.0.1:12355")
if err != nil {
t.Fatalf("err: %v", err)
}
require.NoError(t, err)
// Connect, auth and connec to local
req := new(bytes.Buffer)
@ -182,12 +160,12 @@ func TestSOCKS5_Associate(t *testing.T) {
"",
locIP,
lAddr.Port,
statute.ATYPIPv4,
},
AddrType: statute.ATYPIPv4,
}
req.Write(reqHead.Bytes())
// Send all the bytes
conn.Write(req.Bytes())
conn.Write(req.Bytes()) // nolint: errcheck
// Verify response
expected := []byte{
@ -197,21 +175,14 @@ func TestSOCKS5_Associate(t *testing.T) {
out := make([]byte, len(expected))
_ = conn.SetDeadline(time.Now().Add(time.Second))
if _, err := io.ReadFull(conn, out); err != nil {
t.Fatalf("err: %v", err)
}
if !bytes.Equal(out, expected) {
t.Fatalf("bad: %v", out)
}
_, err = io.ReadFull(conn, out)
require.NoError(t, err)
require.Equal(t, expected, out)
rspHead, err := statute.ParseHeader(conn)
if err != nil {
t.Fatalf("bad response header: %v", err)
}
if rspHead.Version != statute.VersionSocks5 && rspHead.Command != statute.RepSuccess {
t.Fatalf("parse success but bad header: %v", rspHead)
}
require.NoError(t, err)
require.Equal(t, statute.VersionSocks5, rspHead.Version)
require.Equal(t, statute.RepSuccess, rspHead.Command)
t.Logf("proxy bind listen port: %d", rspHead.Address.Port)
@ -219,11 +190,9 @@ func TestSOCKS5_Associate(t *testing.T) {
IP: locIP,
Port: rspHead.Address.Port,
})
if err != nil {
t.Fatalf("bad dial: %v", err)
}
require.NoError(t, err)
// Send a ping
_, _ = udpConn.Write(append([]byte{0, 0, 0, statute.ATYPIPv4, 0, 0, 0, 0, 0, 0}, []byte("ping")...))
udpConn.Write(append([]byte{0, 0, 0, statute.ATYPIPv4, 0, 0, 0, 0, 0, 0}, []byte("ping")...)) // nolint: errcheck
response := make([]byte, 1024)
n, _, err := udpConn.ReadFrom(response)
if err != nil || !bytes.Equal(response[n-4:n], []byte("pong")) {
@ -235,66 +204,52 @@ func TestSOCKS5_Associate(t *testing.T) {
func Test_SocksWithProxy(t *testing.T) {
// Create a local listener
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("err: %v", err)
}
require.NoError(t, err)
go func() {
conn, err := l.Accept()
if err != nil {
t.Fatalf("err: %v", err)
}
require.NoError(t, err)
defer conn.Close()
buf := make([]byte, 4)
if _, err := io.ReadAtLeast(conn, buf, 4); err != nil {
t.Fatalf("err: %v", err)
}
_, err = io.ReadAtLeast(conn, buf, 4)
require.NoError(t, err)
require.Equal(t, []byte("ping"), buf)
if !bytes.Equal(buf, []byte("ping")) {
t.Fatalf("bad: %v", buf)
}
conn.Write([]byte("pong"))
}()
lAddr := l.Addr().(*net.TCPAddr)
// Create a socks server
cator := UserPassAuthenticator{Credentials: StaticCredentials{"foo": "bar"}}
serv := New(
cator := UserPassAuthenticator{StaticCredentials{"foo": "bar"}}
serv := NewServer(
WithAuthMethods([]Authenticator{cator}),
WithLogger(NewLogger(log.New(os.Stdout, "socks5: ", log.LstdFlags))),
)
// Start listening
// Start socks server
go func() {
if err := serv.ListenAndServe("tcp", "127.0.0.1:12395"); err != nil {
t.Fatalf("err: %v", err)
}
err := serv.ListenAndServe("tcp", "127.0.0.1:12395")
require.NoError(t, err)
}()
time.Sleep(10 * time.Millisecond)
// client
dial, err := proxy.SOCKS5("tcp", "127.0.0.1:12395", &proxy.Auth{User: "foo", Password: "bar"}, proxy.Direct)
if err != nil {
t.Fatalf("err: %v", err)
}
require.NoError(t, err)
// Connect, auth and connect to local
conn, err := dial.Dial("tcp", lAddr.String())
if err != nil {
t.Fatalf("err: %v", err)
}
require.NoError(t, err)
// Send a ping
_, _ = conn.Write([]byte("ping"))
conn.Write([]byte("ping")) // nolint: errcheck
out := make([]byte, 4)
_ = conn.SetDeadline(time.Now().Add(time.Second))
if _, err := io.ReadFull(conn, out); err != nil {
t.Fatalf("err: %v", err)
}
_ = conn.SetDeadline(time.Now().Add(time.Second)) // nolint: errcheck
_, err = io.ReadFull(conn, out)
require.NoError(t, err)
if !bytes.Equal(out, []byte("pong")) {
t.Fatalf("bad: %v", out)
}
require.Equal(t, []byte("pong"), out)
}
/***************************** auth *******************************/
@ -302,7 +257,7 @@ func Test_SocksWithProxy(t *testing.T) {
func TestNoAuth_Server(t *testing.T) {
req := bytes.NewBuffer(nil)
rsp := new(bytes.Buffer)
s := New()
s := NewServer()
ctx, err := s.authenticate(rsp, req, "", []byte{statute.MethodNoAuth})
require.NoError(t, err)
@ -314,11 +269,9 @@ func TestPasswordAuth_Valid_Server(t *testing.T) {
req := bytes.NewBuffer([]byte{1, 3, 'f', 'o', 'o', 3, 'b', 'a', 'r'})
rsp := new(bytes.Buffer)
cator := UserPassAuthenticator{
StaticCredentials{
"foo": "bar",
},
StaticCredentials{"foo": "bar"},
}
s := New(WithAuthMethods([]Authenticator{cator}))
s := NewServer(WithAuthMethods([]Authenticator{cator}))
ctx, err := s.authenticate(rsp, req, "", []byte{statute.MethodUserPassAuth})
require.NoError(t, err)
@ -339,11 +292,9 @@ func TestPasswordAuth_Invalid_Server(t *testing.T) {
req := bytes.NewBuffer([]byte{1, 3, 'f', 'o', 'o', 3, 'b', 'a', 'z'})
rsp := new(bytes.Buffer)
cator := UserPassAuthenticator{
StaticCredentials{
"foo": "bar",
},
StaticCredentials{"foo": "bar"},
}
s := New(WithAuthMethods([]Authenticator{cator}))
s := NewServer(WithAuthMethods([]Authenticator{cator}))
ctx, err := s.authenticate(rsp, req, "", []byte{statute.MethodNoAuth, statute.MethodUserPassAuth})
require.True(t, errors.Is(err, statute.ErrUserAuthFailed))
@ -356,12 +307,10 @@ func TestNoSupportedAuth_Server(t *testing.T) {
req := bytes.NewBuffer(nil)
rsp := new(bytes.Buffer)
cator := UserPassAuthenticator{
StaticCredentials{
"foo": "bar",
},
StaticCredentials{"foo": "bar"},
}
s := New(WithAuthMethods([]Authenticator{cator}))
s := NewServer(WithAuthMethods([]Authenticator{cator}))
ctx, err := s.authenticate(rsp, req, "", []byte{statute.MethodNoAuth})
require.True(t, errors.Is(err, statute.ErrNoSupportedAuth))

@ -5,6 +5,12 @@ import (
"io"
)
// auth error defined
var (
ErrUserAuthFailed = fmt.Errorf("user authentication failed")
ErrNoSupportedAuth = fmt.Errorf("no supported authentication mechanism")
)
// UserPassRequest is the negotiation user's password request packet
// The SOCKS handshake user's password request is formed as follows:
// +--------------+------+----------+------+----------+

@ -7,4 +7,5 @@ import (
var (
ErrUnrecognizedAddrType = errors.New("Unrecognized address type")
ErrNotSupportVersion = errors.New("not support version")
ErrNotSupportMethod = errors.New("not support method")
)

@ -31,8 +31,6 @@ type Header struct {
Reserved uint8
// Address in socks message
Address AddrSpec
// private stuff set when Header parsed
AddrType uint8
}
// ParseHeader to header from io.Reader
@ -54,8 +52,8 @@ func ParseHeader(r io.Reader) (hd Header, err error) {
return hd, fmt.Errorf("failed to get header RSV and address type, %v", err)
}
hd.Reserved = tmp[0]
hd.AddrType = tmp[1]
switch hd.AddrType {
hd.Address.AddrType = tmp[1]
switch hd.Address.AddrType {
case ATYPDomain:
if _, err = io.ReadFull(r, tmp[:1]); err != nil {
return hd, fmt.Errorf("failed to get header, %v", err)
@ -93,10 +91,10 @@ func (h Header) Bytes() (b []byte) {
var addr []byte
length := headerVERLen + headerCMDLen + headerRSVLen + headerATYPLen + headerPORTLen
if h.AddrType == ATYPIPv4 {
if h.Address.AddrType == ATYPIPv4 {
length += net.IPv4len
addr = h.Address.IP.To4()
} else if h.AddrType == ATYPIPv6 {
} else if h.Address.AddrType == ATYPIPv6 {
length += net.IPv6len
addr = h.Address.IP.To16()
} else { //ATYPDomain
@ -108,8 +106,8 @@ func (h Header) Bytes() (b []byte) {
b = append(b, h.Version)
b = append(b, h.Command)
b = append(b, h.Reserved)
b = append(b, h.AddrType)
if h.AddrType == ATYPDomain {
b = append(b, h.Address.AddrType)
if h.Address.AddrType == ATYPDomain {
b = append(b, byte(len(h.Address.FQDN)))
}
b = append(b, addr...)

@ -23,8 +23,7 @@ func TestParseHeader(t *testing.T) {
args{bytes.NewReader([]byte{VersionSocks5, CommandConnect, 0, ATYPIPv4, 127, 0, 0, 1, 0x1f, 0x90})},
Header{
VersionSocks5, CommandConnect, 0,
AddrSpec{IP: net.IPv4(127, 0, 0, 1), Port: 8080},
ATYPIPv4,
AddrSpec{IP: net.IPv4(127, 0, 0, 1), Port: 8080, AddrType: ATYPIPv4},
},
false,
},
@ -33,8 +32,7 @@ func TestParseHeader(t *testing.T) {
args{bytes.NewReader([]byte{VersionSocks5, CommandConnect, 0, ATYPIPv6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x1f, 0x90})},
Header{
VersionSocks5, CommandConnect, 0,
AddrSpec{IP: net.IPv6zero, Port: 8080},
ATYPIPv6,
AddrSpec{IP: net.IPv6zero, Port: 8080, AddrType: ATYPIPv6},
},
false,
},
@ -43,8 +41,7 @@ func TestParseHeader(t *testing.T) {
args{bytes.NewReader([]byte{VersionSocks5, CommandConnect, 0, ATYPDomain, 9, 'l', 'o', 'c', 'a', 'l', 'h', 'o', 's', 't', 0x1f, 0x90})},
Header{
VersionSocks5, CommandConnect, 0,
AddrSpec{FQDN: "localhost", Port: 8080},
ATYPDomain,
AddrSpec{FQDN: "localhost", Port: 8080, AddrType: ATYPDomain},
},
false,
},
@ -74,8 +71,7 @@ func TestHeader_Bytes(t *testing.T) {
"SOCKS5 IPV4",
Header{
VersionSocks5, CommandConnect, 0,
AddrSpec{IP: net.IPv4(127, 0, 0, 1), Port: 8080},
ATYPIPv4,
AddrSpec{IP: net.IPv4(127, 0, 0, 1), Port: 8080, AddrType: ATYPIPv4},
},
[]byte{VersionSocks5, CommandConnect, 0, ATYPIPv4, 127, 0, 0, 1, 0x1f, 0x90},
},
@ -83,8 +79,7 @@ func TestHeader_Bytes(t *testing.T) {
"SOCKS5 IPV6",
Header{
VersionSocks5, CommandConnect, 0,
AddrSpec{IP: net.IPv6zero, Port: 8080},
ATYPIPv6,
AddrSpec{IP: net.IPv6zero, Port: 8080, AddrType: ATYPIPv6},
},
[]byte{VersionSocks5, CommandConnect, 0, ATYPIPv6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x1f, 0x90},
},
@ -92,8 +87,7 @@ func TestHeader_Bytes(t *testing.T) {
"SOCKS5 FQDN",
Header{
VersionSocks5, CommandConnect, 0,
AddrSpec{FQDN: "localhost", Port: 8080},
ATYPDomain,
AddrSpec{FQDN: "localhost", Port: 8080, AddrType: ATYPDomain},
},
[]byte{VersionSocks5, CommandConnect, 0, ATYPDomain, 9, 'l', 'o', 'c', 'a', 'l', 'h', 'o', 's', 't', 0x1f, 0x90},
},

@ -1,9 +1,5 @@
package statute
import (
"fmt"
)
// auth defined
const (
MethodNoAuth = byte(0x00)
@ -44,9 +40,3 @@ const (
RepAddrTypeNotSupported
// 0x09 - 0xff unassigned
)
// auth error defined
var (
ErrUserAuthFailed = fmt.Errorf("user authentication failed")
ErrNoSupportedAuth = fmt.Errorf("no supported authentication mechanism")
)

@ -12,6 +12,8 @@ type AddrSpec struct {
FQDN string
IP net.IP
Port int
// private stuff set when Header parsed
AddrType uint8
}
// Address returns a string suitable to dial; prefer returning IP-based
@ -44,10 +46,13 @@ func ParseAddrSpec(address string) (a AddrSpec, err error) {
}
ip := net.ParseIP(host)
if ip4 := ip.To4(); ip4 != nil {
a.AddrType = ATYPIPv4
a.IP = ip
} else if ip6 := ip.To16(); ip6 != nil {
a.AddrType = ATYPIPv6
a.IP = ip
} else {
a.AddrType = ATYPDomain
a.FQDN = host
}
a.Port, err = strconv.Atoi(port)

@ -45,8 +45,9 @@ func TestParseAddrSpec1(t *testing.T) {
"IPv4",
args{"127.0.0.1:8080"},
AddrSpec{
IP: net.IPv4(127, 0, 0, 1),
Port: 8080,
IP: net.IPv4(127, 0, 0, 1),
Port: 8080,
AddrType: ATYPIPv4,
},
false,
},
@ -54,8 +55,9 @@ func TestParseAddrSpec1(t *testing.T) {
"IPv6",
args{"[::1]:8080"},
AddrSpec{
IP: net.IPv6loopback,
Port: 8080,
IP: net.IPv6loopback,
Port: 8080,
AddrType: ATYPIPv6,
},
false,
},
@ -63,8 +65,9 @@ func TestParseAddrSpec1(t *testing.T) {
"FQDN",
args{"localhost:8080"},
AddrSpec{
FQDN: "localhost",
Port: 8080,
FQDN: "localhost",
Port: 8080,
AddrType: ATYPDomain,
},
false,
},