moved statute

and fix all file struct
This commit is contained in:
mo 2020-08-05 13:17:05 +08:00
parent ceaf26cca1
commit 6f88e8250a
17 changed files with 600 additions and 201 deletions

82
auth.go

@ -3,23 +3,8 @@ package socks5
import (
"fmt"
"io"
)
// auth defined
const (
MethodNoAuth = uint8(0)
MethodGSSAPI = uint8(1)
MethodUserPassAuth = uint8(2)
MethodNoAcceptable = uint8(255)
UserPassAuthVersion = uint8(1)
AuthSuccess = uint8(0)
AuthFailure = uint8(1)
)
// auth error defined
var (
ErrUserAuthFailed = fmt.Errorf("user authentication failed")
ErrNoSupportedAuth = fmt.Errorf("no supported authentication mechanism")
"github.com/thinkgos/go-socks5/statute"
)
// AuthContext A Request encapsulates authentication state provided
@ -44,13 +29,13 @@ type NoAuthAuthenticator struct{}
// GetCode implement interface Authenticator
func (a NoAuthAuthenticator) GetCode() uint8 {
return MethodNoAuth
return statute.MethodNoAuth
}
// Authenticate implement interface Authenticator
func (a NoAuthAuthenticator) Authenticate(_ io.Reader, writer io.Writer, _ string) (*AuthContext, error) {
_, err := writer.Write([]byte{VersionSocks5, MethodNoAuth})
return &AuthContext{MethodNoAuth, make(map[string]string)}, err
_, err := writer.Write([]byte{statute.VersionSocks5, statute.MethodNoAuth})
return &AuthContext{statute.MethodNoAuth, make(map[string]string)}, err
}
// UserPassAuthenticator is used to handle username/password based
@ -61,60 +46,41 @@ type UserPassAuthenticator struct {
// GetCode implement interface Authenticator
func (a UserPassAuthenticator) GetCode() uint8 {
return MethodUserPassAuth
return statute.MethodUserPassAuth
}
// Authenticate implement interface Authenticator
func (a UserPassAuthenticator) Authenticate(reader io.Reader, writer io.Writer, userAddr string) (*AuthContext, error) {
// Tell the client to use user/pass auth
if _, err := writer.Write([]byte{VersionSocks5, MethodUserPassAuth}); err != nil {
if _, err := writer.Write([]byte{statute.VersionSocks5, statute.MethodUserPassAuth}); err != nil {
return nil, err
}
// Get the version and username length
header := []byte{0, 0}
if _, err := io.ReadAtLeast(reader, header, 2); err != nil {
return nil, err
}
// Ensure we are compatible
if header[0] != UserPassAuthVersion {
return nil, fmt.Errorf("unsupported auth version: %v", header[0])
}
// Get the user name
userLen := int(header[1])
user := make([]byte, userLen)
if _, err := io.ReadAtLeast(reader, user, userLen); err != nil {
return nil, err
}
// Get the password length
if _, err := reader.Read(header[:1]); err != nil {
return nil, err
}
// Get the password
passLen := int(header[0])
pass := make([]byte, passLen)
if _, err := io.ReadAtLeast(reader, pass, passLen); err != nil {
nup, err := statute.ParseUserPassRequest(reader)
if err != nil {
return nil, err
}
// Verify the password
if a.Credentials.Valid(string(user), string(pass), userAddr) {
if _, err := writer.Write([]byte{UserPassAuthVersion, AuthSuccess}); err != nil {
if !a.Credentials.Valid(string(nup.User), string(nup.Pass), userAddr) {
if _, err := writer.Write([]byte{statute.UserPassAuthVersion, statute.AuthFailure}); err != nil {
return nil, err
}
} else {
if _, err := writer.Write([]byte{UserPassAuthVersion, AuthFailure}); err != nil {
return nil, err
}
return nil, ErrUserAuthFailed
return nil, statute.ErrUserAuthFailed
}
if _, err := writer.Write([]byte{statute.UserPassAuthVersion, statute.AuthSuccess}); err != nil {
return nil, err
}
// Done
return &AuthContext{MethodUserPassAuth, map[string]string{"username": string(user), "password": string(pass)}}, nil
return &AuthContext{
statute.MethodUserPassAuth,
map[string]string{
"username": string(nup.User),
"password": string(nup.Pass),
},
}, nil
}
// authenticate is used to handle connection authentication
@ -140,8 +106,8 @@ func (s *Server) authenticate(conn io.Writer, bufConn io.Reader, userAddr string
// noAcceptableAuth is used to handle when we have no eligible
// authentication mechanism
func noAcceptableAuth(conn io.Writer) error {
conn.Write([]byte{VersionSocks5, MethodNoAcceptable})
return ErrNoSupportedAuth
conn.Write([]byte{statute.VersionSocks5, statute.MethodNoAcceptable})
return statute.ErrNoSupportedAuth
}
// readMethods is used to read the number of methods

@ -3,11 +3,13 @@ package socks5
import (
"bytes"
"testing"
"github.com/thinkgos/go-socks5/statute"
)
func TestNoAuth(t *testing.T) {
req := bytes.NewBuffer(nil)
req.Write([]byte{1, MethodNoAuth})
req.Write([]byte{1, statute.MethodNoAuth})
var resp bytes.Buffer
s := New()
@ -16,19 +18,19 @@ func TestNoAuth(t *testing.T) {
t.Fatalf("err: %v", err)
}
if ctx.Method != MethodNoAuth {
if ctx.Method != statute.MethodNoAuth {
t.Fatal("Invalid Context Method")
}
out := resp.Bytes()
if !bytes.Equal(out, []byte{VersionSocks5, MethodNoAuth}) {
if !bytes.Equal(out, []byte{statute.VersionSocks5, statute.MethodNoAuth}) {
t.Fatalf("bad: %v", out)
}
}
func TestPasswordAuth_Valid(t *testing.T) {
req := bytes.NewBuffer(nil)
req.Write([]byte{2, MethodNoAuth, MethodUserPassAuth})
req.Write([]byte{2, statute.MethodNoAuth, statute.MethodUserPassAuth})
req.Write([]byte{1, 3, 'f', 'o', 'o', 3, 'b', 'a', 'r'})
var resp bytes.Buffer
@ -45,7 +47,7 @@ func TestPasswordAuth_Valid(t *testing.T) {
t.Fatalf("err: %v", err)
}
if ctx.Method != MethodUserPassAuth {
if ctx.Method != statute.MethodUserPassAuth {
t.Fatal("Invalid Context Method")
}
@ -68,14 +70,14 @@ func TestPasswordAuth_Valid(t *testing.T) {
}
out := resp.Bytes()
if !bytes.Equal(out, []byte{VersionSocks5, MethodUserPassAuth, 1, AuthSuccess}) {
if !bytes.Equal(out, []byte{statute.VersionSocks5, statute.MethodUserPassAuth, 1, statute.AuthSuccess}) {
t.Fatalf("bad: %v", out)
}
}
func TestPasswordAuth_Invalid(t *testing.T) {
req := bytes.NewBuffer(nil)
req.Write([]byte{2, MethodNoAuth, MethodUserPassAuth})
req.Write([]byte{2, statute.MethodNoAuth, statute.MethodUserPassAuth})
req.Write([]byte{1, 3, 'f', 'o', 'o', 3, 'b', 'a', 'z'})
var resp bytes.Buffer
@ -86,7 +88,7 @@ func TestPasswordAuth_Invalid(t *testing.T) {
s := New(WithAuthMethods([]Authenticator{cator}))
ctx, err := s.authenticate(&resp, req, "")
if err != ErrUserAuthFailed {
if err != statute.ErrUserAuthFailed {
t.Fatalf("err: %v", err)
}
@ -95,14 +97,14 @@ func TestPasswordAuth_Invalid(t *testing.T) {
}
out := resp.Bytes()
if !bytes.Equal(out, []byte{VersionSocks5, MethodUserPassAuth, 1, AuthFailure}) {
if !bytes.Equal(out, []byte{statute.VersionSocks5, statute.MethodUserPassAuth, 1, statute.AuthFailure}) {
t.Fatalf("bad: %v", out)
}
}
func TestNoSupportedAuth(t *testing.T) {
req := bytes.NewBuffer(nil)
req.Write([]byte{1, MethodNoAuth})
req.Write([]byte{1, statute.MethodNoAuth})
var resp bytes.Buffer
cred := StaticCredentials{
@ -113,7 +115,7 @@ func TestNoSupportedAuth(t *testing.T) {
s := New(WithAuthMethods([]Authenticator{cator}))
ctx, err := s.authenticate(&resp, req, "")
if err != ErrNoSupportedAuth {
if err != statute.ErrNoSupportedAuth {
t.Fatalf("err: %v", err)
}
@ -122,7 +124,7 @@ func TestNoSupportedAuth(t *testing.T) {
}
out := resp.Bytes()
if !bytes.Equal(out, []byte{VersionSocks5, MethodNoAcceptable}) {
if !bytes.Equal(out, []byte{statute.VersionSocks5, statute.MethodNoAcceptable}) {
t.Fatalf("bad: %v", out)
}
}

221
client.go Normal file

@ -0,0 +1,221 @@
package socks5
import (
"errors"
"net"
"time"
"github.com/thinkgos/go-socks5/statute"
)
// Client is socks5 client wrapper
type Client struct {
Server string
Version byte
UserName string
Password string
// On command UDP, let server control the tcp and udp connection relationship
TCPConn *net.TCPConn
UDPConn *net.UDPConn
RemoteAddress net.Addr
TCPDeadline int
TCPTimeout int
UDPDeadline int
}
// This is just create a client, you need to use Dial to create conn
func NewClient(addr, username, password string, tcpTimeout, tcpDeadline, udpDeadline int) (*Client, error) {
c := &Client{
Server: addr,
Version: statute.VersionSocks5,
UserName: username,
Password: password,
TCPTimeout: tcpTimeout,
TCPDeadline: tcpDeadline,
UDPDeadline: udpDeadline,
}
return c, nil
}
func (c *Client) Close() error {
if c.UDPConn == nil {
return c.TCPConn.Close()
}
if c.TCPConn != nil {
c.TCPConn.Close()
}
return c.UDPConn.Close()
}
func (c *Client) LocalAddr() net.Addr {
if c.UDPConn == nil {
return c.TCPConn.LocalAddr()
}
return c.UDPConn.LocalAddr()
}
func (c *Client) RemoteAddr() net.Addr {
return c.RemoteAddress
}
func (c *Client) SetDeadline(t time.Time) error {
if c.UDPConn == nil {
return c.TCPConn.SetDeadline(t)
}
return c.UDPConn.SetDeadline(t)
}
func (c *Client) SetReadDeadline(t time.Time) error {
if c.UDPConn == nil {
return c.TCPConn.SetReadDeadline(t)
}
return c.UDPConn.SetReadDeadline(t)
}
func (c *Client) SetWriteDeadline(t time.Time) error {
if c.UDPConn == nil {
return c.TCPConn.SetWriteDeadline(t)
}
return c.UDPConn.SetWriteDeadline(t)
}
func (c *Client) Read(b []byte) (int, error) {
if c.UDPConn == nil {
return c.TCPConn.Read(b)
}
// TODO: UDP data
// b1 := make([]byte, 65535)
// n, err := c.UDPConn.Read(b1)
// if err != nil {
// return 0, err
// }
// d, err := NewDatagramFromBytes(b1[0:n])
// if err != nil {
// return 0, err
// }
// if len(b) < len(d.Data) {
// return 0, errors.New("b too small")
// }
// n = copy(b, d.Data)
return 0, nil
}
func (c *Client) Write(b []byte) (int, error) {
if c.UDPConn == nil {
return c.TCPConn.Write(b)
}
// TODO: UPD data
// addr, err := ParseAddrSpec(c.RemoteAddress.String())
// if err != nil {
// return 0, err
// }
// if a == ATYPDomain {
// h = h[1:]
// }
// d := NewDatagram(a, h, p, b)
// b1 := d.Bytes()
return c.UDPConn.Write(b)
}
func (c *Client) Dial(network, addr string) (net.Conn, error) {
// var err error
//
// conn := *c
// if network == "tcp" {
// conn.RemoteAddress, err = net.ResolveTCPAddr("tcp", addr)
// if err != nil {
// return nil, err
// }
// if err := conn.Negotiate(); err != nil {
// return nil, err
// }
// a, h, p, err := ParseAddress(addr)
// if err != nil {
// return nil, err
// }
// if a == ATYPDomain {
// h = h[1:]
// }
// if _, err := conn.Request(NewRequest(CommandConnect, a, h, p)); err != nil {
// return nil, err
// }
// return conn, nil
// }
//
// if network == "udp" {
// conn.RemoteAddress, err = net.ResolveUDPAddr("udp", addr)
// if err != nil {
// return nil, err
// }
// if err := conn.Negotiate(); err != nil {
// return nil, err
// }
//
// laddr := &net.UDPAddr{
// IP: conn.TCPConn.LocalAddr().(*net.TCPAddr).IP,
// Port: conn.TCPConn.LocalAddr().(*net.TCPAddr).Port,
// Zone: conn.TCPConn.LocalAddr().(*net.TCPAddr).Zone,
// }
// a, h, p, err := ParseAddress(laddr.String())
// if err != nil {
// return nil, err
// }
// rp, err := conn.Request(NewRequest(CmdUDP, a, h, p))
// if err != nil {
// return nil, err
// }
// raddr, err := net.ResolveUDPAddr("udp", rp.Address())
// if err != nil {
// return nil, err
// }
// conn.UDPConn, err = net.DialUDP("udp", laddr, raddr)
// if err != nil {
// return nil, err
// }
// return conn, nil
// }
// return nil, errors.New("unsupport network")
return nil, errors.New("aaa")
}
func (c *Client) handshake() error {
methods := statute.MethodNoAuth
if c.UserName != "" && c.Password != "" {
methods = statute.MethodUserPassAuth
}
_, err := c.TCPConn.Write(statute.NewMethodRequest(c.Version, []byte{methods}).Bytes())
if err != nil {
return err
}
reply, err := statute.ParseMethodReply(c.TCPConn)
if err != nil {
return err
}
if reply.Ver != c.Version {
return errors.New("handshake failed cause version not same")
}
if reply.Method != methods {
return errors.New("unsupport method")
}
if methods == statute.MethodUserPassAuth {
_, err = c.TCPConn.Write(statute.NewNegotiationUserPassRequest(statute.UserPassAuthVersion, []byte(c.UserName), []byte(c.Password)).Bytes())
if err != nil {
return err
}
rsp, err := statute.ParseUserPassReply(c.TCPConn)
if err != nil {
return err
}
if rsp.Ver != statute.UserPassAuthVersion {
return errors.New("handshake failed cause version not same")
}
if rsp.Status != statute.RepSuccess {
return statute.ErrUserAuthFailed
}
}
return nil
}

@ -26,5 +26,5 @@ func (sf *pool) Put(b []byte) {
if cap(b) != sf.size {
panic("invalid buffer size that's put into leaky buffer")
}
sf.pool.Put(b[:0])
sf.pool.Put(b[:0]) // nolint: staticcheck
}

@ -3,16 +3,20 @@ package socks5
import (
"sync"
"testing"
"github.com/stretchr/testify/require"
)
func TestPool(t *testing.T) {
p := newPool(2048)
b := p.Get()
bs := b[0:cap(b)]
if len(bs) != cap(b) {
t.Fatalf("invalid buffer")
}
require.Equal(t, cap(b), len(bs))
p.Put(b)
p.Get()
p.Put(b)
p.Put(make([]byte, 2048))
require.Panics(t, func() { p.Put([]byte{}) })
}
func BenchmarkSyncPool(b *testing.B) {

@ -7,20 +7,18 @@ import (
"net"
"strings"
"sync"
)
var (
errUnrecognizedAddrType = fmt.Errorf("Unrecognized address type")
"github.com/thinkgos/go-socks5/statute"
)
// AddressRewriter is used to rewrite a destination transparently
type AddressRewriter interface {
Rewrite(ctx context.Context, request *Request) (context.Context, *AddrSpec)
Rewrite(ctx context.Context, request *Request) (context.Context, *statute.AddrSpec)
}
// A Request represents request received by a server
type Request struct {
Header
statute.Header
// AuthContext provided during negotiation
AuthContext *AuthContext
// LocalAddr of the the network server listen
@ -28,11 +26,11 @@ type Request struct {
// RemoteAddr of the the network that sent the request
RemoteAddr net.Addr
// DestAddr of the actual destination (might be affected by rewrite)
DestAddr *AddrSpec
DestAddr *statute.AddrSpec
// Reader connect of request
Reader io.Reader
// RawDestAddr of the desired destination
RawDestAddr *AddrSpec
RawDestAddr *statute.AddrSpec
}
// NewRequest creates a new Request from the tcp connection
@ -45,11 +43,11 @@ func NewRequest(bufConn io.Reader) (*Request, error) {
| 1 | 1 | X'00' | 1 | Variable | 2 |
+-----+-----+-------+------+----------+----------+
*/
hd, err := ParseHeader(bufConn)
hd, err := statute.ParseHeader(bufConn)
if err != nil {
return nil, err
}
if hd.Command != CommandConnect && hd.Command != CommandBind && hd.Command != CommandAssociate {
if hd.Command != statute.CommandConnect && hd.Command != statute.CommandBind && hd.Command != statute.CommandAssociate {
return nil, fmt.Errorf("unrecognized command[%d]", hd.Command)
}
return &Request{
@ -68,7 +66,7 @@ func (s *Server) handleRequest(write io.Writer, req *Request) error {
if dest.FQDN != "" {
_ctx, addr, err := s.resolver.Resolve(ctx, dest.FQDN)
if err != nil {
if err := SendReply(write, req.Header, RepHostUnreachable); err != nil {
if err := SendReply(write, req.Header, statute.RepHostUnreachable); err != nil {
return fmt.Errorf("failed to send reply, %v", err)
}
return fmt.Errorf("failed to resolve destination[%v], %v", dest.FQDN, err)
@ -86,7 +84,7 @@ func (s *Server) handleRequest(write io.Writer, req *Request) error {
// Check if this is allowed
_ctx, ok := s.rules.Allow(ctx, req)
if !ok {
if err := SendReply(write, req.Header, RepRuleFailure); err != nil {
if err := SendReply(write, req.Header, statute.RepRuleFailure); err != nil {
return fmt.Errorf("failed to send reply, %v", err)
}
return fmt.Errorf("bind to %v blocked by rules", req.RawDestAddr)
@ -95,23 +93,23 @@ func (s *Server) handleRequest(write io.Writer, req *Request) error {
// Switch on the command
switch req.Command {
case CommandConnect:
case statute.CommandConnect:
if s.userConnectHandle != nil {
return s.userConnectHandle(ctx, write, req)
}
return s.handleConnect(ctx, write, req)
case CommandBind:
case statute.CommandBind:
if s.userBindHandle != nil {
return s.userBindHandle(ctx, write, req)
}
return s.handleBind(ctx, write, req)
case CommandAssociate:
case statute.CommandAssociate:
if s.userAssociateHandle != nil {
return s.userAssociateHandle(ctx, write, req)
}
return s.handleAssociate(ctx, write, req)
default:
if err := SendReply(write, req.Header, RepCommandNotSupported); err != nil {
if err := SendReply(write, req.Header, statute.RepCommandNotSupported); err != nil {
return fmt.Errorf("failed to send reply, %v", err)
}
return fmt.Errorf("unsupported command[%v]", req.Command)
@ -130,11 +128,11 @@ func (s *Server) handleConnect(ctx context.Context, writer io.Writer, request *R
target, err := dial(ctx, "tcp", request.DestAddr.String())
if err != nil {
msg := err.Error()
resp := RepHostUnreachable
resp := statute.RepHostUnreachable
if strings.Contains(msg, "refused") {
resp = RepConnectionRefused
resp = statute.RepConnectionRefused
} else if strings.Contains(msg, "network is unreachable") {
resp = RepNetworkUnreachable
resp = statute.RepNetworkUnreachable
}
if err := SendReply(writer, request.Header, resp); err != nil {
return fmt.Errorf("failed to send reply, %v", err)
@ -144,7 +142,7 @@ func (s *Server) handleConnect(ctx context.Context, writer io.Writer, request *R
defer target.Close()
// Send success
if err := SendReply(writer, request.Header, RepSuccess, target.LocalAddr()); err != nil {
if err := SendReply(writer, request.Header, statute.RepSuccess, target.LocalAddr()); err != nil {
return fmt.Errorf("failed to send reply, %v", err)
}
@ -166,7 +164,7 @@ func (s *Server) handleConnect(ctx context.Context, writer io.Writer, request *R
// handleBind is used to handle a connect command
func (s *Server) handleBind(_ context.Context, writer io.Writer, request *Request) error {
// TODO: Support bind
if err := SendReply(writer, request.Header, RepCommandNotSupported); err != nil {
if err := SendReply(writer, request.Header, statute.RepCommandNotSupported); err != nil {
return fmt.Errorf("failed to send reply: %v", err)
}
return nil
@ -184,11 +182,11 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, request
target, err := dial(ctx, "udp", request.DestAddr.String())
if err != nil {
msg := err.Error()
resp := RepHostUnreachable
resp := statute.RepHostUnreachable
if strings.Contains(msg, "refused") {
resp = RepConnectionRefused
resp = statute.RepConnectionRefused
} else if strings.Contains(msg, "network is unreachable") {
resp = RepNetworkUnreachable
resp = statute.RepNetworkUnreachable
}
if err := SendReply(writer, request.Header, resp); err != nil {
return fmt.Errorf("failed to send reply, %v", err)
@ -199,7 +197,7 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, request
targetUDP, ok := target.(*net.UDPConn)
if !ok {
if err := SendReply(writer, request.Header, RepServerFailure); err != nil {
if err := SendReply(writer, request.Header, statute.RepServerFailure); err != nil {
return fmt.Errorf("failed to send reply, %v", err)
}
return fmt.Errorf("dial udp invalid")
@ -207,7 +205,7 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, request
bindLn, err := net.ListenUDP("udp", nil)
if err != nil {
if err := SendReply(writer, request.Header, RepServerFailure); err != nil {
if err := SendReply(writer, request.Header, statute.RepServerFailure); err != nil {
return fmt.Errorf("failed to send reply, %v", err)
}
return fmt.Errorf("listen udp failed, %v", err)
@ -216,7 +214,7 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, request
s.logger.Errorf("target addr %v, listen addr: %s", targetUDP.RemoteAddr(), bindLn.LocalAddr())
// send BND.ADDR and BND.PORT, client must
if err = SendReply(writer, request.Header, RepSuccess, bindLn.LocalAddr()); err != nil {
if err = SendReply(writer, request.Header, statute.RepSuccess, bindLn.LocalAddr()); err != nil {
return fmt.Errorf("failed to send reply, %v", err)
}
@ -247,7 +245,7 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, request
continue
}
pk := NewEmptyPacket()
pk := statute.Packet{}
if err := pk.Parse(bufPool[:n]); err != nil {
continue
}
@ -270,7 +268,7 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, request
return
}
pkb, err := NewPacket(remote.String(), buf[:n])
pkb, err := statute.NewPacket(remote.String(), buf[:n])
if err != nil {
continue
}
@ -311,7 +309,7 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, request
}
// SendReply is used to send a reply message
func SendReply(w io.Writer, head Header, resp uint8, bindAddr ...net.Addr) error {
func SendReply(w io.Writer, head statute.Header, resp uint8, bindAddr ...net.Addr) error {
/*
The SOCKS response is formed as follows:
+----+-----+-------+------+----------+----------+
@ -324,11 +322,11 @@ func SendReply(w io.Writer, head Header, resp uint8, bindAddr ...net.Addr) error
head.Command = resp
if len(bindAddr) == 0 {
head.addrType = ATYPIPv4
head.AddrType = statute.ATYPIPv4
head.Address.IP = []byte{0, 0, 0, 0}
head.Address.Port = 0
} else {
addrSpec := AddrSpec{}
addrSpec := statute.AddrSpec{}
if tcpAddr, ok := bindAddr[0].(*net.TCPAddr); ok && tcpAddr != nil {
addrSpec.IP = tcpAddr.IP
addrSpec.Port = tcpAddr.Port
@ -341,15 +339,15 @@ func SendReply(w io.Writer, head Header, resp uint8, bindAddr ...net.Addr) error
}
switch {
case addrSpec.FQDN != "":
head.addrType = ATYPDomain
head.AddrType = statute.ATYPDomain
head.Address.FQDN = addrSpec.FQDN
head.Address.Port = addrSpec.Port
case addrSpec.IP.To4() != nil:
head.addrType = ATYPIPv4
head.AddrType = statute.ATYPIPv4
head.Address.IP = addrSpec.IP.To4()
head.Address.Port = addrSpec.Port
case addrSpec.IP.To16() != nil:
head.addrType = ATYPIPv6
head.AddrType = statute.ATYPIPv6
head.Address.IP = addrSpec.IP.To16()
head.Address.Port = addrSpec.Port
default:

@ -2,6 +2,8 @@ package socks5
import (
"context"
"github.com/thinkgos/go-socks5/statute"
)
// RuleSet is used to provide custom rules to allow or prohibit actions
@ -30,11 +32,11 @@ type PermitCommand struct {
// Allow implement interface RuleSet
func (p *PermitCommand) Allow(ctx context.Context, req *Request) (context.Context, bool) {
switch req.Command {
case CommandConnect:
case statute.CommandConnect:
return ctx, p.EnableConnect
case CommandBind:
case statute.CommandBind:
return ctx, p.EnableBind
case CommandAssociate:
case statute.CommandAssociate:
return ctx, p.EnableAssociate
}
return ctx, false

@ -3,21 +3,23 @@ package socks5
import (
"context"
"testing"
"github.com/thinkgos/go-socks5/statute"
)
func TestPermitCommand(t *testing.T) {
ctx := context.Background()
r := &PermitCommand{true, false, false}
if _, ok := r.Allow(ctx, &Request{Header: Header{Command: CommandConnect}}); !ok {
if _, ok := r.Allow(ctx, &Request{Header: statute.Header{Command: statute.CommandConnect}}); !ok {
t.Fatalf("expect connect")
}
if _, ok := r.Allow(ctx, &Request{Header: Header{Command: CommandBind}}); ok {
if _, ok := r.Allow(ctx, &Request{Header: statute.Header{Command: statute.CommandBind}}); ok {
t.Fatalf("do not expect bind")
}
if _, ok := r.Allow(ctx, &Request{Header: Header{Command: CommandAssociate}}); ok {
if _, ok := r.Allow(ctx, &Request{Header: statute.Header{Command: statute.CommandAssociate}}); ok {
t.Fatalf("do not expect associate")
}
}

@ -8,6 +8,8 @@ import (
"io/ioutil"
"log"
"net"
"github.com/thinkgos/go-socks5/statute"
)
// GPool is used to implement custom goroutine pool default use goroutine
@ -132,7 +134,7 @@ func (s *Server) ServeConn(conn net.Conn) error {
var authContext *AuthContext
var err error
// Ensure we are compatible
if version[0] == VersionSocks5 {
if version[0] == statute.VersionSocks5 {
// Authenticate the connection
authContext, err = s.authenticate(conn, bufConn, conn.RemoteAddr().String())
if err != nil {
@ -140,7 +142,7 @@ func (s *Server) ServeConn(conn net.Conn) error {
s.logger.Errorf("%v", err)
return err
}
} else if version[0] != VersionSocks4 {
} else if version[0] != statute.VersionSocks4 {
err := fmt.Errorf("unsupported SOCKS version: %v", version[0])
s.logger.Errorf("%v", err)
return err
@ -149,14 +151,14 @@ func (s *Server) ServeConn(conn net.Conn) error {
// The client request detail
request, err := NewRequest(bufConn)
if err != nil {
if err == errUnrecognizedAddrType {
if err := SendReply(conn, Header{Version: version[0]}, RepAddrTypeNotSupported); err != nil {
if err == statute.ErrUnrecognizedAddrType {
if err := SendReply(conn, statute.Header{Version: version[0]}, statute.RepAddrTypeNotSupported); err != nil {
return fmt.Errorf("failed to send reply, %v", err)
}
}
return fmt.Errorf("failed to read destination address, %v", err)
}
if request.Header.Version == VersionSocks5 {
if request.Header.Version == statute.VersionSocks5 {
request.AuthContext = authContext
}
request.LocalAddr = conn.LocalAddr()

@ -10,6 +10,8 @@ import (
"time"
"golang.org/x/net/proxy"
"github.com/thinkgos/go-socks5/statute"
)
func TestSOCKS5_Connect(t *testing.T) {
@ -62,18 +64,18 @@ func TestSOCKS5_Connect(t *testing.T) {
// Connect, auth and connec to local
req := new(bytes.Buffer)
req.Write([]byte{VersionSocks5, 2, MethodNoAuth, MethodUserPassAuth})
req.Write([]byte{UserPassAuthVersion, 3, 'f', 'o', 'o', 3, 'b', 'a', 'r'})
reqHead := Header{
Version: VersionSocks5,
Command: CommandConnect,
req.Write([]byte{statute.VersionSocks5, 2, statute.MethodNoAuth, statute.MethodUserPassAuth})
req.Write([]byte{statute.UserPassAuthVersion, 3, 'f', 'o', 'o', 3, 'b', 'a', 'r'})
reqHead := statute.Header{
Version: statute.VersionSocks5,
Command: statute.CommandConnect,
Reserved: 0,
Address: AddrSpec{
Address: statute.AddrSpec{
"",
net.ParseIP("127.0.0.1"),
lAddr.Port,
},
addrType: ATYPIPv4,
AddrType: statute.ATYPIPv4,
}
req.Write(reqHead.Bytes())
// Send a ping
@ -84,19 +86,19 @@ func TestSOCKS5_Connect(t *testing.T) {
// Verify response
expected := []byte{
VersionSocks5, MethodUserPassAuth, // use user password auth
UserPassAuthVersion, AuthSuccess, // response auth success
statute.VersionSocks5, statute.MethodUserPassAuth, // use user password auth
statute.UserPassAuthVersion, statute.AuthSuccess, // response auth success
}
rspHead := Header{
Version: VersionSocks5,
Command: RepSuccess,
rspHead := statute.Header{
Version: statute.VersionSocks5,
Command: statute.RepSuccess,
Reserved: 0,
Address: AddrSpec{
Address: statute.AddrSpec{
"",
net.ParseIP("127.0.0.1"),
0, // Ignore the port
},
addrType: ATYPIPv4,
AddrType: statute.ATYPIPv4,
}
expected = append(expected, rspHead.Bytes()...)
expected = append(expected, []byte("pong")...)
@ -107,7 +109,7 @@ func TestSOCKS5_Connect(t *testing.T) {
t.Fatalf("err: %v", err)
}
t.Logf("proxy bind port: %d", buildPort(out[12], out[13]))
t.Logf("proxy bind port: %d", statute.BuildPort(out[12], out[13]))
// Ignore the port
out[12] = 0
@ -167,18 +169,18 @@ func TestSOCKS5_Associate(t *testing.T) {
// Connect, auth and connec to local
req := new(bytes.Buffer)
req.Write([]byte{VersionSocks5, 2, MethodNoAuth, MethodUserPassAuth})
req.Write([]byte{UserPassAuthVersion, 3, 'f', 'o', 'o', 3, 'b', 'a', 'r'})
reqHead := Header{
Version: VersionSocks5,
Command: CommandAssociate,
req.Write([]byte{statute.VersionSocks5, 2, statute.MethodNoAuth, statute.MethodUserPassAuth})
req.Write([]byte{statute.UserPassAuthVersion, 3, 'f', 'o', 'o', 3, 'b', 'a', 'r'})
reqHead := statute.Header{
Version: statute.VersionSocks5,
Command: statute.CommandAssociate,
Reserved: 0,
Address: AddrSpec{
Address: statute.AddrSpec{
"",
locIP,
lAddr.Port,
},
addrType: ATYPIPv4,
AddrType: statute.ATYPIPv4,
}
req.Write(reqHead.Bytes())
// Send all the bytes
@ -186,8 +188,8 @@ func TestSOCKS5_Associate(t *testing.T) {
// Verify response
expected := []byte{
VersionSocks5, MethodUserPassAuth, // use user password auth
UserPassAuthVersion, AuthSuccess, // response auth success
statute.VersionSocks5, statute.MethodUserPassAuth, // use user password auth
statute.UserPassAuthVersion, statute.AuthSuccess, // response auth success
}
out := make([]byte, len(expected))
@ -200,11 +202,11 @@ func TestSOCKS5_Associate(t *testing.T) {
t.Fatalf("bad: %v", out)
}
rspHead, err := ParseHeader(conn)
rspHead, err := statute.ParseHeader(conn)
if err != nil {
t.Fatalf("bad response header: %v", err)
}
if rspHead.Version != VersionSocks5 && rspHead.Command != RepSuccess {
if rspHead.Version != statute.VersionSocks5 && rspHead.Command != statute.RepSuccess {
t.Fatalf("parse success but bad header: %v", rspHead)
}
@ -218,7 +220,7 @@ func TestSOCKS5_Associate(t *testing.T) {
t.Fatalf("bad dial: %v", err)
}
// Send a ping
_, _ = udpConn.Write(append([]byte{0, 0, 0, ATYPIPv4, 0, 0, 0, 0, 0, 0}, []byte("ping")...))
_, _ = udpConn.Write(append([]byte{0, 0, 0, statute.ATYPIPv4, 0, 0, 0, 0, 0, 0}, []byte("ping")...))
response := make([]byte, 1024)
n, _, err := udpConn.ReadFrom(response)
if err != nil || !bytes.Equal(response[n-4:n], []byte("pong")) {

93
statute/auth.go Normal file

@ -0,0 +1,93 @@
package statute
import (
"fmt"
"io"
)
// UserPassRequest is the negotiation user's password request packet
// The SOCKS handshake user's password request is formed as follows:
// +--------------+------+----------+------+----------+
// | USERPASS_VER | ULEN | USER | PLEN | PASS |
// +--------------+------+----------+------+----------+
// | 1 | 1 | Variable | 1 | Variable |
// +--------------+------+----------+------+----------+
type UserPassRequest struct {
Ver byte
Ulen byte
User []byte // 1-255 bytes
Plen byte
Pass []byte // 1-255 bytes
}
func NewNegotiationUserPassRequest(ver byte, user, pass []byte) UserPassRequest {
return UserPassRequest{
ver,
byte(len(user)),
user,
byte(len(pass)),
pass,
}
}
func ParseUserPassRequest(r io.Reader) (nup UserPassRequest, err error) {
// Get the version and username length
header := []byte{0, 0}
if _, err = io.ReadAtLeast(r, header, 2); err != nil {
return
}
nup.Ver = header[0]
// Ensure we are compatible
if header[0] != UserPassAuthVersion {
err = fmt.Errorf("unsupported auth version: %v", header[0])
return
}
// Get the user name
nup.Ulen = header[1]
nup.User = make([]byte, nup.Ulen)
if _, err = io.ReadAtLeast(r, nup.User, int(nup.Ulen)); err != nil {
return
}
// Get the password length
if _, err = r.Read(header[:1]); err != nil {
return
}
// Get the password
nup.Plen = header[0]
nup.Pass = make([]byte, nup.Plen)
_, err = io.ReadAtLeast(r, nup.Pass, int(nup.Plen))
return
}
func (sf UserPassRequest) Bytes() []byte {
b := make([]byte, 0, 3+sf.Ulen+sf.Plen)
b = append(b, sf.Ver, sf.Ulen)
b = append(b, sf.User...)
b = append(b, sf.Plen)
b = append(b, sf.Pass...)
return b
}
// UserPassReply is the negotiation user's password reply packet
// The SOCKS handshake user's password response is formed as follows:
// +-----+--------+
// | VER | status |
// +-----+--------+
// | 1 | 1 |
// +-----+--------+
type UserPassReply struct {
Ver byte
Status byte
}
func ParseUserPassReply(r io.Reader) (n UserPassReply, err error) {
bb := make([]byte, 2)
if _, err = io.ReadFull(r, bb); err != nil {
return
}
n.Ver = bb[0]
n.Status = bb[1]
return
}

@ -1,4 +1,4 @@
package socks5
package statute
import (
"errors"
@ -78,7 +78,7 @@ func (sf *Packet) Parse(b []byte) error {
case ATYPIPv4:
headLen += net.IPv4len + 2
sf.DstAddr.IP = net.IPv4(b[4], b[5], b[6], b[7])
sf.DstAddr.Port = buildPort(b[4+net.IPv4len], b[4+net.IPv4len+1])
sf.DstAddr.Port = BuildPort(b[4+net.IPv4len], b[4+net.IPv4len+1])
case ATYPIPv6:
headLen += net.IPv6len + 2
if len(b) <= headLen {
@ -86,7 +86,7 @@ func (sf *Packet) Parse(b []byte) error {
}
sf.DstAddr.IP = net.IP{b[4], b[5], b[6], b[7], b[8], b[9], b[10], b[11], b[12], b[13], b[14], b[15], b[16], b[17], b[18], b[19]}
sf.DstAddr.Port = buildPort(b[4+net.IPv6len], b[4+net.IPv6len+1])
sf.DstAddr.Port = BuildPort(b[4+net.IPv6len], b[4+net.IPv6len+1])
case ATYPDomain:
addrLen := int(b[4])
headLen += 1 + addrLen + 2
@ -96,9 +96,9 @@ func (sf *Packet) Parse(b []byte) error {
str := make([]byte, addrLen)
copy(str, b[5:5+addrLen])
sf.DstAddr.FQDN = string(str)
sf.DstAddr.Port = buildPort(b[5+addrLen], b[5+addrLen+1])
sf.DstAddr.Port = BuildPort(b[5+addrLen], b[5+addrLen+1])
default:
return errUnrecognizedAddrType
return ErrUnrecognizedAddrType
}
sf.Data = b[headLen:]
return nil
@ -119,7 +119,7 @@ func (sf *Packet) Header() []byte {
bs = append(bs, ATYPDomain)
bs = append(bs, []byte(sf.DstAddr.FQDN)...)
}
hi, lo := breakPort(sf.DstAddr.Port)
hi, lo := BreakPort(sf.DstAddr.Port)
bs = append(bs, hi, lo)
return bs
}

9
statute/errors.go Normal file

@ -0,0 +1,9 @@
package statute
import (
"fmt"
)
var (
ErrUnrecognizedAddrType = fmt.Errorf("Unrecognized address type")
)

@ -1,4 +1,4 @@
package socks5
package statute
import (
"fmt"
@ -7,35 +7,6 @@ import (
"strconv"
)
// socks const defined
const (
// protocol version
VersionSocks4 = uint8(4)
VersionSocks5 = uint8(5)
// request command
CommandConnect = uint8(1)
CommandBind = uint8(2)
CommandAssociate = uint8(3)
// address type
ATYPIPv4 = uint8(1)
ATYPDomain = uint8(3)
ATYPIPv6 = uint8(4)
)
// reply status
const (
RepSuccess uint8 = iota
RepServerFailure
RepRuleFailure
RepNetworkUnreachable
RepHostUnreachable
RepConnectionRefused
RepTTLExpired
RepCommandNotSupported
RepAddrTypeNotSupported
// 0x09 - 0xff unassigned
)
// Header represents the SOCKS4/SOCKS5 head len defined
const (
headerVERLen = 1
@ -73,6 +44,26 @@ func (a AddrSpec) Address() string {
return fmt.Sprintf("%s:%d", a.IP, a.Port)
}
// ParseAddrSpec parse address to the AddrSpec address
func ParseAddrSpec(address string) (a AddrSpec, err error) {
var host, port string
host, port, err = net.SplitHostPort(address)
if err != nil {
return
}
ip := net.ParseIP(host)
if ip4 := ip.To4(); ip4 != nil {
a.IP = ip4
} else if ip6 := ip.To16(); ip6 != nil {
a.IP = ip6
} else {
a.FQDN = host
}
a.Port, err = strconv.Atoi(port)
return
}
// Header represents the SOCKS4/SOCKS5 header, it contains everything that is not payload
// The SOCKS4 request/response is formed as follows:
// +-----+-----+------+------+
@ -96,7 +87,7 @@ type Header struct {
// Address in socks message
Address AddrSpec
// private stuff set when Header parsed
addrType uint8
AddrType uint8
}
// ParseHeader to header from io.Reader
@ -123,8 +114,8 @@ func ParseHeader(r io.Reader) (hd Header, err error) {
return hd, fmt.Errorf("failed to get header RSV and address type, %v", err)
}
hd.Reserved = tmp[0]
hd.addrType = tmp[1]
switch hd.addrType {
hd.AddrType = tmp[1]
switch hd.AddrType {
case ATYPDomain:
if _, err = io.ReadFull(r, tmp[:1]); err != nil {
return hd, fmt.Errorf("failed to get header, %v", err)
@ -135,23 +126,23 @@ func ParseHeader(r io.Reader) (hd Header, err error) {
return hd, fmt.Errorf("failed to get header, %v", err)
}
hd.Address.FQDN = string(addr[:domainLen])
hd.Address.Port = buildPort(addr[domainLen], addr[domainLen+1])
hd.Address.Port = BuildPort(addr[domainLen], addr[domainLen+1])
case ATYPIPv4:
addr := make([]byte, net.IPv4len+2)
if _, err = io.ReadFull(r, addr); err != nil {
return hd, fmt.Errorf("failed to get header, %v", err)
}
hd.Address.IP = net.IPv4(addr[0], addr[1], addr[2], addr[3])
hd.Address.Port = buildPort(addr[net.IPv4len], addr[net.IPv4len+1])
hd.Address.Port = BuildPort(addr[net.IPv4len], addr[net.IPv4len+1])
case ATYPIPv6:
addr := make([]byte, net.IPv6len+2)
if _, err = io.ReadFull(r, addr); err != nil {
return hd, fmt.Errorf("failed to get header, %v", err)
}
hd.Address.IP = addr[:net.IPv6len]
hd.Address.Port = buildPort(addr[net.IPv6len], addr[net.IPv6len+1])
hd.Address.Port = BuildPort(addr[net.IPv6len], addr[net.IPv6len+1])
default:
return hd, errUnrecognizedAddrType
return hd, ErrUnrecognizedAddrType
}
} else { // Socks4
// read port and ipv4 ip
@ -159,7 +150,7 @@ func ParseHeader(r io.Reader) (hd Header, err error) {
if _, err = io.ReadFull(r, tmp); err != nil {
return hd, fmt.Errorf("failed to get socks4 header port and ip, %v", err)
}
hd.Address.Port = buildPort(tmp[0], tmp[1])
hd.Address.Port = BuildPort(tmp[0], tmp[1])
hd.Address.IP = net.IPv4(tmp[2], tmp[3], tmp[4], tmp[5])
}
return hd, nil
@ -167,7 +158,7 @@ func ParseHeader(r io.Reader) (hd Header, err error) {
// Bytes returns a slice of header
func (h Header) Bytes() (b []byte) {
hiPort, loPort := breakPort(h.Address.Port)
hiPort, loPort := BreakPort(h.Address.Port)
if h.Version == VersionSocks4 {
b = make([]byte, 0, headerVERLen+headerCMDLen+headerPORTLen+net.IPv4len)
b = append(b, h.Version)
@ -176,24 +167,24 @@ func (h Header) Bytes() (b []byte) {
b = append(b, h.Address.IP.To4()...)
} else if h.Version == VersionSocks5 {
length := headerVERLen + headerCMDLen + headerRSVLen + headerATYPLen + headerPORTLen
if h.addrType == ATYPDomain {
if h.AddrType == ATYPDomain {
length += 1 + len(h.Address.FQDN)
} else if h.addrType == ATYPIPv4 {
} else if h.AddrType == ATYPIPv4 {
length += net.IPv4len
} else if h.addrType == ATYPIPv6 {
} else if h.AddrType == ATYPIPv6 {
length += net.IPv6len
}
b = make([]byte, 0, length)
b = append(b, h.Version)
b = append(b, h.Command)
b = append(b, h.Reserved)
b = append(b, h.addrType)
if h.addrType == ATYPDomain {
b = append(b, h.AddrType)
if h.AddrType == ATYPDomain {
b = append(b, byte(len(h.Address.FQDN)))
b = append(b, []byte(h.Address.FQDN)...)
} else if h.addrType == ATYPIPv4 {
} else if h.AddrType == ATYPIPv4 {
b = append(b, h.Address.IP.To4()...)
} else if h.addrType == ATYPIPv6 {
} else if h.AddrType == ATYPIPv6 {
b = append(b, h.Address.IP.To16()...)
}
b = append(b, hiPort, loPort)
@ -201,5 +192,5 @@ func (h Header) Bytes() (b []byte) {
return b
}
func buildPort(hi, lo byte) int { return (int(hi) << 8) | int(lo) }
func breakPort(port int) (hi, lo byte) { return byte(port >> 8), byte(port) }
func BuildPort(hi, lo byte) int { return (int(hi) << 8) | int(lo) }
func BreakPort(port int) (hi, lo byte) { return byte(port >> 8), byte(port) }

@ -1,4 +1,4 @@
package socks5
package statute
import (
"bytes"

54
statute/method.go Normal file

@ -0,0 +1,54 @@
package statute
import (
"io"
)
// MethodRequest is the negotiation method request packet
// The SOCKS handshake method request is formed as follows:
// +-----+----------+---------------+
// | VER | NMETHODS | METHODS |
// +-----+----------+---------------+
// | 1 | 1 | X'00' - X'FF' |
// +-----+----------+---------------+
type MethodRequest struct {
Ver byte
NMethods byte
Methods []byte // 1-255 bytes
}
// NewMethodRequest new negotiation method request
func NewMethodRequest(ver byte, medthods []byte) MethodRequest {
return MethodRequest{
ver,
byte(len(medthods)),
medthods,
}
}
func (n MethodRequest) Bytes() []byte {
b := make([]byte, 0, 2+n.NMethods)
return append(append(b, n.Ver, n.NMethods), n.Methods...)
}
// MethodReply is the negotiation method reply packet
// The SOCKS handshake method response is formed as follows:
// +-----+--------+
// | VER | METHOD |
// +-----+--------+
// | 1 | 1 |
// +-----+--------+
type MethodReply struct {
Ver byte
Method byte
}
func ParseMethodReply(r io.Reader) (n MethodReply, err error) {
bb := make([]byte, 2)
if _, err = io.ReadFull(r, bb); err != nil {
return
}
n.Ver = bb[0]
n.Method = bb[1]
return
}

53
statute/statute.go Normal file

@ -0,0 +1,53 @@
package statute
import (
"fmt"
)
// auth defined
const (
MethodNoAuth = byte(0x00)
MethodGSSAPI = byte(0x01)
MethodUserPassAuth = byte(0x02)
MethodNoAcceptable = byte(0xff)
// user password version
UserPassAuthVersion = byte(0x01)
// auth status
AuthSuccess = byte(0x00)
AuthFailure = byte(0x01)
)
// socks const defined
const (
// protocol version
VersionSocks4 = byte(0x04)
VersionSocks5 = byte(0x05)
// request command
CommandConnect = byte(0x01)
CommandBind = byte(0x02)
CommandAssociate = byte(0x03)
// address type
ATYPIPv4 = byte(0x01)
ATYPDomain = byte(0x03)
ATYPIPv6 = byte(0x04)
)
// reply status
const (
RepSuccess uint8 = iota
RepServerFailure
RepRuleFailure
RepNetworkUnreachable
RepHostUnreachable
RepConnectionRefused
RepTTLExpired
RepCommandNotSupported
RepAddrTypeNotSupported
// 0x09 - 0xff unassigned
)
// auth error defined
var (
ErrUserAuthFailed = fmt.Errorf("user authentication failed")
ErrNoSupportedAuth = fmt.Errorf("no supported authentication mechanism")
)