add datagram test
rename request and response api name
This commit is contained in:
parent
916af03167
commit
35e543fcfb
@ -88,8 +88,7 @@ func (c *Client) Read(b []byte) (int, error) {
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
pkt := statute.Packet{}
|
||||
err = pkt.Parse(b1[:n])
|
||||
pkt, err := statute.ParseDatagram(b1[:n])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@ -101,7 +100,7 @@ func (c *Client) Write(b []byte) (int, error) {
|
||||
if c.UDPConn == nil {
|
||||
return c.TCPConn.Write(b)
|
||||
}
|
||||
pkt, err := statute.NewPacket(c.RemoteAddress.String(), b)
|
||||
pkt, err := statute.NewDatagram(c.RemoteAddress.String(), b)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
108
request.go
108
request.go
@ -33,15 +33,12 @@ type Request struct {
|
||||
RawDestAddr *statute.AddrSpec
|
||||
}
|
||||
|
||||
// NewRequest creates a new Request from the tcp connection
|
||||
func NewRequest(bufConn io.Reader) (*Request, error) {
|
||||
// ParseRequest creates a new Request from the tcp connection
|
||||
func ParseRequest(bufConn io.Reader) (*Request, error) {
|
||||
hd, err := statute.ParseRequest(bufConn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if hd.Command != statute.CommandConnect && hd.Command != statute.CommandBind && hd.Command != statute.CommandAssociate {
|
||||
return nil, fmt.Errorf("unrecognized command[%d]", hd.Command)
|
||||
}
|
||||
return &Request{
|
||||
Request: hd,
|
||||
RawDestAddr: &hd.DstAddress,
|
||||
@ -51,20 +48,19 @@ func NewRequest(bufConn io.Reader) (*Request, error) {
|
||||
|
||||
// handleRequest is used for request processing after authentication
|
||||
func (s *Server) handleRequest(write io.Writer, req *Request) error {
|
||||
var err error
|
||||
ctx := context.Background()
|
||||
|
||||
// Resolve the address if we have a FQDN
|
||||
dest := req.RawDestAddr
|
||||
if dest.FQDN != "" {
|
||||
_ctx, addr, err := s.resolver.Resolve(ctx, dest.FQDN)
|
||||
ctx, dest.IP, err = s.resolver.Resolve(ctx, dest.FQDN)
|
||||
if err != nil {
|
||||
if err := SendReply(write, req.Request, statute.RepHostUnreachable); err != nil {
|
||||
if err := SendReply(write, statute.RepHostUnreachable, nil); err != nil {
|
||||
return fmt.Errorf("failed to send reply, %v", err)
|
||||
}
|
||||
return fmt.Errorf("failed to resolve destination[%v], %v", dest.FQDN, err)
|
||||
}
|
||||
ctx = _ctx
|
||||
dest.IP = addr
|
||||
}
|
||||
|
||||
// Apply any address rewrites
|
||||
@ -74,14 +70,14 @@ func (s *Server) handleRequest(write io.Writer, req *Request) error {
|
||||
}
|
||||
|
||||
// Check if this is allowed
|
||||
_ctx, ok := s.rules.Allow(ctx, req)
|
||||
var ok bool
|
||||
ctx, ok = s.rules.Allow(ctx, req)
|
||||
if !ok {
|
||||
if err := SendReply(write, req.Request, statute.RepRuleFailure); err != nil {
|
||||
if err := SendReply(write, statute.RepRuleFailure, nil); err != nil {
|
||||
return fmt.Errorf("failed to send reply, %v", err)
|
||||
}
|
||||
return fmt.Errorf("bind to %v blocked by rules", req.RawDestAddr)
|
||||
}
|
||||
ctx = _ctx
|
||||
|
||||
// Switch on the command
|
||||
switch req.Command {
|
||||
@ -101,7 +97,7 @@ func (s *Server) handleRequest(write io.Writer, req *Request) error {
|
||||
}
|
||||
return s.handleAssociate(ctx, write, req)
|
||||
default:
|
||||
if err := SendReply(write, req.Request, statute.RepCommandNotSupported); err != nil {
|
||||
if err := SendReply(write, statute.RepCommandNotSupported, nil); err != nil {
|
||||
return fmt.Errorf("failed to send reply, %v", err)
|
||||
}
|
||||
return fmt.Errorf("unsupported command[%v]", req.Command)
|
||||
@ -126,7 +122,7 @@ func (s *Server) handleConnect(ctx context.Context, writer io.Writer, request *R
|
||||
} else if strings.Contains(msg, "network is unreachable") {
|
||||
resp = statute.RepNetworkUnreachable
|
||||
}
|
||||
if err := SendReply(writer, request.Request, resp); err != nil {
|
||||
if err := SendReply(writer, resp, nil); err != nil {
|
||||
return fmt.Errorf("failed to send reply, %v", err)
|
||||
}
|
||||
return fmt.Errorf("connect to %v failed, %v", request.RawDestAddr, err)
|
||||
@ -134,7 +130,7 @@ func (s *Server) handleConnect(ctx context.Context, writer io.Writer, request *R
|
||||
defer target.Close()
|
||||
|
||||
// Send success
|
||||
if err := SendReply(writer, request.Request, statute.RepSuccess, target.LocalAddr()); err != nil {
|
||||
if err := SendReply(writer, statute.RepSuccess, target.LocalAddr()); err != nil {
|
||||
return fmt.Errorf("failed to send reply, %v", err)
|
||||
}
|
||||
|
||||
@ -156,7 +152,7 @@ func (s *Server) handleConnect(ctx context.Context, writer io.Writer, request *R
|
||||
// handleBind is used to handle a connect command
|
||||
func (s *Server) handleBind(_ context.Context, writer io.Writer, request *Request) error {
|
||||
// TODO: Support bind
|
||||
if err := SendReply(writer, request.Request, statute.RepCommandNotSupported); err != nil {
|
||||
if err := SendReply(writer, statute.RepCommandNotSupported, nil); err != nil {
|
||||
return fmt.Errorf("failed to send reply: %v", err)
|
||||
}
|
||||
return nil
|
||||
@ -180,7 +176,7 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, request
|
||||
} else if strings.Contains(msg, "network is unreachable") {
|
||||
resp = statute.RepNetworkUnreachable
|
||||
}
|
||||
if err := SendReply(writer, request.Request, resp); err != nil {
|
||||
if err := SendReply(writer, resp, nil); err != nil {
|
||||
return fmt.Errorf("failed to send reply, %v", err)
|
||||
}
|
||||
return fmt.Errorf("connect to %v failed, %v", request.RawDestAddr, err)
|
||||
@ -189,7 +185,7 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, request
|
||||
|
||||
targetUDP, ok := target.(*net.UDPConn)
|
||||
if !ok {
|
||||
if err := SendReply(writer, request.Request, statute.RepServerFailure); err != nil {
|
||||
if err := SendReply(writer, statute.RepServerFailure, nil); err != nil {
|
||||
return fmt.Errorf("failed to send reply, %v", err)
|
||||
}
|
||||
return fmt.Errorf("dial udp invalid")
|
||||
@ -197,7 +193,7 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, request
|
||||
|
||||
bindLn, err := net.ListenUDP("udp", nil)
|
||||
if err != nil {
|
||||
if err := SendReply(writer, request.Request, statute.RepServerFailure); err != nil {
|
||||
if err := SendReply(writer, statute.RepServerFailure, nil); err != nil {
|
||||
return fmt.Errorf("failed to send reply, %v", err)
|
||||
}
|
||||
return fmt.Errorf("listen udp failed, %v", err)
|
||||
@ -206,19 +202,11 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, request
|
||||
|
||||
s.logger.Errorf("target addr %v, listen addr: %s", targetUDP.RemoteAddr(), bindLn.LocalAddr())
|
||||
// send BND.ADDR and BND.PORT, client must
|
||||
if err = SendReply(writer, request.Request, statute.RepSuccess, bindLn.LocalAddr()); err != nil {
|
||||
if err = SendReply(writer, statute.RepSuccess, bindLn.LocalAddr()); err != nil {
|
||||
return fmt.Errorf("failed to send reply, %v", err)
|
||||
}
|
||||
|
||||
s.submit(func() {
|
||||
/*
|
||||
The SOCKS UDP request/response is formed as follows:
|
||||
+-----+------+-------+----------+----------+----------+
|
||||
| RSV | FRAG | ATYP | DST.ADDR | DST.PORT | DATA |
|
||||
+-----+------+-------+----------+----------+----------+
|
||||
| 2 | 1 | X'00' | Variable | 2 | Variable |
|
||||
+-----+------+-------+----------+----------+----------+
|
||||
*/
|
||||
// read from client and write to remote server
|
||||
conns := sync.Map{}
|
||||
bufPool := s.bufferPool.Get()
|
||||
@ -237,8 +225,8 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, request
|
||||
continue
|
||||
}
|
||||
|
||||
pk := statute.Packet{}
|
||||
if err := pk.Parse(bufPool[:n]); err != nil {
|
||||
pk, err := statute.ParseDatagram(bufPool[:n])
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
@ -260,7 +248,7 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, request
|
||||
return
|
||||
}
|
||||
|
||||
pkb, err := statute.NewPacket(remote.String(), buf[:n])
|
||||
pkb, err := statute.NewDatagram(remote.String(), buf[:n])
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
@ -301,45 +289,35 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, request
|
||||
}
|
||||
|
||||
// SendReply is used to send a reply message
|
||||
func SendReply(w io.Writer, head statute.Request, resp uint8, bindAddr ...net.Addr) error {
|
||||
head.Command = resp
|
||||
func SendReply(w io.Writer, resp uint8, bindAddr net.Addr) error {
|
||||
rsp := statute.Reply{
|
||||
Version: statute.VersionSocks5,
|
||||
Response: resp,
|
||||
BndAddress: statute.AddrSpec{
|
||||
AddrType: statute.ATYPIPv4,
|
||||
IP: net.IPv4zero,
|
||||
Port: 0,
|
||||
},
|
||||
}
|
||||
|
||||
if len(bindAddr) == 0 {
|
||||
head.DstAddress.AddrType = statute.ATYPIPv4
|
||||
head.DstAddress.IP = []byte{0, 0, 0, 0}
|
||||
head.DstAddress.Port = 0
|
||||
} else {
|
||||
addrSpec := statute.AddrSpec{}
|
||||
if tcpAddr, ok := bindAddr[0].(*net.TCPAddr); ok && tcpAddr != nil {
|
||||
addrSpec.IP = tcpAddr.IP
|
||||
addrSpec.Port = tcpAddr.Port
|
||||
} else if udpAddr, ok := bindAddr[0].(*net.UDPAddr); ok && udpAddr != nil {
|
||||
addrSpec.IP = udpAddr.IP
|
||||
addrSpec.Port = udpAddr.Port
|
||||
if rsp.Response == statute.RepSuccess {
|
||||
if tcpAddr, ok := bindAddr.(*net.TCPAddr); ok && tcpAddr != nil {
|
||||
rsp.BndAddress.IP = tcpAddr.IP
|
||||
rsp.BndAddress.Port = tcpAddr.Port
|
||||
} else if udpAddr, ok := bindAddr.(*net.UDPAddr); ok && udpAddr != nil {
|
||||
rsp.BndAddress.IP = udpAddr.IP
|
||||
rsp.BndAddress.Port = udpAddr.Port
|
||||
} else {
|
||||
addrSpec.IP = []byte{0, 0, 0, 0}
|
||||
addrSpec.Port = 0
|
||||
rsp.Response = statute.RepAddrTypeNotSupported
|
||||
}
|
||||
switch {
|
||||
case addrSpec.FQDN != "":
|
||||
head.DstAddress.AddrType = statute.ATYPDomain
|
||||
head.DstAddress.FQDN = addrSpec.FQDN
|
||||
head.DstAddress.Port = addrSpec.Port
|
||||
case addrSpec.IP.To4() != nil:
|
||||
head.DstAddress.AddrType = statute.ATYPIPv4
|
||||
head.DstAddress.IP = addrSpec.IP.To4()
|
||||
head.DstAddress.Port = addrSpec.Port
|
||||
case addrSpec.IP.To16() != nil:
|
||||
head.DstAddress.AddrType = statute.ATYPIPv6
|
||||
head.DstAddress.IP = addrSpec.IP.To16()
|
||||
head.DstAddress.Port = addrSpec.Port
|
||||
default:
|
||||
return fmt.Errorf("failed to format address[%v]", bindAddr)
|
||||
if rsp.BndAddress.IP.To4() != nil {
|
||||
rsp.BndAddress.AddrType = statute.ATYPIPv4
|
||||
} else if rsp.BndAddress.IP.To16() != nil {
|
||||
rsp.BndAddress.AddrType = statute.ATYPIPv6
|
||||
}
|
||||
|
||||
}
|
||||
// Send the message
|
||||
_, err := w.Write(head.Bytes())
|
||||
_, err := w.Write(rsp.Bytes())
|
||||
return err
|
||||
}
|
||||
|
||||
@ -354,7 +332,7 @@ func (s *Server) Proxy(dst io.Writer, src io.Reader) error {
|
||||
defer s.bufferPool.Put(buf)
|
||||
_, err := io.CopyBuffer(dst, src, buf[:cap(buf)])
|
||||
if tcpConn, ok := dst.(closeWriter); ok {
|
||||
tcpConn.CloseWrite()
|
||||
tcpConn.CloseWrite() // nolint: errcheck
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
@ -2,7 +2,6 @@ package socks5
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
@ -10,6 +9,8 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/thinkgos/go-socks5/statute"
|
||||
)
|
||||
|
||||
type MockConn struct {
|
||||
@ -32,7 +33,7 @@ func TestRequest_Connect(t *testing.T) {
|
||||
go func() {
|
||||
conn, err := l.Accept()
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
defer conn.Close() // nolint: errcheck
|
||||
|
||||
buf := make([]byte, 4)
|
||||
_, err = io.ReadAtLeast(conn, buf, 4)
|
||||
@ -43,7 +44,7 @@ func TestRequest_Connect(t *testing.T) {
|
||||
}()
|
||||
lAddr := l.Addr().(*net.TCPAddr)
|
||||
|
||||
// Make server
|
||||
// Make proxy server
|
||||
proxySrv := &Server{
|
||||
rules: NewPermitAll(),
|
||||
resolver: DNSResolver{},
|
||||
@ -52,33 +53,27 @@ func TestRequest_Connect(t *testing.T) {
|
||||
}
|
||||
|
||||
// Create the connect request
|
||||
buf := bytes.NewBuffer(nil)
|
||||
buf.Write([]byte{5, 1, 0, 1, 127, 0, 0, 1}) // nolint: errcheck
|
||||
|
||||
port := []byte{0, 0}
|
||||
binary.BigEndian.PutUint16(port, uint16(lAddr.Port))
|
||||
buf.Write(port) // nolint: errcheck
|
||||
|
||||
hi, lo := statute.BreakPort(lAddr.Port)
|
||||
buf := bytes.NewBuffer([]byte{
|
||||
statute.VersionSocks5, statute.CommandConnect, 0,
|
||||
statute.ATYPIPv4, 127, 0, 0, 1, hi, lo,
|
||||
})
|
||||
// Send a ping
|
||||
buf.Write([]byte("ping")) // nolint: errcheck
|
||||
|
||||
// Handle the request
|
||||
resp := &MockConn{}
|
||||
req, err := NewRequest(buf)
|
||||
rsp := new(MockConn)
|
||||
req, err := ParseRequest(buf)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = proxySrv.handleRequest(resp, req)
|
||||
err = proxySrv.handleRequest(rsp, req)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify response
|
||||
out := resp.buf.Bytes()
|
||||
out := rsp.buf.Bytes()
|
||||
expected := []byte{
|
||||
5,
|
||||
0,
|
||||
0,
|
||||
1,
|
||||
127, 0, 0, 1,
|
||||
0, 0,
|
||||
statute.VersionSocks5, statute.RepSuccess, 0,
|
||||
statute.ATYPIPv4, 127, 0, 0, 1, 0, 0,
|
||||
'p', 'o', 'n', 'g',
|
||||
}
|
||||
|
||||
@ -102,6 +97,7 @@ func TestRequest_Connect_RuleFail(t *testing.T) {
|
||||
_, err = io.ReadAtLeast(conn, buf, 4)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []byte("ping"), buf)
|
||||
|
||||
conn.Write([]byte("pong")) // nolint: errcheck
|
||||
}()
|
||||
lAddr := l.Addr().(*net.TCPAddr)
|
||||
@ -115,33 +111,28 @@ func TestRequest_Connect_RuleFail(t *testing.T) {
|
||||
}
|
||||
|
||||
// Create the connect request
|
||||
buf := bytes.NewBuffer(nil)
|
||||
buf.Write([]byte{5, 1, 0, 1, 127, 0, 0, 1})
|
||||
|
||||
port := []byte{0, 0}
|
||||
binary.BigEndian.PutUint16(port, uint16(lAddr.Port))
|
||||
buf.Write(port)
|
||||
hi, lo := statute.BreakPort(lAddr.Port)
|
||||
buf := bytes.NewBuffer([]byte{
|
||||
statute.VersionSocks5, statute.CommandConnect, 0,
|
||||
statute.ATYPIPv4, 127, 0, 0, 1, hi, lo,
|
||||
})
|
||||
|
||||
// Send a ping
|
||||
buf.Write([]byte("ping"))
|
||||
|
||||
// Handle the request
|
||||
resp := &MockConn{}
|
||||
req, err := NewRequest(buf)
|
||||
rsp := new(MockConn)
|
||||
req, err := ParseRequest(buf)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = s.handleRequest(resp, req)
|
||||
err = s.handleRequest(rsp, req)
|
||||
require.Contains(t, err.Error(), "blocked by rules")
|
||||
|
||||
// Verify response
|
||||
out := resp.buf.Bytes()
|
||||
out := rsp.buf.Bytes()
|
||||
expected := []byte{
|
||||
5,
|
||||
2,
|
||||
0,
|
||||
1,
|
||||
0, 0, 0, 0,
|
||||
0, 0,
|
||||
statute.VersionSocks5, statute.RepRuleFailure, 0,
|
||||
statute.ATYPIPv4, 0, 0, 0, 0, 0, 0,
|
||||
}
|
||||
require.Equal(t, expected, out)
|
||||
}
|
||||
|
22
server.go
22
server.go
@ -3,6 +3,7 @@ package socks5
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
@ -116,6 +117,7 @@ func (s *Server) ServeConn(conn net.Conn) error {
|
||||
var authContext *AuthContext
|
||||
|
||||
defer conn.Close()
|
||||
|
||||
bufConn := bufio.NewReader(conn)
|
||||
|
||||
mr, err := statute.ParseMethodRequest(bufConn)
|
||||
@ -133,24 +135,30 @@ func (s *Server) ServeConn(conn net.Conn) error {
|
||||
}
|
||||
|
||||
// The client request detail
|
||||
request, err := NewRequest(bufConn)
|
||||
request, err := ParseRequest(bufConn)
|
||||
if err != nil {
|
||||
if err == statute.ErrUnrecognizedAddrType {
|
||||
if err := SendReply(conn, statute.Request{Version: mr.Ver}, statute.RepAddrTypeNotSupported); err != nil {
|
||||
if errors.Is(err, statute.ErrUnrecognizedAddrType) {
|
||||
if err := SendReply(conn, statute.RepAddrTypeNotSupported, nil); err != nil {
|
||||
return fmt.Errorf("failed to send reply %w", err)
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("failed to read destination address, %w", err)
|
||||
}
|
||||
|
||||
if request.Request.Command != statute.CommandConnect &&
|
||||
request.Request.Command != statute.CommandBind &&
|
||||
request.Request.Command != statute.CommandAssociate {
|
||||
if err := SendReply(conn, statute.RepCommandNotSupported, nil); err != nil {
|
||||
return fmt.Errorf("failed to send reply, %v", err)
|
||||
}
|
||||
return fmt.Errorf("unrecognized command[%d]", request.Request.Command)
|
||||
}
|
||||
|
||||
request.AuthContext = authContext
|
||||
request.LocalAddr = conn.LocalAddr()
|
||||
request.RemoteAddr = conn.RemoteAddr()
|
||||
// Process the client request
|
||||
if err := s.handleRequest(conn, request); err != nil {
|
||||
return fmt.Errorf("failed to handle request, %v", err)
|
||||
}
|
||||
return nil
|
||||
return s.handleRequest(conn, request)
|
||||
}
|
||||
|
||||
// authenticate is used to handle connection authentication
|
||||
|
@ -31,11 +31,12 @@ func TestSOCKS5_Connect(t *testing.T) {
|
||||
_, err = io.ReadAtLeast(conn, buf, 4)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []byte("ping"), buf)
|
||||
_, _ = conn.Write([]byte("pong"))
|
||||
|
||||
conn.Write([]byte("pong")) // nolint: errcheck
|
||||
}()
|
||||
lAddr := l.Addr().(*net.TCPAddr)
|
||||
|
||||
// Create a socks server
|
||||
// Create a socks server with UserPass auth.
|
||||
cator := UserPassAuthenticator{StaticCredentials{"foo": "bar"}}
|
||||
srv := NewServer(
|
||||
WithAuthMethods([]Authenticator{cator}),
|
||||
@ -54,18 +55,20 @@ func TestSOCKS5_Connect(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Connect, auth and connec to local
|
||||
req := new(bytes.Buffer)
|
||||
req.Write([]byte{statute.VersionSocks5, 2, statute.MethodNoAuth, statute.MethodUserPassAuth})
|
||||
req.Write([]byte{statute.UserPassAuthVersion, 3, 'f', 'o', 'o', 3, 'b', 'a', 'r'})
|
||||
req := bytes.NewBuffer(
|
||||
[]byte{
|
||||
statute.VersionSocks5, 2, statute.MethodNoAuth, statute.MethodUserPassAuth, // methods
|
||||
statute.UserPassAuthVersion, 3, 'f', 'o', 'o', 3, 'b', 'a', 'r', // userpass auth
|
||||
})
|
||||
reqHead := statute.Request{
|
||||
Version: statute.VersionSocks5,
|
||||
Command: statute.CommandConnect,
|
||||
Reserved: 0,
|
||||
DstAddress: statute.AddrSpec{
|
||||
"",
|
||||
net.ParseIP("127.0.0.1"),
|
||||
lAddr.Port,
|
||||
statute.ATYPIPv4,
|
||||
FQDN: "",
|
||||
IP: net.ParseIP("127.0.0.1"),
|
||||
Port: lAddr.Port,
|
||||
AddrType: statute.ATYPIPv4,
|
||||
},
|
||||
}
|
||||
req.Write(reqHead.Bytes())
|
||||
@ -73,11 +76,11 @@ func TestSOCKS5_Connect(t *testing.T) {
|
||||
req.Write([]byte("ping"))
|
||||
|
||||
// Send all the bytes
|
||||
conn.Write(req.Bytes())
|
||||
conn.Write(req.Bytes()) // nolint: errcheck
|
||||
|
||||
// Verify response
|
||||
expected := []byte{
|
||||
statute.VersionSocks5, statute.MethodUserPassAuth, // use user password auth
|
||||
statute.VersionSocks5, statute.MethodUserPassAuth, // response use UserPass auth
|
||||
statute.UserPassAuthVersion, statute.AuthSuccess, // response auth success
|
||||
}
|
||||
rspHead := statute.Request{
|
||||
@ -85,22 +88,19 @@ func TestSOCKS5_Connect(t *testing.T) {
|
||||
Command: statute.RepSuccess,
|
||||
Reserved: 0,
|
||||
DstAddress: statute.AddrSpec{
|
||||
"",
|
||||
net.ParseIP("127.0.0.1"),
|
||||
0, // Ignore the port
|
||||
statute.ATYPIPv4,
|
||||
FQDN: "",
|
||||
IP: net.ParseIP("127.0.0.1"),
|
||||
Port: 0,
|
||||
AddrType: statute.ATYPIPv4,
|
||||
},
|
||||
}
|
||||
expected = append(expected, rspHead.Bytes()...)
|
||||
expected = append(expected, []byte("pong")...)
|
||||
|
||||
out := make([]byte, len(expected))
|
||||
_ = conn.SetDeadline(time.Now().Add(time.Second))
|
||||
conn.SetDeadline(time.Now().Add(time.Second)) // nolint: errcheck
|
||||
_, err = io.ReadFull(conn, out)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Logf("proxy bind port: %d", statute.BuildPort(out[12], out[13]))
|
||||
|
||||
// Ignore the port
|
||||
out[12] = 0
|
||||
out[13] = 0
|
||||
@ -157,10 +157,10 @@ func TestSOCKS5_Associate(t *testing.T) {
|
||||
Command: statute.CommandAssociate,
|
||||
Reserved: 0,
|
||||
DstAddress: statute.AddrSpec{
|
||||
"",
|
||||
locIP,
|
||||
lAddr.Port,
|
||||
statute.ATYPIPv4,
|
||||
FQDN: "",
|
||||
IP: locIP,
|
||||
Port: lAddr.Port,
|
||||
AddrType: statute.ATYPIPv4,
|
||||
},
|
||||
}
|
||||
req.Write(reqHead.Bytes())
|
||||
@ -216,11 +216,11 @@ func Test_SocksWithProxy(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []byte("ping"), buf)
|
||||
|
||||
conn.Write([]byte("pong"))
|
||||
conn.Write([]byte("pong")) // nolint: errcheck
|
||||
}()
|
||||
lAddr := l.Addr().(*net.TCPAddr)
|
||||
|
||||
// Create a socks server
|
||||
// Create a socks server with UserPass auth.
|
||||
cator := UserPassAuthenticator{StaticCredentials{"foo": "bar"}}
|
||||
serv := NewServer(
|
||||
WithAuthMethods([]Authenticator{cator}),
|
||||
@ -245,14 +245,13 @@ func Test_SocksWithProxy(t *testing.T) {
|
||||
conn.Write([]byte("ping")) // nolint: errcheck
|
||||
|
||||
out := make([]byte, 4)
|
||||
_ = conn.SetDeadline(time.Now().Add(time.Second)) // nolint: errcheck
|
||||
conn.SetDeadline(time.Now().Add(time.Second)) // nolint: errcheck
|
||||
_, err = io.ReadFull(conn, out)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, []byte("pong"), out)
|
||||
}
|
||||
|
||||
/***************************** auth *******************************/
|
||||
/***************************** auth *******************************/
|
||||
|
||||
func TestNoAuth_Server(t *testing.T) {
|
||||
req := bytes.NewBuffer(nil)
|
||||
|
@ -4,119 +4,96 @@ import (
|
||||
"errors"
|
||||
"math"
|
||||
"net"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
/*
|
||||
The SOCKS UDP request/response is formed as follows:
|
||||
+-----+------+-------+----------+----------+----------+
|
||||
| RSV | FRAG | ATYP | DST.ADDR | DST.PORT | DATA |
|
||||
+-----+------+-------+----------+----------+----------+
|
||||
| 2 | 1 | X'00' | Variable | 2 | Variable |
|
||||
+-----+------+-------+----------+----------+----------+
|
||||
*/
|
||||
// Packet udp packet
|
||||
type Packet struct {
|
||||
// The SOCKS UDP request/response is formed as follows:
|
||||
// +-----+------+-------+----------+----------+----------+
|
||||
// | RSV | FRAG | ATYP | DST.ADDR | DST.PORT | DATA |
|
||||
// +-----+------+-------+----------+----------+----------+
|
||||
// | 2 | 1 | X'00' | Variable | 2 | Variable |
|
||||
// +-----+------+-------+----------+----------+----------+
|
||||
// Datagram udp packet
|
||||
type Datagram struct {
|
||||
RSV uint16
|
||||
Frag uint8
|
||||
ATYP uint8
|
||||
DstAddr AddrSpec
|
||||
Data []byte
|
||||
}
|
||||
|
||||
// NewEmptyPacket new empty packet
|
||||
func NewEmptyPacket() Packet {
|
||||
return Packet{}
|
||||
}
|
||||
|
||||
// NewPacket new packet with dest addr and data
|
||||
func NewPacket(destAddr string, data []byte) (p Packet, err error) {
|
||||
var host, port string
|
||||
|
||||
host, port, err = net.SplitHostPort(destAddr)
|
||||
// NewDatagram new packet with dest addr and data
|
||||
func NewDatagram(destAddr string, data []byte) (p Datagram, err error) {
|
||||
p.DstAddr, err = ParseAddrSpec(destAddr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
p.DstAddr.Port, err = strconv.Atoi(port)
|
||||
if err != nil {
|
||||
if p.DstAddr.AddrType == ATYPDomain && len(p.DstAddr.FQDN) > math.MaxUint8 {
|
||||
err = errors.New("destination host name too long")
|
||||
return
|
||||
}
|
||||
p.RSV = 0
|
||||
p.Frag = 0
|
||||
if ip := net.ParseIP(host); ip != nil {
|
||||
if ip4 := ip.To4(); ip4 != nil {
|
||||
p.ATYP = ATYPIPv4
|
||||
p.DstAddr.IP = ip
|
||||
} else {
|
||||
p.ATYP = ATYPIPv6
|
||||
p.DstAddr.IP = ip
|
||||
}
|
||||
} else {
|
||||
if len(host) > math.MaxUint8 {
|
||||
err = errors.New("destination host name too long")
|
||||
return
|
||||
}
|
||||
p.ATYP = ATYPDomain
|
||||
p.DstAddr.FQDN = host
|
||||
}
|
||||
p.Data = data
|
||||
return
|
||||
}
|
||||
|
||||
// ParseRequest parse to packet
|
||||
func (sf *Packet) Parse(b []byte) error {
|
||||
if len(b) <= 4+net.IPv4len+2 { // no data
|
||||
return errors.New("too short")
|
||||
func ParseDatagram(b []byte) (da Datagram, err error) {
|
||||
if len(b) < 4+net.IPv4len+2 { // no data
|
||||
err = errors.New("datagram to short")
|
||||
return
|
||||
}
|
||||
// ignore RSV
|
||||
sf.RSV = 0
|
||||
da.RSV = 0
|
||||
// FRAG
|
||||
sf.Frag = b[2]
|
||||
sf.ATYP = b[3]
|
||||
da.Frag = b[2]
|
||||
da.DstAddr.AddrType = b[3]
|
||||
headLen := 4
|
||||
switch sf.ATYP {
|
||||
switch da.DstAddr.AddrType {
|
||||
case ATYPIPv4:
|
||||
headLen += net.IPv4len + 2
|
||||
sf.DstAddr.IP = net.IPv4(b[4], b[5], b[6], b[7])
|
||||
sf.DstAddr.Port = BuildPort(b[4+net.IPv4len], b[4+net.IPv4len+1])
|
||||
da.DstAddr.IP = net.IPv4(b[4], b[5], b[6], b[7])
|
||||
da.DstAddr.Port = BuildPort(b[4+net.IPv4len], b[4+net.IPv4len+1])
|
||||
case ATYPIPv6:
|
||||
headLen += net.IPv6len + 2
|
||||
if len(b) <= headLen {
|
||||
return errors.New("too short")
|
||||
err = errors.New("datagram to short")
|
||||
return
|
||||
}
|
||||
|
||||
sf.DstAddr.IP = net.IP{b[4], b[5], b[6], b[7], b[8], b[9], b[10], b[11], b[12], b[13], b[14], b[15], b[16], b[17], b[18], b[19]}
|
||||
sf.DstAddr.Port = BuildPort(b[4+net.IPv6len], b[4+net.IPv6len+1])
|
||||
da.DstAddr.IP = net.IP{b[4], b[5], b[6], b[7], b[8], b[9], b[10], b[11], b[12], b[13], b[14], b[15], b[16], b[17], b[18], b[19]}
|
||||
da.DstAddr.Port = BuildPort(b[4+net.IPv6len], b[4+net.IPv6len+1])
|
||||
case ATYPDomain:
|
||||
addrLen := int(b[4])
|
||||
headLen += 1 + addrLen + 2
|
||||
if len(b) <= headLen {
|
||||
return errors.New("too short")
|
||||
err = errors.New("datagram to short")
|
||||
return
|
||||
}
|
||||
str := make([]byte, addrLen)
|
||||
copy(str, b[5:5+addrLen])
|
||||
sf.DstAddr.FQDN = string(str)
|
||||
sf.DstAddr.Port = BuildPort(b[5+addrLen], b[5+addrLen+1])
|
||||
da.DstAddr.FQDN = string(str)
|
||||
da.DstAddr.Port = BuildPort(b[5+addrLen], b[5+addrLen+1])
|
||||
default:
|
||||
return ErrUnrecognizedAddrType
|
||||
err = ErrUnrecognizedAddrType
|
||||
return
|
||||
}
|
||||
sf.Data = b[headLen:]
|
||||
return nil
|
||||
da.Data = b[headLen:]
|
||||
return
|
||||
}
|
||||
|
||||
// Request returns s slice of packet reply
|
||||
func (sf *Packet) Header() []byte {
|
||||
// Request returns s slice of datagram header except data
|
||||
func (sf *Datagram) Header() []byte {
|
||||
bs := make([]byte, 0, 32)
|
||||
bs = append(bs, []byte{byte(sf.RSV << 8), byte(sf.RSV), sf.Frag}...)
|
||||
switch sf.ATYP {
|
||||
switch sf.DstAddr.AddrType {
|
||||
case ATYPIPv4:
|
||||
bs = append(bs, ATYPIPv4)
|
||||
bs = append(bs, sf.DstAddr.IP...)
|
||||
bs = append(bs, sf.DstAddr.IP.To4()...)
|
||||
case ATYPIPv6:
|
||||
bs = append(bs, ATYPIPv6)
|
||||
bs = append(bs, sf.DstAddr.IP...)
|
||||
bs = append(bs, sf.DstAddr.IP.To16()...)
|
||||
case ATYPDomain:
|
||||
bs = append(bs, ATYPDomain)
|
||||
bs = append(bs, ATYPDomain, byte(len(sf.DstAddr.FQDN)))
|
||||
bs = append(bs, []byte(sf.DstAddr.FQDN)...)
|
||||
}
|
||||
hi, lo := BreakPort(sf.DstAddr.Port)
|
||||
@ -124,6 +101,6 @@ func (sf *Packet) Header() []byte {
|
||||
return bs
|
||||
}
|
||||
|
||||
func (sf *Packet) Bytes() []byte {
|
||||
func (sf *Datagram) Bytes() []byte {
|
||||
return append(sf.Header(), sf.Data...)
|
||||
}
|
||||
|
145
statute/datagram_test.go
Normal file
145
statute/datagram_test.go
Normal file
@ -0,0 +1,145 @@
|
||||
package statute
|
||||
|
||||
import (
|
||||
"net"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestDatagram(t *testing.T) {
|
||||
_, err := NewDatagram("localhost", nil)
|
||||
require.Error(t, err)
|
||||
|
||||
_, err = NewDatagram("localhostlocalhostlocalhostlocalhostlocalhostlocalhostlocalhostlocalhostlocalhostlocalhost"+
|
||||
"localhostlocalhostlocalhostlocalhostlocalhostlocalhostlocalhostlocalhostlocalhostlocalhost"+
|
||||
"localhostlocalhostlocalhostlocalhostlocalhostlocalhostlocalhostlocalhostlocalhostlocalhost:8080", nil)
|
||||
require.Error(t, err)
|
||||
|
||||
datagram, err := NewDatagram("localhost:8080", []byte{1, 2, 3})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, Datagram{
|
||||
0, 0, AddrSpec{
|
||||
FQDN: "localhost",
|
||||
Port: 8080,
|
||||
AddrType: ATYPDomain,
|
||||
},
|
||||
[]byte{1, 2, 3},
|
||||
}, datagram)
|
||||
require.Equal(t, []byte{0, 0, 0, ATYPDomain, 9, 'l', 'o', 'c', 'a', 'l', 'h', 'o', 's', 't', 0x1f, 0x90}, datagram.Header())
|
||||
require.Equal(t, []byte{0, 0, 0, ATYPDomain, 9, 'l', 'o', 'c', 'a', 'l', 'h', 'o', 's', 't', 0x1f, 0x90, 1, 2, 3}, datagram.Bytes())
|
||||
|
||||
datagram, err = NewDatagram("127.0.0.1:8080", []byte{1, 2, 3})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, Datagram{
|
||||
0, 0, AddrSpec{
|
||||
IP: net.IPv4(127, 0, 0, 1),
|
||||
Port: 8080,
|
||||
AddrType: ATYPIPv4,
|
||||
},
|
||||
[]byte{1, 2, 3},
|
||||
}, datagram)
|
||||
require.Equal(t, []byte{0, 0, 0, ATYPIPv4, 127, 0, 0, 1, 0x1f, 0x90}, datagram.Header())
|
||||
require.Equal(t, []byte{0, 0, 0, ATYPIPv4, 127, 0, 0, 1, 0x1f, 0x90, 1, 2, 3}, datagram.Bytes())
|
||||
datagram, err = NewDatagram("[::1]:8080", []byte{1, 2, 3})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, Datagram{
|
||||
0, 0, AddrSpec{
|
||||
IP: net.IPv6loopback,
|
||||
Port: 8080,
|
||||
AddrType: ATYPIPv6,
|
||||
},
|
||||
[]byte{1, 2, 3},
|
||||
}, datagram)
|
||||
require.Equal(t, []byte{0, 0, 0, ATYPIPv6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0x1f, 0x90}, datagram.Header())
|
||||
require.Equal(t, []byte{0, 0, 0, ATYPIPv6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0x1f, 0x90, 1, 2, 3}, datagram.Bytes())
|
||||
}
|
||||
|
||||
func TestParseDatagram(t *testing.T) {
|
||||
type args struct {
|
||||
b []byte
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantDa Datagram
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
"IPv4",
|
||||
args{[]byte{0, 0, 0, ATYPIPv4, 127, 0, 0, 1, 0x1f, 0x90, 1, 2, 3}},
|
||||
Datagram{
|
||||
0, 0, AddrSpec{
|
||||
IP: net.IPv4(127, 0, 0, 1),
|
||||
Port: 8080,
|
||||
AddrType: ATYPIPv4,
|
||||
},
|
||||
[]byte{1, 2, 3},
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"IPv6",
|
||||
args{[]byte{0, 0, 0, ATYPIPv6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0x1f, 0x90, 1, 2, 3}},
|
||||
Datagram{
|
||||
0, 0, AddrSpec{
|
||||
IP: net.IPv6loopback,
|
||||
Port: 8080,
|
||||
AddrType: ATYPIPv6,
|
||||
},
|
||||
[]byte{1, 2, 3},
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"FQDN",
|
||||
args{[]byte{0, 0, 0, ATYPDomain, 9, 'l', 'o', 'c', 'a', 'l', 'h', 'o', 's', 't', 0x1f, 0x90, 1, 2, 3}},
|
||||
Datagram{
|
||||
0, 0, AddrSpec{
|
||||
FQDN: "localhost",
|
||||
Port: 8080,
|
||||
AddrType: ATYPDomain,
|
||||
},
|
||||
[]byte{1, 2, 3},
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"invalid address type",
|
||||
args{[]byte{0, 0, 0, 0x02, 127, 0, 0, 1, 0x1f, 0x90}},
|
||||
Datagram{},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"less min length",
|
||||
args{[]byte{0, 0, 0, ATYPIPv4, 127, 0, 0, 1, 0x1f}},
|
||||
Datagram{},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"less domain length",
|
||||
args{[]byte{0, 0, 0, ATYPDomain, 10, 127, 0, 0, 1, 0x1f, 0x09}},
|
||||
Datagram{},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"less ipv6 length",
|
||||
args{[]byte{0, 0, 0, ATYPIPv6, 127, 0, 0, 1, 0x1f, 0x09}},
|
||||
Datagram{},
|
||||
true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotDa, err := ParseDatagram(tt.args.b)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("ParseDatagram() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if err == nil && !reflect.DeepEqual(gotDa, tt.wantDa) {
|
||||
t.Errorf("ParseDatagram() gotDa = %v, want %v", gotDa, tt.wantDa)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user