add datagram test

rename request and response api name
This commit is contained in:
mo 2020-08-06 08:19:31 +08:00
parent 916af03167
commit 35e543fcfb
7 changed files with 301 additions and 204 deletions

@ -88,8 +88,7 @@ func (c *Client) Read(b []byte) (int, error) {
if err != nil {
return 0, err
}
pkt := statute.Packet{}
err = pkt.Parse(b1[:n])
pkt, err := statute.ParseDatagram(b1[:n])
if err != nil {
return 0, err
}
@ -101,7 +100,7 @@ func (c *Client) Write(b []byte) (int, error) {
if c.UDPConn == nil {
return c.TCPConn.Write(b)
}
pkt, err := statute.NewPacket(c.RemoteAddress.String(), b)
pkt, err := statute.NewDatagram(c.RemoteAddress.String(), b)
if err != nil {
return 0, err
}

@ -33,15 +33,12 @@ type Request struct {
RawDestAddr *statute.AddrSpec
}
// NewRequest creates a new Request from the tcp connection
func NewRequest(bufConn io.Reader) (*Request, error) {
// ParseRequest creates a new Request from the tcp connection
func ParseRequest(bufConn io.Reader) (*Request, error) {
hd, err := statute.ParseRequest(bufConn)
if err != nil {
return nil, err
}
if hd.Command != statute.CommandConnect && hd.Command != statute.CommandBind && hd.Command != statute.CommandAssociate {
return nil, fmt.Errorf("unrecognized command[%d]", hd.Command)
}
return &Request{
Request: hd,
RawDestAddr: &hd.DstAddress,
@ -51,20 +48,19 @@ func NewRequest(bufConn io.Reader) (*Request, error) {
// handleRequest is used for request processing after authentication
func (s *Server) handleRequest(write io.Writer, req *Request) error {
var err error
ctx := context.Background()
// Resolve the address if we have a FQDN
dest := req.RawDestAddr
if dest.FQDN != "" {
_ctx, addr, err := s.resolver.Resolve(ctx, dest.FQDN)
ctx, dest.IP, err = s.resolver.Resolve(ctx, dest.FQDN)
if err != nil {
if err := SendReply(write, req.Request, statute.RepHostUnreachable); err != nil {
if err := SendReply(write, statute.RepHostUnreachable, nil); err != nil {
return fmt.Errorf("failed to send reply, %v", err)
}
return fmt.Errorf("failed to resolve destination[%v], %v", dest.FQDN, err)
}
ctx = _ctx
dest.IP = addr
}
// Apply any address rewrites
@ -74,14 +70,14 @@ func (s *Server) handleRequest(write io.Writer, req *Request) error {
}
// Check if this is allowed
_ctx, ok := s.rules.Allow(ctx, req)
var ok bool
ctx, ok = s.rules.Allow(ctx, req)
if !ok {
if err := SendReply(write, req.Request, statute.RepRuleFailure); err != nil {
if err := SendReply(write, statute.RepRuleFailure, nil); err != nil {
return fmt.Errorf("failed to send reply, %v", err)
}
return fmt.Errorf("bind to %v blocked by rules", req.RawDestAddr)
}
ctx = _ctx
// Switch on the command
switch req.Command {
@ -101,7 +97,7 @@ func (s *Server) handleRequest(write io.Writer, req *Request) error {
}
return s.handleAssociate(ctx, write, req)
default:
if err := SendReply(write, req.Request, statute.RepCommandNotSupported); err != nil {
if err := SendReply(write, statute.RepCommandNotSupported, nil); err != nil {
return fmt.Errorf("failed to send reply, %v", err)
}
return fmt.Errorf("unsupported command[%v]", req.Command)
@ -126,7 +122,7 @@ func (s *Server) handleConnect(ctx context.Context, writer io.Writer, request *R
} else if strings.Contains(msg, "network is unreachable") {
resp = statute.RepNetworkUnreachable
}
if err := SendReply(writer, request.Request, resp); err != nil {
if err := SendReply(writer, resp, nil); err != nil {
return fmt.Errorf("failed to send reply, %v", err)
}
return fmt.Errorf("connect to %v failed, %v", request.RawDestAddr, err)
@ -134,7 +130,7 @@ func (s *Server) handleConnect(ctx context.Context, writer io.Writer, request *R
defer target.Close()
// Send success
if err := SendReply(writer, request.Request, statute.RepSuccess, target.LocalAddr()); err != nil {
if err := SendReply(writer, statute.RepSuccess, target.LocalAddr()); err != nil {
return fmt.Errorf("failed to send reply, %v", err)
}
@ -156,7 +152,7 @@ func (s *Server) handleConnect(ctx context.Context, writer io.Writer, request *R
// handleBind is used to handle a connect command
func (s *Server) handleBind(_ context.Context, writer io.Writer, request *Request) error {
// TODO: Support bind
if err := SendReply(writer, request.Request, statute.RepCommandNotSupported); err != nil {
if err := SendReply(writer, statute.RepCommandNotSupported, nil); err != nil {
return fmt.Errorf("failed to send reply: %v", err)
}
return nil
@ -180,7 +176,7 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, request
} else if strings.Contains(msg, "network is unreachable") {
resp = statute.RepNetworkUnreachable
}
if err := SendReply(writer, request.Request, resp); err != nil {
if err := SendReply(writer, resp, nil); err != nil {
return fmt.Errorf("failed to send reply, %v", err)
}
return fmt.Errorf("connect to %v failed, %v", request.RawDestAddr, err)
@ -189,7 +185,7 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, request
targetUDP, ok := target.(*net.UDPConn)
if !ok {
if err := SendReply(writer, request.Request, statute.RepServerFailure); err != nil {
if err := SendReply(writer, statute.RepServerFailure, nil); err != nil {
return fmt.Errorf("failed to send reply, %v", err)
}
return fmt.Errorf("dial udp invalid")
@ -197,7 +193,7 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, request
bindLn, err := net.ListenUDP("udp", nil)
if err != nil {
if err := SendReply(writer, request.Request, statute.RepServerFailure); err != nil {
if err := SendReply(writer, statute.RepServerFailure, nil); err != nil {
return fmt.Errorf("failed to send reply, %v", err)
}
return fmt.Errorf("listen udp failed, %v", err)
@ -206,19 +202,11 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, request
s.logger.Errorf("target addr %v, listen addr: %s", targetUDP.RemoteAddr(), bindLn.LocalAddr())
// send BND.ADDR and BND.PORT, client must
if err = SendReply(writer, request.Request, statute.RepSuccess, bindLn.LocalAddr()); err != nil {
if err = SendReply(writer, statute.RepSuccess, bindLn.LocalAddr()); err != nil {
return fmt.Errorf("failed to send reply, %v", err)
}
s.submit(func() {
/*
The SOCKS UDP request/response is formed as follows:
+-----+------+-------+----------+----------+----------+
| RSV | FRAG | ATYP | DST.ADDR | DST.PORT | DATA |
+-----+------+-------+----------+----------+----------+
| 2 | 1 | X'00' | Variable | 2 | Variable |
+-----+------+-------+----------+----------+----------+
*/
// read from client and write to remote server
conns := sync.Map{}
bufPool := s.bufferPool.Get()
@ -237,8 +225,8 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, request
continue
}
pk := statute.Packet{}
if err := pk.Parse(bufPool[:n]); err != nil {
pk, err := statute.ParseDatagram(bufPool[:n])
if err != nil {
continue
}
@ -260,7 +248,7 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, request
return
}
pkb, err := statute.NewPacket(remote.String(), buf[:n])
pkb, err := statute.NewDatagram(remote.String(), buf[:n])
if err != nil {
continue
}
@ -301,45 +289,35 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, request
}
// SendReply is used to send a reply message
func SendReply(w io.Writer, head statute.Request, resp uint8, bindAddr ...net.Addr) error {
head.Command = resp
func SendReply(w io.Writer, resp uint8, bindAddr net.Addr) error {
rsp := statute.Reply{
Version: statute.VersionSocks5,
Response: resp,
BndAddress: statute.AddrSpec{
AddrType: statute.ATYPIPv4,
IP: net.IPv4zero,
Port: 0,
},
}
if len(bindAddr) == 0 {
head.DstAddress.AddrType = statute.ATYPIPv4
head.DstAddress.IP = []byte{0, 0, 0, 0}
head.DstAddress.Port = 0
} else {
addrSpec := statute.AddrSpec{}
if tcpAddr, ok := bindAddr[0].(*net.TCPAddr); ok && tcpAddr != nil {
addrSpec.IP = tcpAddr.IP
addrSpec.Port = tcpAddr.Port
} else if udpAddr, ok := bindAddr[0].(*net.UDPAddr); ok && udpAddr != nil {
addrSpec.IP = udpAddr.IP
addrSpec.Port = udpAddr.Port
if rsp.Response == statute.RepSuccess {
if tcpAddr, ok := bindAddr.(*net.TCPAddr); ok && tcpAddr != nil {
rsp.BndAddress.IP = tcpAddr.IP
rsp.BndAddress.Port = tcpAddr.Port
} else if udpAddr, ok := bindAddr.(*net.UDPAddr); ok && udpAddr != nil {
rsp.BndAddress.IP = udpAddr.IP
rsp.BndAddress.Port = udpAddr.Port
} else {
addrSpec.IP = []byte{0, 0, 0, 0}
addrSpec.Port = 0
rsp.Response = statute.RepAddrTypeNotSupported
}
switch {
case addrSpec.FQDN != "":
head.DstAddress.AddrType = statute.ATYPDomain
head.DstAddress.FQDN = addrSpec.FQDN
head.DstAddress.Port = addrSpec.Port
case addrSpec.IP.To4() != nil:
head.DstAddress.AddrType = statute.ATYPIPv4
head.DstAddress.IP = addrSpec.IP.To4()
head.DstAddress.Port = addrSpec.Port
case addrSpec.IP.To16() != nil:
head.DstAddress.AddrType = statute.ATYPIPv6
head.DstAddress.IP = addrSpec.IP.To16()
head.DstAddress.Port = addrSpec.Port
default:
return fmt.Errorf("failed to format address[%v]", bindAddr)
if rsp.BndAddress.IP.To4() != nil {
rsp.BndAddress.AddrType = statute.ATYPIPv4
} else if rsp.BndAddress.IP.To16() != nil {
rsp.BndAddress.AddrType = statute.ATYPIPv6
}
}
// Send the message
_, err := w.Write(head.Bytes())
_, err := w.Write(rsp.Bytes())
return err
}
@ -354,7 +332,7 @@ func (s *Server) Proxy(dst io.Writer, src io.Reader) error {
defer s.bufferPool.Put(buf)
_, err := io.CopyBuffer(dst, src, buf[:cap(buf)])
if tcpConn, ok := dst.(closeWriter); ok {
tcpConn.CloseWrite()
tcpConn.CloseWrite() // nolint: errcheck
}
return err
}

@ -2,7 +2,6 @@ package socks5
import (
"bytes"
"encoding/binary"
"io"
"log"
"net"
@ -10,6 +9,8 @@ import (
"testing"
"github.com/stretchr/testify/require"
"github.com/thinkgos/go-socks5/statute"
)
type MockConn struct {
@ -32,7 +33,7 @@ func TestRequest_Connect(t *testing.T) {
go func() {
conn, err := l.Accept()
require.NoError(t, err)
defer conn.Close()
defer conn.Close() // nolint: errcheck
buf := make([]byte, 4)
_, err = io.ReadAtLeast(conn, buf, 4)
@ -43,7 +44,7 @@ func TestRequest_Connect(t *testing.T) {
}()
lAddr := l.Addr().(*net.TCPAddr)
// Make server
// Make proxy server
proxySrv := &Server{
rules: NewPermitAll(),
resolver: DNSResolver{},
@ -52,33 +53,27 @@ func TestRequest_Connect(t *testing.T) {
}
// Create the connect request
buf := bytes.NewBuffer(nil)
buf.Write([]byte{5, 1, 0, 1, 127, 0, 0, 1}) // nolint: errcheck
port := []byte{0, 0}
binary.BigEndian.PutUint16(port, uint16(lAddr.Port))
buf.Write(port) // nolint: errcheck
hi, lo := statute.BreakPort(lAddr.Port)
buf := bytes.NewBuffer([]byte{
statute.VersionSocks5, statute.CommandConnect, 0,
statute.ATYPIPv4, 127, 0, 0, 1, hi, lo,
})
// Send a ping
buf.Write([]byte("ping")) // nolint: errcheck
// Handle the request
resp := &MockConn{}
req, err := NewRequest(buf)
rsp := new(MockConn)
req, err := ParseRequest(buf)
require.NoError(t, err)
err = proxySrv.handleRequest(resp, req)
err = proxySrv.handleRequest(rsp, req)
require.NoError(t, err)
// Verify response
out := resp.buf.Bytes()
out := rsp.buf.Bytes()
expected := []byte{
5,
0,
0,
1,
127, 0, 0, 1,
0, 0,
statute.VersionSocks5, statute.RepSuccess, 0,
statute.ATYPIPv4, 127, 0, 0, 1, 0, 0,
'p', 'o', 'n', 'g',
}
@ -102,6 +97,7 @@ func TestRequest_Connect_RuleFail(t *testing.T) {
_, err = io.ReadAtLeast(conn, buf, 4)
require.NoError(t, err)
require.Equal(t, []byte("ping"), buf)
conn.Write([]byte("pong")) // nolint: errcheck
}()
lAddr := l.Addr().(*net.TCPAddr)
@ -115,33 +111,28 @@ func TestRequest_Connect_RuleFail(t *testing.T) {
}
// Create the connect request
buf := bytes.NewBuffer(nil)
buf.Write([]byte{5, 1, 0, 1, 127, 0, 0, 1})
port := []byte{0, 0}
binary.BigEndian.PutUint16(port, uint16(lAddr.Port))
buf.Write(port)
hi, lo := statute.BreakPort(lAddr.Port)
buf := bytes.NewBuffer([]byte{
statute.VersionSocks5, statute.CommandConnect, 0,
statute.ATYPIPv4, 127, 0, 0, 1, hi, lo,
})
// Send a ping
buf.Write([]byte("ping"))
// Handle the request
resp := &MockConn{}
req, err := NewRequest(buf)
rsp := new(MockConn)
req, err := ParseRequest(buf)
require.NoError(t, err)
err = s.handleRequest(resp, req)
err = s.handleRequest(rsp, req)
require.Contains(t, err.Error(), "blocked by rules")
// Verify response
out := resp.buf.Bytes()
out := rsp.buf.Bytes()
expected := []byte{
5,
2,
0,
1,
0, 0, 0, 0,
0, 0,
statute.VersionSocks5, statute.RepRuleFailure, 0,
statute.ATYPIPv4, 0, 0, 0, 0, 0, 0,
}
require.Equal(t, expected, out)
}

@ -3,6 +3,7 @@ package socks5
import (
"bufio"
"context"
"errors"
"fmt"
"io"
"io/ioutil"
@ -116,6 +117,7 @@ func (s *Server) ServeConn(conn net.Conn) error {
var authContext *AuthContext
defer conn.Close()
bufConn := bufio.NewReader(conn)
mr, err := statute.ParseMethodRequest(bufConn)
@ -133,24 +135,30 @@ func (s *Server) ServeConn(conn net.Conn) error {
}
// The client request detail
request, err := NewRequest(bufConn)
request, err := ParseRequest(bufConn)
if err != nil {
if err == statute.ErrUnrecognizedAddrType {
if err := SendReply(conn, statute.Request{Version: mr.Ver}, statute.RepAddrTypeNotSupported); err != nil {
if errors.Is(err, statute.ErrUnrecognizedAddrType) {
if err := SendReply(conn, statute.RepAddrTypeNotSupported, nil); err != nil {
return fmt.Errorf("failed to send reply %w", err)
}
}
return fmt.Errorf("failed to read destination address, %w", err)
}
if request.Request.Command != statute.CommandConnect &&
request.Request.Command != statute.CommandBind &&
request.Request.Command != statute.CommandAssociate {
if err := SendReply(conn, statute.RepCommandNotSupported, nil); err != nil {
return fmt.Errorf("failed to send reply, %v", err)
}
return fmt.Errorf("unrecognized command[%d]", request.Request.Command)
}
request.AuthContext = authContext
request.LocalAddr = conn.LocalAddr()
request.RemoteAddr = conn.RemoteAddr()
// Process the client request
if err := s.handleRequest(conn, request); err != nil {
return fmt.Errorf("failed to handle request, %v", err)
}
return nil
return s.handleRequest(conn, request)
}
// authenticate is used to handle connection authentication

@ -31,11 +31,12 @@ func TestSOCKS5_Connect(t *testing.T) {
_, err = io.ReadAtLeast(conn, buf, 4)
require.NoError(t, err)
assert.Equal(t, []byte("ping"), buf)
_, _ = conn.Write([]byte("pong"))
conn.Write([]byte("pong")) // nolint: errcheck
}()
lAddr := l.Addr().(*net.TCPAddr)
// Create a socks server
// Create a socks server with UserPass auth.
cator := UserPassAuthenticator{StaticCredentials{"foo": "bar"}}
srv := NewServer(
WithAuthMethods([]Authenticator{cator}),
@ -54,18 +55,20 @@ func TestSOCKS5_Connect(t *testing.T) {
require.NoError(t, err)
// Connect, auth and connec to local
req := new(bytes.Buffer)
req.Write([]byte{statute.VersionSocks5, 2, statute.MethodNoAuth, statute.MethodUserPassAuth})
req.Write([]byte{statute.UserPassAuthVersion, 3, 'f', 'o', 'o', 3, 'b', 'a', 'r'})
req := bytes.NewBuffer(
[]byte{
statute.VersionSocks5, 2, statute.MethodNoAuth, statute.MethodUserPassAuth, // methods
statute.UserPassAuthVersion, 3, 'f', 'o', 'o', 3, 'b', 'a', 'r', // userpass auth
})
reqHead := statute.Request{
Version: statute.VersionSocks5,
Command: statute.CommandConnect,
Reserved: 0,
DstAddress: statute.AddrSpec{
"",
net.ParseIP("127.0.0.1"),
lAddr.Port,
statute.ATYPIPv4,
FQDN: "",
IP: net.ParseIP("127.0.0.1"),
Port: lAddr.Port,
AddrType: statute.ATYPIPv4,
},
}
req.Write(reqHead.Bytes())
@ -73,11 +76,11 @@ func TestSOCKS5_Connect(t *testing.T) {
req.Write([]byte("ping"))
// Send all the bytes
conn.Write(req.Bytes())
conn.Write(req.Bytes()) // nolint: errcheck
// Verify response
expected := []byte{
statute.VersionSocks5, statute.MethodUserPassAuth, // use user password auth
statute.VersionSocks5, statute.MethodUserPassAuth, // response use UserPass auth
statute.UserPassAuthVersion, statute.AuthSuccess, // response auth success
}
rspHead := statute.Request{
@ -85,22 +88,19 @@ func TestSOCKS5_Connect(t *testing.T) {
Command: statute.RepSuccess,
Reserved: 0,
DstAddress: statute.AddrSpec{
"",
net.ParseIP("127.0.0.1"),
0, // Ignore the port
statute.ATYPIPv4,
FQDN: "",
IP: net.ParseIP("127.0.0.1"),
Port: 0,
AddrType: statute.ATYPIPv4,
},
}
expected = append(expected, rspHead.Bytes()...)
expected = append(expected, []byte("pong")...)
out := make([]byte, len(expected))
_ = conn.SetDeadline(time.Now().Add(time.Second))
conn.SetDeadline(time.Now().Add(time.Second)) // nolint: errcheck
_, err = io.ReadFull(conn, out)
require.NoError(t, err)
t.Logf("proxy bind port: %d", statute.BuildPort(out[12], out[13]))
// Ignore the port
out[12] = 0
out[13] = 0
@ -157,10 +157,10 @@ func TestSOCKS5_Associate(t *testing.T) {
Command: statute.CommandAssociate,
Reserved: 0,
DstAddress: statute.AddrSpec{
"",
locIP,
lAddr.Port,
statute.ATYPIPv4,
FQDN: "",
IP: locIP,
Port: lAddr.Port,
AddrType: statute.ATYPIPv4,
},
}
req.Write(reqHead.Bytes())
@ -216,11 +216,11 @@ func Test_SocksWithProxy(t *testing.T) {
require.NoError(t, err)
require.Equal(t, []byte("ping"), buf)
conn.Write([]byte("pong"))
conn.Write([]byte("pong")) // nolint: errcheck
}()
lAddr := l.Addr().(*net.TCPAddr)
// Create a socks server
// Create a socks server with UserPass auth.
cator := UserPassAuthenticator{StaticCredentials{"foo": "bar"}}
serv := NewServer(
WithAuthMethods([]Authenticator{cator}),
@ -245,14 +245,13 @@ func Test_SocksWithProxy(t *testing.T) {
conn.Write([]byte("ping")) // nolint: errcheck
out := make([]byte, 4)
_ = conn.SetDeadline(time.Now().Add(time.Second)) // nolint: errcheck
conn.SetDeadline(time.Now().Add(time.Second)) // nolint: errcheck
_, err = io.ReadFull(conn, out)
require.NoError(t, err)
require.Equal(t, []byte("pong"), out)
}
/***************************** auth *******************************/
/***************************** auth *******************************/
func TestNoAuth_Server(t *testing.T) {
req := bytes.NewBuffer(nil)

@ -4,119 +4,96 @@ import (
"errors"
"math"
"net"
"strconv"
)
/*
The SOCKS UDP request/response is formed as follows:
+-----+------+-------+----------+----------+----------+
| RSV | FRAG | ATYP | DST.ADDR | DST.PORT | DATA |
+-----+------+-------+----------+----------+----------+
| 2 | 1 | X'00' | Variable | 2 | Variable |
+-----+------+-------+----------+----------+----------+
*/
// Packet udp packet
type Packet struct {
// The SOCKS UDP request/response is formed as follows:
// +-----+------+-------+----------+----------+----------+
// | RSV | FRAG | ATYP | DST.ADDR | DST.PORT | DATA |
// +-----+------+-------+----------+----------+----------+
// | 2 | 1 | X'00' | Variable | 2 | Variable |
// +-----+------+-------+----------+----------+----------+
// Datagram udp packet
type Datagram struct {
RSV uint16
Frag uint8
ATYP uint8
DstAddr AddrSpec
Data []byte
}
// NewEmptyPacket new empty packet
func NewEmptyPacket() Packet {
return Packet{}
}
// NewPacket new packet with dest addr and data
func NewPacket(destAddr string, data []byte) (p Packet, err error) {
var host, port string
host, port, err = net.SplitHostPort(destAddr)
// NewDatagram new packet with dest addr and data
func NewDatagram(destAddr string, data []byte) (p Datagram, err error) {
p.DstAddr, err = ParseAddrSpec(destAddr)
if err != nil {
return
}
p.DstAddr.Port, err = strconv.Atoi(port)
if err != nil {
if p.DstAddr.AddrType == ATYPDomain && len(p.DstAddr.FQDN) > math.MaxUint8 {
err = errors.New("destination host name too long")
return
}
p.RSV = 0
p.Frag = 0
if ip := net.ParseIP(host); ip != nil {
if ip4 := ip.To4(); ip4 != nil {
p.ATYP = ATYPIPv4
p.DstAddr.IP = ip
} else {
p.ATYP = ATYPIPv6
p.DstAddr.IP = ip
}
} else {
if len(host) > math.MaxUint8 {
err = errors.New("destination host name too long")
return
}
p.ATYP = ATYPDomain
p.DstAddr.FQDN = host
}
p.Data = data
return
}
// ParseRequest parse to packet
func (sf *Packet) Parse(b []byte) error {
if len(b) <= 4+net.IPv4len+2 { // no data
return errors.New("too short")
func ParseDatagram(b []byte) (da Datagram, err error) {
if len(b) < 4+net.IPv4len+2 { // no data
err = errors.New("datagram to short")
return
}
// ignore RSV
sf.RSV = 0
da.RSV = 0
// FRAG
sf.Frag = b[2]
sf.ATYP = b[3]
da.Frag = b[2]
da.DstAddr.AddrType = b[3]
headLen := 4
switch sf.ATYP {
switch da.DstAddr.AddrType {
case ATYPIPv4:
headLen += net.IPv4len + 2
sf.DstAddr.IP = net.IPv4(b[4], b[5], b[6], b[7])
sf.DstAddr.Port = BuildPort(b[4+net.IPv4len], b[4+net.IPv4len+1])
da.DstAddr.IP = net.IPv4(b[4], b[5], b[6], b[7])
da.DstAddr.Port = BuildPort(b[4+net.IPv4len], b[4+net.IPv4len+1])
case ATYPIPv6:
headLen += net.IPv6len + 2
if len(b) <= headLen {
return errors.New("too short")
err = errors.New("datagram to short")
return
}
sf.DstAddr.IP = net.IP{b[4], b[5], b[6], b[7], b[8], b[9], b[10], b[11], b[12], b[13], b[14], b[15], b[16], b[17], b[18], b[19]}
sf.DstAddr.Port = BuildPort(b[4+net.IPv6len], b[4+net.IPv6len+1])
da.DstAddr.IP = net.IP{b[4], b[5], b[6], b[7], b[8], b[9], b[10], b[11], b[12], b[13], b[14], b[15], b[16], b[17], b[18], b[19]}
da.DstAddr.Port = BuildPort(b[4+net.IPv6len], b[4+net.IPv6len+1])
case ATYPDomain:
addrLen := int(b[4])
headLen += 1 + addrLen + 2
if len(b) <= headLen {
return errors.New("too short")
err = errors.New("datagram to short")
return
}
str := make([]byte, addrLen)
copy(str, b[5:5+addrLen])
sf.DstAddr.FQDN = string(str)
sf.DstAddr.Port = BuildPort(b[5+addrLen], b[5+addrLen+1])
da.DstAddr.FQDN = string(str)
da.DstAddr.Port = BuildPort(b[5+addrLen], b[5+addrLen+1])
default:
return ErrUnrecognizedAddrType
err = ErrUnrecognizedAddrType
return
}
sf.Data = b[headLen:]
return nil
da.Data = b[headLen:]
return
}
// Request returns s slice of packet reply
func (sf *Packet) Header() []byte {
// Request returns s slice of datagram header except data
func (sf *Datagram) Header() []byte {
bs := make([]byte, 0, 32)
bs = append(bs, []byte{byte(sf.RSV << 8), byte(sf.RSV), sf.Frag}...)
switch sf.ATYP {
switch sf.DstAddr.AddrType {
case ATYPIPv4:
bs = append(bs, ATYPIPv4)
bs = append(bs, sf.DstAddr.IP...)
bs = append(bs, sf.DstAddr.IP.To4()...)
case ATYPIPv6:
bs = append(bs, ATYPIPv6)
bs = append(bs, sf.DstAddr.IP...)
bs = append(bs, sf.DstAddr.IP.To16()...)
case ATYPDomain:
bs = append(bs, ATYPDomain)
bs = append(bs, ATYPDomain, byte(len(sf.DstAddr.FQDN)))
bs = append(bs, []byte(sf.DstAddr.FQDN)...)
}
hi, lo := BreakPort(sf.DstAddr.Port)
@ -124,6 +101,6 @@ func (sf *Packet) Header() []byte {
return bs
}
func (sf *Packet) Bytes() []byte {
func (sf *Datagram) Bytes() []byte {
return append(sf.Header(), sf.Data...)
}

145
statute/datagram_test.go Normal file

@ -0,0 +1,145 @@
package statute
import (
"net"
"reflect"
"testing"
"github.com/stretchr/testify/require"
)
func TestDatagram(t *testing.T) {
_, err := NewDatagram("localhost", nil)
require.Error(t, err)
_, err = NewDatagram("localhostlocalhostlocalhostlocalhostlocalhostlocalhostlocalhostlocalhostlocalhostlocalhost"+
"localhostlocalhostlocalhostlocalhostlocalhostlocalhostlocalhostlocalhostlocalhostlocalhost"+
"localhostlocalhostlocalhostlocalhostlocalhostlocalhostlocalhostlocalhostlocalhostlocalhost:8080", nil)
require.Error(t, err)
datagram, err := NewDatagram("localhost:8080", []byte{1, 2, 3})
require.NoError(t, err)
require.Equal(t, Datagram{
0, 0, AddrSpec{
FQDN: "localhost",
Port: 8080,
AddrType: ATYPDomain,
},
[]byte{1, 2, 3},
}, datagram)
require.Equal(t, []byte{0, 0, 0, ATYPDomain, 9, 'l', 'o', 'c', 'a', 'l', 'h', 'o', 's', 't', 0x1f, 0x90}, datagram.Header())
require.Equal(t, []byte{0, 0, 0, ATYPDomain, 9, 'l', 'o', 'c', 'a', 'l', 'h', 'o', 's', 't', 0x1f, 0x90, 1, 2, 3}, datagram.Bytes())
datagram, err = NewDatagram("127.0.0.1:8080", []byte{1, 2, 3})
require.NoError(t, err)
require.Equal(t, Datagram{
0, 0, AddrSpec{
IP: net.IPv4(127, 0, 0, 1),
Port: 8080,
AddrType: ATYPIPv4,
},
[]byte{1, 2, 3},
}, datagram)
require.Equal(t, []byte{0, 0, 0, ATYPIPv4, 127, 0, 0, 1, 0x1f, 0x90}, datagram.Header())
require.Equal(t, []byte{0, 0, 0, ATYPIPv4, 127, 0, 0, 1, 0x1f, 0x90, 1, 2, 3}, datagram.Bytes())
datagram, err = NewDatagram("[::1]:8080", []byte{1, 2, 3})
require.NoError(t, err)
require.Equal(t, Datagram{
0, 0, AddrSpec{
IP: net.IPv6loopback,
Port: 8080,
AddrType: ATYPIPv6,
},
[]byte{1, 2, 3},
}, datagram)
require.Equal(t, []byte{0, 0, 0, ATYPIPv6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0x1f, 0x90}, datagram.Header())
require.Equal(t, []byte{0, 0, 0, ATYPIPv6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0x1f, 0x90, 1, 2, 3}, datagram.Bytes())
}
func TestParseDatagram(t *testing.T) {
type args struct {
b []byte
}
tests := []struct {
name string
args args
wantDa Datagram
wantErr bool
}{
{
"IPv4",
args{[]byte{0, 0, 0, ATYPIPv4, 127, 0, 0, 1, 0x1f, 0x90, 1, 2, 3}},
Datagram{
0, 0, AddrSpec{
IP: net.IPv4(127, 0, 0, 1),
Port: 8080,
AddrType: ATYPIPv4,
},
[]byte{1, 2, 3},
},
false,
},
{
"IPv6",
args{[]byte{0, 0, 0, ATYPIPv6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0x1f, 0x90, 1, 2, 3}},
Datagram{
0, 0, AddrSpec{
IP: net.IPv6loopback,
Port: 8080,
AddrType: ATYPIPv6,
},
[]byte{1, 2, 3},
},
false,
},
{
"FQDN",
args{[]byte{0, 0, 0, ATYPDomain, 9, 'l', 'o', 'c', 'a', 'l', 'h', 'o', 's', 't', 0x1f, 0x90, 1, 2, 3}},
Datagram{
0, 0, AddrSpec{
FQDN: "localhost",
Port: 8080,
AddrType: ATYPDomain,
},
[]byte{1, 2, 3},
},
false,
},
{
"invalid address type",
args{[]byte{0, 0, 0, 0x02, 127, 0, 0, 1, 0x1f, 0x90}},
Datagram{},
true,
},
{
"less min length",
args{[]byte{0, 0, 0, ATYPIPv4, 127, 0, 0, 1, 0x1f}},
Datagram{},
true,
},
{
"less domain length",
args{[]byte{0, 0, 0, ATYPDomain, 10, 127, 0, 0, 1, 0x1f, 0x09}},
Datagram{},
true,
},
{
"less ipv6 length",
args{[]byte{0, 0, 0, ATYPIPv6, 127, 0, 0, 1, 0x1f, 0x09}},
Datagram{},
true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotDa, err := ParseDatagram(tt.args.b)
if (err != nil) != tt.wantErr {
t.Errorf("ParseDatagram() error = %v, wantErr %v", err, tt.wantErr)
return
}
if err == nil && !reflect.DeepEqual(gotDa, tt.wantDa) {
t.Errorf("ParseDatagram() gotDa = %v, want %v", gotDa, tt.wantDa)
}
})
}
}