fix request/response statute

This commit is contained in:
mo 2020-08-05 18:10:50 +08:00
parent 75dd487cbf
commit 916af03167
12 changed files with 506 additions and 286 deletions

@ -1,19 +1,58 @@
package main
import (
"io"
"log"
"net"
"os"
"time"
"github.com/thinkgos/go-socks5"
"github.com/thinkgos/go-socks5/client"
)
func handleErr(err error) {
if err != nil {
panic(err)
}
}
func main() {
// Create a local listener
l, err := net.Listen("tcp", "127.0.0.1:0")
handleErr(err)
go func() {
conn, err := l.Accept()
handleErr(err)
defer conn.Close()
buf := make([]byte, 4)
_, err = io.ReadAtLeast(conn, buf, 4)
handleErr(err)
log.Printf("server: %+v", string(buf))
conn.Write([]byte("pong"))
}()
lAddr := l.Addr().(*net.TCPAddr)
go func() {
time.Sleep(time.Second)
c, err := client.NewClient("127.0.0.1:1080")
handleErr(err)
con, err := c.Dial("tcp", lAddr.String())
handleErr(err)
con.Write([]byte("ping"))
out := make([]byte, 4)
_ = con.SetDeadline(time.Now().Add(time.Second)) // nolint: errcheck
_, err = io.ReadFull(con, out)
log.Printf("client: %+v", string(out))
}()
// Create a SOCKS5 server
server := socks5.New(socks5.WithLogger(socks5.NewLogger(log.New(os.Stdout, "socks5: ", log.LstdFlags))))
server := socks5.NewServer(socks5.WithLogger(socks5.NewLogger(log.New(os.Stdout, "socks5: ", log.LstdFlags))))
// Create SOCKS5 proxy on localhost port 8000
if err := server.ListenAndServe("tcp", "127.0.0.1:1080"); err != nil {
panic(err)
}
}

@ -129,16 +129,16 @@ func (c *Client) Dial(network, addr string) (net.Conn, error) {
if err != nil {
return nil, err
}
head := statute.Header{
Version: statute.VersionSocks5,
Command: statute.CommandConnect,
Address: a,
head := statute.Request{
Version: statute.VersionSocks5,
Command: statute.CommandConnect,
DstAddress: a,
}
if _, err := conn.Write(head.Bytes()); err != nil {
return nil, err
}
rspHead, err := statute.ParseHeader(conn.TCPConn)
rspHead, err := statute.ParseRequest(conn.TCPConn)
if err != nil {
return nil, err
}
@ -169,15 +169,15 @@ func (c *Client) Dial(network, addr string) (net.Conn, error) {
if err != nil {
return nil, err
}
head := statute.Header{
Version: statute.VersionSocks5,
Command: statute.CommandConnect,
Address: a,
head := statute.Request{
Version: statute.VersionSocks5,
Command: statute.CommandConnect,
DstAddress: a,
}
if _, err := conn.Write(head.Bytes()); err != nil {
return nil, err
}
rspHead, err := statute.ParseHeader(conn.TCPConn)
rspHead, err := statute.ParseRequest(conn.TCPConn)
if err != nil {
return nil, err
}
@ -185,7 +185,7 @@ func (c *Client) Dial(network, addr string) (net.Conn, error) {
return nil, errors.New("host unreachable")
}
raddr, err := net.ResolveUDPAddr("udp", rspHead.Address.String())
raddr, err := net.ResolveUDPAddr("udp", rspHead.DstAddress.String())
if err != nil {
return nil, err
}

@ -18,7 +18,7 @@ type AddressRewriter interface {
// A Request represents request received by a server
type Request struct {
statute.Header
statute.Request
// AuthContext provided during negotiation
AuthContext *AuthContext
// LocalAddr of the the network server listen
@ -35,7 +35,7 @@ type Request struct {
// NewRequest creates a new Request from the tcp connection
func NewRequest(bufConn io.Reader) (*Request, error) {
hd, err := statute.ParseHeader(bufConn)
hd, err := statute.ParseRequest(bufConn)
if err != nil {
return nil, err
}
@ -43,8 +43,8 @@ func NewRequest(bufConn io.Reader) (*Request, error) {
return nil, fmt.Errorf("unrecognized command[%d]", hd.Command)
}
return &Request{
Header: hd,
RawDestAddr: &hd.Address,
Request: hd,
RawDestAddr: &hd.DstAddress,
Reader: bufConn,
}, nil
}
@ -58,7 +58,7 @@ func (s *Server) handleRequest(write io.Writer, req *Request) error {
if dest.FQDN != "" {
_ctx, addr, err := s.resolver.Resolve(ctx, dest.FQDN)
if err != nil {
if err := SendReply(write, req.Header, statute.RepHostUnreachable); err != nil {
if err := SendReply(write, req.Request, statute.RepHostUnreachable); err != nil {
return fmt.Errorf("failed to send reply, %v", err)
}
return fmt.Errorf("failed to resolve destination[%v], %v", dest.FQDN, err)
@ -76,7 +76,7 @@ func (s *Server) handleRequest(write io.Writer, req *Request) error {
// Check if this is allowed
_ctx, ok := s.rules.Allow(ctx, req)
if !ok {
if err := SendReply(write, req.Header, statute.RepRuleFailure); err != nil {
if err := SendReply(write, req.Request, statute.RepRuleFailure); err != nil {
return fmt.Errorf("failed to send reply, %v", err)
}
return fmt.Errorf("bind to %v blocked by rules", req.RawDestAddr)
@ -101,7 +101,7 @@ func (s *Server) handleRequest(write io.Writer, req *Request) error {
}
return s.handleAssociate(ctx, write, req)
default:
if err := SendReply(write, req.Header, statute.RepCommandNotSupported); err != nil {
if err := SendReply(write, req.Request, statute.RepCommandNotSupported); err != nil {
return fmt.Errorf("failed to send reply, %v", err)
}
return fmt.Errorf("unsupported command[%v]", req.Command)
@ -126,7 +126,7 @@ func (s *Server) handleConnect(ctx context.Context, writer io.Writer, request *R
} else if strings.Contains(msg, "network is unreachable") {
resp = statute.RepNetworkUnreachable
}
if err := SendReply(writer, request.Header, resp); err != nil {
if err := SendReply(writer, request.Request, resp); err != nil {
return fmt.Errorf("failed to send reply, %v", err)
}
return fmt.Errorf("connect to %v failed, %v", request.RawDestAddr, err)
@ -134,7 +134,7 @@ func (s *Server) handleConnect(ctx context.Context, writer io.Writer, request *R
defer target.Close()
// Send success
if err := SendReply(writer, request.Header, statute.RepSuccess, target.LocalAddr()); err != nil {
if err := SendReply(writer, request.Request, statute.RepSuccess, target.LocalAddr()); err != nil {
return fmt.Errorf("failed to send reply, %v", err)
}
@ -156,7 +156,7 @@ func (s *Server) handleConnect(ctx context.Context, writer io.Writer, request *R
// handleBind is used to handle a connect command
func (s *Server) handleBind(_ context.Context, writer io.Writer, request *Request) error {
// TODO: Support bind
if err := SendReply(writer, request.Header, statute.RepCommandNotSupported); err != nil {
if err := SendReply(writer, request.Request, statute.RepCommandNotSupported); err != nil {
return fmt.Errorf("failed to send reply: %v", err)
}
return nil
@ -180,7 +180,7 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, request
} else if strings.Contains(msg, "network is unreachable") {
resp = statute.RepNetworkUnreachable
}
if err := SendReply(writer, request.Header, resp); err != nil {
if err := SendReply(writer, request.Request, resp); err != nil {
return fmt.Errorf("failed to send reply, %v", err)
}
return fmt.Errorf("connect to %v failed, %v", request.RawDestAddr, err)
@ -189,7 +189,7 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, request
targetUDP, ok := target.(*net.UDPConn)
if !ok {
if err := SendReply(writer, request.Header, statute.RepServerFailure); err != nil {
if err := SendReply(writer, request.Request, statute.RepServerFailure); err != nil {
return fmt.Errorf("failed to send reply, %v", err)
}
return fmt.Errorf("dial udp invalid")
@ -197,7 +197,7 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, request
bindLn, err := net.ListenUDP("udp", nil)
if err != nil {
if err := SendReply(writer, request.Header, statute.RepServerFailure); err != nil {
if err := SendReply(writer, request.Request, statute.RepServerFailure); err != nil {
return fmt.Errorf("failed to send reply, %v", err)
}
return fmt.Errorf("listen udp failed, %v", err)
@ -206,7 +206,7 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, request
s.logger.Errorf("target addr %v, listen addr: %s", targetUDP.RemoteAddr(), bindLn.LocalAddr())
// send BND.ADDR and BND.PORT, client must
if err = SendReply(writer, request.Header, statute.RepSuccess, bindLn.LocalAddr()); err != nil {
if err = SendReply(writer, request.Request, statute.RepSuccess, bindLn.LocalAddr()); err != nil {
return fmt.Errorf("failed to send reply, %v", err)
}
@ -301,13 +301,13 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, request
}
// SendReply is used to send a reply message
func SendReply(w io.Writer, head statute.Header, resp uint8, bindAddr ...net.Addr) error {
func SendReply(w io.Writer, head statute.Request, resp uint8, bindAddr ...net.Addr) error {
head.Command = resp
if len(bindAddr) == 0 {
head.Address.AddrType = statute.ATYPIPv4
head.Address.IP = []byte{0, 0, 0, 0}
head.Address.Port = 0
head.DstAddress.AddrType = statute.ATYPIPv4
head.DstAddress.IP = []byte{0, 0, 0, 0}
head.DstAddress.Port = 0
} else {
addrSpec := statute.AddrSpec{}
if tcpAddr, ok := bindAddr[0].(*net.TCPAddr); ok && tcpAddr != nil {
@ -322,17 +322,17 @@ func SendReply(w io.Writer, head statute.Header, resp uint8, bindAddr ...net.Add
}
switch {
case addrSpec.FQDN != "":
head.Address.AddrType = statute.ATYPDomain
head.Address.FQDN = addrSpec.FQDN
head.Address.Port = addrSpec.Port
head.DstAddress.AddrType = statute.ATYPDomain
head.DstAddress.FQDN = addrSpec.FQDN
head.DstAddress.Port = addrSpec.Port
case addrSpec.IP.To4() != nil:
head.Address.AddrType = statute.ATYPIPv4
head.Address.IP = addrSpec.IP.To4()
head.Address.Port = addrSpec.Port
head.DstAddress.AddrType = statute.ATYPIPv4
head.DstAddress.IP = addrSpec.IP.To4()
head.DstAddress.Port = addrSpec.Port
case addrSpec.IP.To16() != nil:
head.Address.AddrType = statute.ATYPIPv6
head.Address.IP = addrSpec.IP.To16()
head.Address.Port = addrSpec.Port
head.DstAddress.AddrType = statute.ATYPIPv6
head.DstAddress.IP = addrSpec.IP.To16()
head.DstAddress.Port = addrSpec.Port
default:
return fmt.Errorf("failed to format address[%v]", bindAddr)
}

@ -16,32 +16,32 @@ func TestPermitCommand(t *testing.T) {
ctx := context.Background()
r = NewPermitAll()
_, ok = r.Allow(ctx, &Request{Header: statute.Header{Command: statute.CommandConnect}})
_, ok = r.Allow(ctx, &Request{Request: statute.Request{Command: statute.CommandConnect}})
require.True(t, ok)
_, ok = r.Allow(ctx, &Request{Header: statute.Header{Command: statute.CommandBind}})
_, ok = r.Allow(ctx, &Request{Request: statute.Request{Command: statute.CommandBind}})
require.True(t, ok)
_, ok = r.Allow(ctx, &Request{Header: statute.Header{Command: statute.CommandAssociate}})
_, ok = r.Allow(ctx, &Request{Request: statute.Request{Command: statute.CommandAssociate}})
require.True(t, ok)
_, ok = r.Allow(ctx, &Request{Header: statute.Header{Command: 0x00}})
_, ok = r.Allow(ctx, &Request{Request: statute.Request{Command: 0x00}})
require.False(t, ok)
r = NewPermitConnAndAss()
_, ok = r.Allow(ctx, &Request{Header: statute.Header{Command: statute.CommandConnect}})
_, ok = r.Allow(ctx, &Request{Request: statute.Request{Command: statute.CommandConnect}})
require.True(t, ok)
_, ok = r.Allow(ctx, &Request{Header: statute.Header{Command: statute.CommandBind}})
_, ok = r.Allow(ctx, &Request{Request: statute.Request{Command: statute.CommandBind}})
require.False(t, ok)
_, ok = r.Allow(ctx, &Request{Header: statute.Header{Command: statute.CommandAssociate}})
_, ok = r.Allow(ctx, &Request{Request: statute.Request{Command: statute.CommandAssociate}})
require.True(t, ok)
_, ok = r.Allow(ctx, &Request{Header: statute.Header{Command: 0x00}})
_, ok = r.Allow(ctx, &Request{Request: statute.Request{Command: 0x00}})
require.False(t, ok)
r = NewPermitNone()
_, ok = r.Allow(ctx, &Request{Header: statute.Header{Command: statute.CommandConnect}})
_, ok = r.Allow(ctx, &Request{Request: statute.Request{Command: statute.CommandConnect}})
require.False(t, ok)
_, ok = r.Allow(ctx, &Request{Header: statute.Header{Command: statute.CommandBind}})
_, ok = r.Allow(ctx, &Request{Request: statute.Request{Command: statute.CommandBind}})
require.False(t, ok)
_, ok = r.Allow(ctx, &Request{Header: statute.Header{Command: statute.CommandAssociate}})
_, ok = r.Allow(ctx, &Request{Request: statute.Request{Command: statute.CommandAssociate}})
require.False(t, ok)
_, ok = r.Allow(ctx, &Request{Header: statute.Header{Command: 0x00}})
_, ok = r.Allow(ctx, &Request{Request: statute.Request{Command: 0x00}})
require.False(t, ok)
}

@ -136,7 +136,7 @@ func (s *Server) ServeConn(conn net.Conn) error {
request, err := NewRequest(bufConn)
if err != nil {
if err == statute.ErrUnrecognizedAddrType {
if err := SendReply(conn, statute.Header{Version: mr.Ver}, statute.RepAddrTypeNotSupported); err != nil {
if err := SendReply(conn, statute.Request{Version: mr.Ver}, statute.RepAddrTypeNotSupported); err != nil {
return fmt.Errorf("failed to send reply %w", err)
}
}

@ -57,11 +57,11 @@ func TestSOCKS5_Connect(t *testing.T) {
req := new(bytes.Buffer)
req.Write([]byte{statute.VersionSocks5, 2, statute.MethodNoAuth, statute.MethodUserPassAuth})
req.Write([]byte{statute.UserPassAuthVersion, 3, 'f', 'o', 'o', 3, 'b', 'a', 'r'})
reqHead := statute.Header{
reqHead := statute.Request{
Version: statute.VersionSocks5,
Command: statute.CommandConnect,
Reserved: 0,
Address: statute.AddrSpec{
DstAddress: statute.AddrSpec{
"",
net.ParseIP("127.0.0.1"),
lAddr.Port,
@ -80,11 +80,11 @@ func TestSOCKS5_Connect(t *testing.T) {
statute.VersionSocks5, statute.MethodUserPassAuth, // use user password auth
statute.UserPassAuthVersion, statute.AuthSuccess, // response auth success
}
rspHead := statute.Header{
rspHead := statute.Request{
Version: statute.VersionSocks5,
Command: statute.RepSuccess,
Reserved: 0,
Address: statute.AddrSpec{
DstAddress: statute.AddrSpec{
"",
net.ParseIP("127.0.0.1"),
0, // Ignore the port
@ -152,11 +152,11 @@ func TestSOCKS5_Associate(t *testing.T) {
req := new(bytes.Buffer)
req.Write([]byte{statute.VersionSocks5, 2, statute.MethodNoAuth, statute.MethodUserPassAuth})
req.Write([]byte{statute.UserPassAuthVersion, 3, 'f', 'o', 'o', 3, 'b', 'a', 'r'})
reqHead := statute.Header{
reqHead := statute.Request{
Version: statute.VersionSocks5,
Command: statute.CommandAssociate,
Reserved: 0,
Address: statute.AddrSpec{
DstAddress: statute.AddrSpec{
"",
locIP,
lAddr.Port,
@ -179,16 +179,16 @@ func TestSOCKS5_Associate(t *testing.T) {
require.NoError(t, err)
require.Equal(t, expected, out)
rspHead, err := statute.ParseHeader(conn)
rspHead, err := statute.ParseRequest(conn)
require.NoError(t, err)
require.Equal(t, statute.VersionSocks5, rspHead.Version)
require.Equal(t, statute.RepSuccess, rspHead.Command)
t.Logf("proxy bind listen port: %d", rspHead.Address.Port)
t.Logf("proxy bind listen port: %d", rspHead.DstAddress.Port)
udpConn, err := net.DialUDP("udp", nil, &net.UDPAddr{
IP: locIP,
Port: rspHead.Address.Port,
Port: rspHead.DstAddress.Port,
})
require.NoError(t, err)
// Send a ping

@ -63,7 +63,7 @@ func NewPacket(destAddr string, data []byte) (p Packet, err error) {
return
}
// ParseHeader parse to packet
// ParseRequest parse to packet
func (sf *Packet) Parse(b []byte) error {
if len(b) <= 4+net.IPv4len+2 { // no data
return errors.New("too short")
@ -104,7 +104,7 @@ func (sf *Packet) Parse(b []byte) error {
return nil
}
// Header returns s slice of packet header
// Request returns s slice of packet reply
func (sf *Packet) Header() []byte {
bs := make([]byte, 0, 32)
bs = append(bs, []byte{byte(sf.RSV << 8), byte(sf.RSV), sf.Frag}...)

@ -1,117 +0,0 @@
package statute
import (
"fmt"
"io"
"net"
)
// Header represents the SOCKS5 head len defined
const (
headerVERLen = 1
headerCMDLen = 1
headerRSVLen = 1
headerATYPLen = 1
headerPORTLen = 2
)
// Header represents the SOCKS4/SOCKS5 header, it contains everything that is not payload
// The SOCKS5 request/response is formed as follows:
// +-----+-----+-------+------+----------------+----------------+
// | VER | CMD | RSV | ATYP | [DST/BND].ADDR | [DST/BND].PORT |
// +-----+-----+-------+------+----------------+----------------+
// | 1 | 1 | X'00' | 1 | Variable | 2 |
// +-----+-----+-------+------+----------------+----------------+
type Header struct {
// Version of socks protocol for message
Version uint8
// Socks Command "connect","bind","associate"
Command uint8
// Reserved byte
Reserved uint8
// Address in socks message
Address AddrSpec
}
// ParseHeader to header from io.Reader
func ParseHeader(r io.Reader) (hd Header, err error) {
// Read the version and command
tmp := make([]byte, headerVERLen+headerCMDLen)
if _, err = io.ReadFull(r, tmp); err != nil {
return hd, fmt.Errorf("failed to get header version and command, %v", err)
}
hd.Version = tmp[0]
hd.Command = tmp[1]
if hd.Version != VersionSocks5 {
return hd, fmt.Errorf("unrecognized SOCKS version[%d]", hd.Version)
}
tmp = make([]byte, headerRSVLen+headerATYPLen)
if _, err = io.ReadFull(r, tmp); err != nil {
return hd, fmt.Errorf("failed to get header RSV and address type, %v", err)
}
hd.Reserved = tmp[0]
hd.Address.AddrType = tmp[1]
switch hd.Address.AddrType {
case ATYPDomain:
if _, err = io.ReadFull(r, tmp[:1]); err != nil {
return hd, fmt.Errorf("failed to get header, %v", err)
}
domainLen := int(tmp[0])
addr := make([]byte, domainLen+headerPORTLen)
if _, err = io.ReadFull(r, addr); err != nil {
return hd, fmt.Errorf("failed to get header, %v", err)
}
hd.Address.FQDN = string(addr[:domainLen])
hd.Address.Port = BuildPort(addr[domainLen], addr[domainLen+1])
case ATYPIPv4:
addr := make([]byte, net.IPv4len+2)
if _, err = io.ReadFull(r, addr); err != nil {
return hd, fmt.Errorf("failed to get header, %v", err)
}
hd.Address.IP = net.IPv4(addr[0], addr[1], addr[2], addr[3])
hd.Address.Port = BuildPort(addr[net.IPv4len], addr[net.IPv4len+1])
case ATYPIPv6:
addr := make([]byte, net.IPv6len+2)
if _, err = io.ReadFull(r, addr); err != nil {
return hd, fmt.Errorf("failed to get header, %v", err)
}
hd.Address.IP = addr[:net.IPv6len]
hd.Address.Port = BuildPort(addr[net.IPv6len], addr[net.IPv6len+1])
default:
return hd, ErrUnrecognizedAddrType
}
return hd, nil
}
// Bytes returns a slice of header
func (h Header) Bytes() (b []byte) {
var addr []byte
length := headerVERLen + headerCMDLen + headerRSVLen + headerATYPLen + headerPORTLen
if h.Address.AddrType == ATYPIPv4 {
length += net.IPv4len
addr = h.Address.IP.To4()
} else if h.Address.AddrType == ATYPIPv6 {
length += net.IPv6len
addr = h.Address.IP.To16()
} else { //ATYPDomain
length += 1 + len(h.Address.FQDN)
addr = []byte(h.Address.FQDN)
}
b = make([]byte, 0, length)
b = append(b, h.Version)
b = append(b, h.Command)
b = append(b, h.Reserved)
b = append(b, h.Address.AddrType)
if h.Address.AddrType == ATYPDomain {
b = append(b, byte(len(h.Address.FQDN)))
}
b = append(b, addr...)
hiPort, loPort := BreakPort(h.Address.Port)
b = append(b, hiPort, loPort)
return b
}

@ -1,102 +0,0 @@
package statute
import (
"bytes"
"io"
"net"
"reflect"
"testing"
)
func TestParseHeader(t *testing.T) {
type args struct {
r io.Reader
}
tests := []struct {
name string
args args
wantHd Header
wantErr bool
}{
{
"SOCKS5 IPV4",
args{bytes.NewReader([]byte{VersionSocks5, CommandConnect, 0, ATYPIPv4, 127, 0, 0, 1, 0x1f, 0x90})},
Header{
VersionSocks5, CommandConnect, 0,
AddrSpec{IP: net.IPv4(127, 0, 0, 1), Port: 8080, AddrType: ATYPIPv4},
},
false,
},
{
"SOCKS5 IPV6",
args{bytes.NewReader([]byte{VersionSocks5, CommandConnect, 0, ATYPIPv6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x1f, 0x90})},
Header{
VersionSocks5, CommandConnect, 0,
AddrSpec{IP: net.IPv6zero, Port: 8080, AddrType: ATYPIPv6},
},
false,
},
{
"SOCKS5 FQDN",
args{bytes.NewReader([]byte{VersionSocks5, CommandConnect, 0, ATYPDomain, 9, 'l', 'o', 'c', 'a', 'l', 'h', 'o', 's', 't', 0x1f, 0x90})},
Header{
VersionSocks5, CommandConnect, 0,
AddrSpec{FQDN: "localhost", Port: 8080, AddrType: ATYPDomain},
},
false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotHd, err := ParseHeader(tt.args.r)
if (err != nil) != tt.wantErr {
t.Errorf("ParseHeader() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(gotHd, tt.wantHd) {
t.Errorf("ParseHeader() gotHd = %+v, want %+v", gotHd, tt.wantHd)
}
})
}
}
func TestHeader_Bytes(t *testing.T) {
tests := []struct {
name string
header Header
wantB []byte
}{
{
"SOCKS5 IPV4",
Header{
VersionSocks5, CommandConnect, 0,
AddrSpec{IP: net.IPv4(127, 0, 0, 1), Port: 8080, AddrType: ATYPIPv4},
},
[]byte{VersionSocks5, CommandConnect, 0, ATYPIPv4, 127, 0, 0, 1, 0x1f, 0x90},
},
{
"SOCKS5 IPV6",
Header{
VersionSocks5, CommandConnect, 0,
AddrSpec{IP: net.IPv6zero, Port: 8080, AddrType: ATYPIPv6},
},
[]byte{VersionSocks5, CommandConnect, 0, ATYPIPv6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x1f, 0x90},
},
{
"SOCKS5 FQDN",
Header{
VersionSocks5, CommandConnect, 0,
AddrSpec{FQDN: "localhost", Port: 8080, AddrType: ATYPDomain},
},
[]byte{VersionSocks5, CommandConnect, 0, ATYPDomain, 9, 'l', 'o', 'c', 'a', 'l', 'h', 'o', 's', 't', 0x1f, 0x90},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if gotB := tt.header.Bytes(); !reflect.DeepEqual(gotB, tt.wantB) {
t.Errorf("Bytes() = %v, want %v", gotB, tt.wantB)
}
})
}
}

205
statute/message.go Normal file

@ -0,0 +1,205 @@
package statute
import (
"fmt"
"io"
"net"
)
// Request represents the SOCKS5 request, it contains everything that is not payload
// The SOCKS5 request is formed as follows:
// +-----+-----+-------+------+----------+----------+
// | VER | CMD | RSV | ATYP | DST.ADDR | DST.PORT |
// +-----+-----+-------+------+----------+----------+
// | 1 | 1 | X'00' | 1 | Variable | 2 |
// +-----+-----+-------+------+----------+----------+
type Request struct {
// Version of socks protocol for message
Version uint8
// Socks Command "connect","bind","associate"
Command uint8
// Reserved byte
Reserved uint8
// DstAddress in socks message
DstAddress AddrSpec
}
// ParseRequest to request from io.Reader
func ParseRequest(r io.Reader) (req Request, err error) {
// Read the version and command
tmp := []byte{0, 0}
if _, err = io.ReadFull(r, tmp); err != nil {
return req, fmt.Errorf("failed to get request version and command, %v", err)
}
req.Version = tmp[0]
req.Command = tmp[1]
if req.Version != VersionSocks5 {
return req, fmt.Errorf("unrecognized SOCKS version[%d]", req.Version)
}
// Read reserved and address type
if _, err = io.ReadFull(r, tmp); err != nil {
return req, fmt.Errorf("failed to get request RSV and address type, %v", err)
}
req.Reserved = tmp[0]
req.DstAddress.AddrType = tmp[1]
switch req.DstAddress.AddrType {
case ATYPDomain:
if _, err = io.ReadFull(r, tmp[:1]); err != nil {
return req, fmt.Errorf("failed to get request, %v", err)
}
domainLen := int(tmp[0])
addr := make([]byte, domainLen+2)
if _, err = io.ReadFull(r, addr); err != nil {
return req, fmt.Errorf("failed to get request, %v", err)
}
req.DstAddress.FQDN = string(addr[:domainLen])
req.DstAddress.Port = BuildPort(addr[domainLen], addr[domainLen+1])
case ATYPIPv4:
addr := make([]byte, net.IPv4len+2)
if _, err = io.ReadFull(r, addr); err != nil {
return req, fmt.Errorf("failed to get request, %v", err)
}
req.DstAddress.IP = net.IPv4(addr[0], addr[1], addr[2], addr[3])
req.DstAddress.Port = BuildPort(addr[net.IPv4len], addr[net.IPv4len+1])
case ATYPIPv6:
addr := make([]byte, net.IPv6len+2)
if _, err = io.ReadFull(r, addr); err != nil {
return req, fmt.Errorf("failed to get request, %v", err)
}
req.DstAddress.IP = addr[:net.IPv6len]
req.DstAddress.Port = BuildPort(addr[net.IPv6len], addr[net.IPv6len+1])
default:
return req, ErrUnrecognizedAddrType
}
return req, nil
}
// Bytes returns a slice of request
func (h Request) Bytes() (b []byte) {
var addr []byte
length := 6
if h.DstAddress.AddrType == ATYPIPv4 {
length += net.IPv4len
addr = h.DstAddress.IP.To4()
} else if h.DstAddress.AddrType == ATYPIPv6 {
length += net.IPv6len
addr = h.DstAddress.IP.To16()
} else { //ATYPDomain
length += 1 + len(h.DstAddress.FQDN)
addr = []byte(h.DstAddress.FQDN)
}
b = make([]byte, 0, length)
b = append(b, h.Version)
b = append(b, h.Command)
b = append(b, h.Reserved)
b = append(b, h.DstAddress.AddrType)
if h.DstAddress.AddrType == ATYPDomain {
b = append(b, byte(len(h.DstAddress.FQDN)))
}
b = append(b, addr...)
hiPort, loPort := BreakPort(h.DstAddress.Port)
b = append(b, hiPort, loPort)
return b
}
// Reply represents the SOCKS5 reply, it contains everything that is not payload
// The SOCKS5 response is formed as follows:
// +-----+-----+-------+------+----------+-----------+
// | VER | REP | RSV | ATYP | BND.ADDR | BND].PORT |
// +-----+-----+-------+------+----------+----------+
// | 1 | 1 | X'00' | 1 | Variable | 2 |
// +-----+-----+-------+------+----------+----------+
type Reply struct {
// Version of socks protocol for message
Version uint8
// Socks Response status"
Response uint8
// Reserved byte
Reserved uint8
// Bind Address in socks message
BndAddress AddrSpec
}
// Bytes returns a slice of request
func (h Reply) Bytes() (b []byte) {
var addr []byte
length := 6
if h.BndAddress.AddrType == ATYPIPv4 {
length += net.IPv4len
addr = h.BndAddress.IP.To4()
} else if h.BndAddress.AddrType == ATYPIPv6 {
length += net.IPv6len
addr = h.BndAddress.IP.To16()
} else { //ATYPDomain
length += 1 + len(h.BndAddress.FQDN)
addr = []byte(h.BndAddress.FQDN)
}
b = make([]byte, 0, length)
b = append(b, h.Version)
b = append(b, h.Response)
b = append(b, h.Reserved)
b = append(b, h.BndAddress.AddrType)
if h.BndAddress.AddrType == ATYPDomain {
b = append(b, byte(len(h.BndAddress.FQDN)))
}
b = append(b, addr...)
hiPort, loPort := BreakPort(h.BndAddress.Port)
b = append(b, hiPort, loPort)
return b
}
// ParseRequest to request from io.Reader
func ParseReply(r io.Reader) (rep Reply, err error) {
// Read the version and command
tmp := []byte{0, 0}
if _, err = io.ReadFull(r, tmp); err != nil {
return rep, fmt.Errorf("failed to get request version and command, %v", err)
}
rep.Version = tmp[0]
rep.Response = tmp[1]
if rep.Version != VersionSocks5 {
return rep, fmt.Errorf("unrecognized SOCKS version[%d]", rep.Version)
}
// Read reserved and address type
if _, err = io.ReadFull(r, tmp); err != nil {
return rep, fmt.Errorf("failed to get request RSV and address type, %v", err)
}
rep.Reserved = tmp[0]
rep.BndAddress.AddrType = tmp[1]
switch rep.BndAddress.AddrType {
case ATYPDomain:
if _, err = io.ReadFull(r, tmp[:1]); err != nil {
return rep, fmt.Errorf("failed to get request, %v", err)
}
domainLen := int(tmp[0])
addr := make([]byte, domainLen+2)
if _, err = io.ReadFull(r, addr); err != nil {
return rep, fmt.Errorf("failed to get request, %v", err)
}
rep.BndAddress.FQDN = string(addr[:domainLen])
rep.BndAddress.Port = BuildPort(addr[domainLen], addr[domainLen+1])
case ATYPIPv4:
addr := make([]byte, net.IPv4len+2)
if _, err = io.ReadFull(r, addr); err != nil {
return rep, fmt.Errorf("failed to get request, %v", err)
}
rep.BndAddress.IP = net.IPv4(addr[0], addr[1], addr[2], addr[3])
rep.BndAddress.Port = BuildPort(addr[net.IPv4len], addr[net.IPv4len+1])
case ATYPIPv6:
addr := make([]byte, net.IPv6len+2)
if _, err = io.ReadFull(r, addr); err != nil {
return rep, fmt.Errorf("failed to get request, %v", err)
}
rep.BndAddress.IP = addr[:net.IPv6len]
rep.BndAddress.Port = BuildPort(addr[net.IPv6len], addr[net.IPv6len+1])
default:
return rep, ErrUnrecognizedAddrType
}
return rep, nil
}

195
statute/message_test.go Normal file

@ -0,0 +1,195 @@
package statute
import (
"bytes"
"io"
"net"
"reflect"
"testing"
)
func TestParseRequest(t *testing.T) {
type args struct {
r io.Reader
}
tests := []struct {
name string
args args
want Request
wantErr bool
}{
{
"SOCKS5 IPV4",
args{bytes.NewReader([]byte{VersionSocks5, CommandConnect, 0, ATYPIPv4, 127, 0, 0, 1, 0x1f, 0x90})},
Request{
VersionSocks5, CommandConnect, 0,
AddrSpec{IP: net.IPv4(127, 0, 0, 1), Port: 8080, AddrType: ATYPIPv4},
},
false,
},
{
"SOCKS5 IPV6",
args{bytes.NewReader([]byte{VersionSocks5, CommandConnect, 0, ATYPIPv6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x1f, 0x90})},
Request{
VersionSocks5, CommandConnect, 0,
AddrSpec{IP: net.IPv6zero, Port: 8080, AddrType: ATYPIPv6},
},
false,
},
{
"SOCKS5 FQDN",
args{bytes.NewReader([]byte{VersionSocks5, CommandConnect, 0, ATYPDomain, 9, 'l', 'o', 'c', 'a', 'l', 'h', 'o', 's', 't', 0x1f, 0x90})},
Request{
VersionSocks5, CommandConnect, 0,
AddrSpec{FQDN: "localhost", Port: 8080, AddrType: ATYPDomain},
},
false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotHd, err := ParseRequest(tt.args.r)
if (err != nil) != tt.wantErr {
t.Errorf("ParseRequest() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(gotHd, tt.want) {
t.Errorf("ParseRequest() gotHd = %+v, want %+v", gotHd, tt.want)
}
})
}
}
func TestRequest_Bytes(t *testing.T) {
tests := []struct {
name string
request Request
wantB []byte
}{
{
"SOCKS5 IPV4",
Request{
VersionSocks5, CommandConnect, 0,
AddrSpec{IP: net.IPv4(127, 0, 0, 1), Port: 8080, AddrType: ATYPIPv4},
},
[]byte{VersionSocks5, CommandConnect, 0, ATYPIPv4, 127, 0, 0, 1, 0x1f, 0x90},
},
{
"SOCKS5 IPV6",
Request{
VersionSocks5, CommandConnect, 0,
AddrSpec{IP: net.IPv6zero, Port: 8080, AddrType: ATYPIPv6},
},
[]byte{VersionSocks5, CommandConnect, 0, ATYPIPv6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x1f, 0x90},
},
{
"SOCKS5 FQDN",
Request{
VersionSocks5, CommandConnect, 0,
AddrSpec{FQDN: "localhost", Port: 8080, AddrType: ATYPDomain},
},
[]byte{VersionSocks5, CommandConnect, 0, ATYPDomain, 9, 'l', 'o', 'c', 'a', 'l', 'h', 'o', 's', 't', 0x1f, 0x90},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if gotB := tt.request.Bytes(); !reflect.DeepEqual(gotB, tt.wantB) {
t.Errorf("Bytes() = %v, want %v", gotB, tt.wantB)
}
})
}
}
func TestParseReply(t *testing.T) {
type args struct {
r io.Reader
}
tests := []struct {
name string
args args
want Reply
wantErr bool
}{
{
"SOCKS5 IPV4",
args{bytes.NewReader([]byte{VersionSocks5, RepSuccess, 0, ATYPIPv4, 127, 0, 0, 1, 0x1f, 0x90})},
Reply{
VersionSocks5, RepSuccess, 0,
AddrSpec{IP: net.IPv4(127, 0, 0, 1), Port: 8080, AddrType: ATYPIPv4},
},
false,
},
{
"SOCKS5 IPV6",
args{bytes.NewReader([]byte{VersionSocks5, RepSuccess, 0, ATYPIPv6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x1f, 0x90})},
Reply{
VersionSocks5, RepSuccess, 0,
AddrSpec{IP: net.IPv6zero, Port: 8080, AddrType: ATYPIPv6},
},
false,
},
{
"SOCKS5 FQDN",
args{bytes.NewReader([]byte{VersionSocks5, RepSuccess, 0, ATYPDomain, 9, 'l', 'o', 'c', 'a', 'l', 'h', 'o', 's', 't', 0x1f, 0x90})},
Reply{
VersionSocks5, RepSuccess, 0,
AddrSpec{FQDN: "localhost", Port: 8080, AddrType: ATYPDomain},
},
false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := ParseReply(tt.args.r)
if (err != nil) != tt.wantErr {
t.Errorf("ParseReply() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("ParseReply() got = %+v, want %+v", got, tt.want)
}
})
}
}
func TestReply_Bytes(t *testing.T) {
tests := []struct {
name string
reply Reply
wantB []byte
}{
{
"SOCKS5 IPV4",
Reply{
VersionSocks5, CommandConnect, 0,
AddrSpec{IP: net.IPv4(127, 0, 0, 1), Port: 8080, AddrType: ATYPIPv4},
},
[]byte{VersionSocks5, CommandConnect, 0, ATYPIPv4, 127, 0, 0, 1, 0x1f, 0x90},
},
{
"SOCKS5 IPV6",
Reply{
VersionSocks5, CommandConnect, 0,
AddrSpec{IP: net.IPv6zero, Port: 8080, AddrType: ATYPIPv6},
},
[]byte{VersionSocks5, CommandConnect, 0, ATYPIPv6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x1f, 0x90},
},
{
"SOCKS5 FQDN",
Reply{
VersionSocks5, CommandConnect, 0,
AddrSpec{FQDN: "localhost", Port: 8080, AddrType: ATYPDomain},
},
[]byte{VersionSocks5, CommandConnect, 0, ATYPDomain, 9, 'l', 'o', 'c', 'a', 'l', 'h', 'o', 's', 't', 0x1f, 0x90},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if gotB := tt.reply.Bytes(); !reflect.DeepEqual(gotB, tt.wantB) {
t.Errorf("Bytes() = %v, want %v", gotB, tt.wantB)
}
})
}
}

@ -12,11 +12,11 @@ type AddrSpec struct {
FQDN string
IP net.IP
Port int
// private stuff set when Header parsed
// private stuff set when Request parsed
AddrType uint8
}
// Address returns a string suitable to dial; prefer returning IP-based
// DstAddress returns a string suitable to dial; prefer returning IP-based
// address, fallback to FQDN
func (a *AddrSpec) String() string {
if 0 != len(a.IP) {
@ -25,7 +25,7 @@ func (a *AddrSpec) String() string {
return net.JoinHostPort(a.FQDN, strconv.Itoa(a.Port))
}
// Address returns a string which may be specified
// DstAddress returns a string which may be specified
// if IPv4/IPv6 will return < ip:port >
// if FQDN will return < domain ip:port >
// Note: do not used to dial, Please use String