From 35e543fcfbf1b62911b97195bca50321d41366a9 Mon Sep 17 00:00:00 2001 From: mo Date: Thu, 6 Aug 2020 08:19:31 +0800 Subject: [PATCH] add datagram test rename request and response api name --- client/client.go | 5 +- request.go | 108 ++++++++++++----------------- request_test.go | 63 ++++++++--------- server.go | 22 ++++-- server_test.go | 55 ++++++++------- statute/datagram.go | 107 ++++++++++++----------------- statute/datagram_test.go | 145 +++++++++++++++++++++++++++++++++++++++ 7 files changed, 301 insertions(+), 204 deletions(-) create mode 100644 statute/datagram_test.go diff --git a/client/client.go b/client/client.go index 5b8cae6..0abaaf8 100644 --- a/client/client.go +++ b/client/client.go @@ -88,8 +88,7 @@ func (c *Client) Read(b []byte) (int, error) { if err != nil { return 0, err } - pkt := statute.Packet{} - err = pkt.Parse(b1[:n]) + pkt, err := statute.ParseDatagram(b1[:n]) if err != nil { return 0, err } @@ -101,7 +100,7 @@ func (c *Client) Write(b []byte) (int, error) { if c.UDPConn == nil { return c.TCPConn.Write(b) } - pkt, err := statute.NewPacket(c.RemoteAddress.String(), b) + pkt, err := statute.NewDatagram(c.RemoteAddress.String(), b) if err != nil { return 0, err } diff --git a/request.go b/request.go index 6dd3697..b37cf82 100644 --- a/request.go +++ b/request.go @@ -33,15 +33,12 @@ type Request struct { RawDestAddr *statute.AddrSpec } -// NewRequest creates a new Request from the tcp connection -func NewRequest(bufConn io.Reader) (*Request, error) { +// ParseRequest creates a new Request from the tcp connection +func ParseRequest(bufConn io.Reader) (*Request, error) { hd, err := statute.ParseRequest(bufConn) if err != nil { return nil, err } - if hd.Command != statute.CommandConnect && hd.Command != statute.CommandBind && hd.Command != statute.CommandAssociate { - return nil, fmt.Errorf("unrecognized command[%d]", hd.Command) - } return &Request{ Request: hd, RawDestAddr: &hd.DstAddress, @@ -51,20 +48,19 @@ func NewRequest(bufConn io.Reader) (*Request, error) { // handleRequest is used for request processing after authentication func (s *Server) handleRequest(write io.Writer, req *Request) error { + var err error ctx := context.Background() // Resolve the address if we have a FQDN dest := req.RawDestAddr if dest.FQDN != "" { - _ctx, addr, err := s.resolver.Resolve(ctx, dest.FQDN) + ctx, dest.IP, err = s.resolver.Resolve(ctx, dest.FQDN) if err != nil { - if err := SendReply(write, req.Request, statute.RepHostUnreachable); err != nil { + if err := SendReply(write, statute.RepHostUnreachable, nil); err != nil { return fmt.Errorf("failed to send reply, %v", err) } return fmt.Errorf("failed to resolve destination[%v], %v", dest.FQDN, err) } - ctx = _ctx - dest.IP = addr } // Apply any address rewrites @@ -74,14 +70,14 @@ func (s *Server) handleRequest(write io.Writer, req *Request) error { } // Check if this is allowed - _ctx, ok := s.rules.Allow(ctx, req) + var ok bool + ctx, ok = s.rules.Allow(ctx, req) if !ok { - if err := SendReply(write, req.Request, statute.RepRuleFailure); err != nil { + if err := SendReply(write, statute.RepRuleFailure, nil); err != nil { return fmt.Errorf("failed to send reply, %v", err) } return fmt.Errorf("bind to %v blocked by rules", req.RawDestAddr) } - ctx = _ctx // Switch on the command switch req.Command { @@ -101,7 +97,7 @@ func (s *Server) handleRequest(write io.Writer, req *Request) error { } return s.handleAssociate(ctx, write, req) default: - if err := SendReply(write, req.Request, statute.RepCommandNotSupported); err != nil { + if err := SendReply(write, statute.RepCommandNotSupported, nil); err != nil { return fmt.Errorf("failed to send reply, %v", err) } return fmt.Errorf("unsupported command[%v]", req.Command) @@ -126,7 +122,7 @@ func (s *Server) handleConnect(ctx context.Context, writer io.Writer, request *R } else if strings.Contains(msg, "network is unreachable") { resp = statute.RepNetworkUnreachable } - if err := SendReply(writer, request.Request, resp); err != nil { + if err := SendReply(writer, resp, nil); err != nil { return fmt.Errorf("failed to send reply, %v", err) } return fmt.Errorf("connect to %v failed, %v", request.RawDestAddr, err) @@ -134,7 +130,7 @@ func (s *Server) handleConnect(ctx context.Context, writer io.Writer, request *R defer target.Close() // Send success - if err := SendReply(writer, request.Request, statute.RepSuccess, target.LocalAddr()); err != nil { + if err := SendReply(writer, statute.RepSuccess, target.LocalAddr()); err != nil { return fmt.Errorf("failed to send reply, %v", err) } @@ -156,7 +152,7 @@ func (s *Server) handleConnect(ctx context.Context, writer io.Writer, request *R // handleBind is used to handle a connect command func (s *Server) handleBind(_ context.Context, writer io.Writer, request *Request) error { // TODO: Support bind - if err := SendReply(writer, request.Request, statute.RepCommandNotSupported); err != nil { + if err := SendReply(writer, statute.RepCommandNotSupported, nil); err != nil { return fmt.Errorf("failed to send reply: %v", err) } return nil @@ -180,7 +176,7 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, request } else if strings.Contains(msg, "network is unreachable") { resp = statute.RepNetworkUnreachable } - if err := SendReply(writer, request.Request, resp); err != nil { + if err := SendReply(writer, resp, nil); err != nil { return fmt.Errorf("failed to send reply, %v", err) } return fmt.Errorf("connect to %v failed, %v", request.RawDestAddr, err) @@ -189,7 +185,7 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, request targetUDP, ok := target.(*net.UDPConn) if !ok { - if err := SendReply(writer, request.Request, statute.RepServerFailure); err != nil { + if err := SendReply(writer, statute.RepServerFailure, nil); err != nil { return fmt.Errorf("failed to send reply, %v", err) } return fmt.Errorf("dial udp invalid") @@ -197,7 +193,7 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, request bindLn, err := net.ListenUDP("udp", nil) if err != nil { - if err := SendReply(writer, request.Request, statute.RepServerFailure); err != nil { + if err := SendReply(writer, statute.RepServerFailure, nil); err != nil { return fmt.Errorf("failed to send reply, %v", err) } return fmt.Errorf("listen udp failed, %v", err) @@ -206,19 +202,11 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, request s.logger.Errorf("target addr %v, listen addr: %s", targetUDP.RemoteAddr(), bindLn.LocalAddr()) // send BND.ADDR and BND.PORT, client must - if err = SendReply(writer, request.Request, statute.RepSuccess, bindLn.LocalAddr()); err != nil { + if err = SendReply(writer, statute.RepSuccess, bindLn.LocalAddr()); err != nil { return fmt.Errorf("failed to send reply, %v", err) } s.submit(func() { - /* - The SOCKS UDP request/response is formed as follows: - +-----+------+-------+----------+----------+----------+ - | RSV | FRAG | ATYP | DST.ADDR | DST.PORT | DATA | - +-----+------+-------+----------+----------+----------+ - | 2 | 1 | X'00' | Variable | 2 | Variable | - +-----+------+-------+----------+----------+----------+ - */ // read from client and write to remote server conns := sync.Map{} bufPool := s.bufferPool.Get() @@ -237,8 +225,8 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, request continue } - pk := statute.Packet{} - if err := pk.Parse(bufPool[:n]); err != nil { + pk, err := statute.ParseDatagram(bufPool[:n]) + if err != nil { continue } @@ -260,7 +248,7 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, request return } - pkb, err := statute.NewPacket(remote.String(), buf[:n]) + pkb, err := statute.NewDatagram(remote.String(), buf[:n]) if err != nil { continue } @@ -301,45 +289,35 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, request } // SendReply is used to send a reply message -func SendReply(w io.Writer, head statute.Request, resp uint8, bindAddr ...net.Addr) error { - head.Command = resp +func SendReply(w io.Writer, resp uint8, bindAddr net.Addr) error { + rsp := statute.Reply{ + Version: statute.VersionSocks5, + Response: resp, + BndAddress: statute.AddrSpec{ + AddrType: statute.ATYPIPv4, + IP: net.IPv4zero, + Port: 0, + }, + } - if len(bindAddr) == 0 { - head.DstAddress.AddrType = statute.ATYPIPv4 - head.DstAddress.IP = []byte{0, 0, 0, 0} - head.DstAddress.Port = 0 - } else { - addrSpec := statute.AddrSpec{} - if tcpAddr, ok := bindAddr[0].(*net.TCPAddr); ok && tcpAddr != nil { - addrSpec.IP = tcpAddr.IP - addrSpec.Port = tcpAddr.Port - } else if udpAddr, ok := bindAddr[0].(*net.UDPAddr); ok && udpAddr != nil { - addrSpec.IP = udpAddr.IP - addrSpec.Port = udpAddr.Port + if rsp.Response == statute.RepSuccess { + if tcpAddr, ok := bindAddr.(*net.TCPAddr); ok && tcpAddr != nil { + rsp.BndAddress.IP = tcpAddr.IP + rsp.BndAddress.Port = tcpAddr.Port + } else if udpAddr, ok := bindAddr.(*net.UDPAddr); ok && udpAddr != nil { + rsp.BndAddress.IP = udpAddr.IP + rsp.BndAddress.Port = udpAddr.Port } else { - addrSpec.IP = []byte{0, 0, 0, 0} - addrSpec.Port = 0 + rsp.Response = statute.RepAddrTypeNotSupported } - switch { - case addrSpec.FQDN != "": - head.DstAddress.AddrType = statute.ATYPDomain - head.DstAddress.FQDN = addrSpec.FQDN - head.DstAddress.Port = addrSpec.Port - case addrSpec.IP.To4() != nil: - head.DstAddress.AddrType = statute.ATYPIPv4 - head.DstAddress.IP = addrSpec.IP.To4() - head.DstAddress.Port = addrSpec.Port - case addrSpec.IP.To16() != nil: - head.DstAddress.AddrType = statute.ATYPIPv6 - head.DstAddress.IP = addrSpec.IP.To16() - head.DstAddress.Port = addrSpec.Port - default: - return fmt.Errorf("failed to format address[%v]", bindAddr) + if rsp.BndAddress.IP.To4() != nil { + rsp.BndAddress.AddrType = statute.ATYPIPv4 + } else if rsp.BndAddress.IP.To16() != nil { + rsp.BndAddress.AddrType = statute.ATYPIPv6 } - } // Send the message - _, err := w.Write(head.Bytes()) + _, err := w.Write(rsp.Bytes()) return err } @@ -354,7 +332,7 @@ func (s *Server) Proxy(dst io.Writer, src io.Reader) error { defer s.bufferPool.Put(buf) _, err := io.CopyBuffer(dst, src, buf[:cap(buf)]) if tcpConn, ok := dst.(closeWriter); ok { - tcpConn.CloseWrite() + tcpConn.CloseWrite() // nolint: errcheck } return err } diff --git a/request_test.go b/request_test.go index 1879277..5def74e 100644 --- a/request_test.go +++ b/request_test.go @@ -2,7 +2,6 @@ package socks5 import ( "bytes" - "encoding/binary" "io" "log" "net" @@ -10,6 +9,8 @@ import ( "testing" "github.com/stretchr/testify/require" + + "github.com/thinkgos/go-socks5/statute" ) type MockConn struct { @@ -32,7 +33,7 @@ func TestRequest_Connect(t *testing.T) { go func() { conn, err := l.Accept() require.NoError(t, err) - defer conn.Close() + defer conn.Close() // nolint: errcheck buf := make([]byte, 4) _, err = io.ReadAtLeast(conn, buf, 4) @@ -43,7 +44,7 @@ func TestRequest_Connect(t *testing.T) { }() lAddr := l.Addr().(*net.TCPAddr) - // Make server + // Make proxy server proxySrv := &Server{ rules: NewPermitAll(), resolver: DNSResolver{}, @@ -52,33 +53,27 @@ func TestRequest_Connect(t *testing.T) { } // Create the connect request - buf := bytes.NewBuffer(nil) - buf.Write([]byte{5, 1, 0, 1, 127, 0, 0, 1}) // nolint: errcheck - - port := []byte{0, 0} - binary.BigEndian.PutUint16(port, uint16(lAddr.Port)) - buf.Write(port) // nolint: errcheck - + hi, lo := statute.BreakPort(lAddr.Port) + buf := bytes.NewBuffer([]byte{ + statute.VersionSocks5, statute.CommandConnect, 0, + statute.ATYPIPv4, 127, 0, 0, 1, hi, lo, + }) // Send a ping buf.Write([]byte("ping")) // nolint: errcheck // Handle the request - resp := &MockConn{} - req, err := NewRequest(buf) + rsp := new(MockConn) + req, err := ParseRequest(buf) require.NoError(t, err) - err = proxySrv.handleRequest(resp, req) + err = proxySrv.handleRequest(rsp, req) require.NoError(t, err) // Verify response - out := resp.buf.Bytes() + out := rsp.buf.Bytes() expected := []byte{ - 5, - 0, - 0, - 1, - 127, 0, 0, 1, - 0, 0, + statute.VersionSocks5, statute.RepSuccess, 0, + statute.ATYPIPv4, 127, 0, 0, 1, 0, 0, 'p', 'o', 'n', 'g', } @@ -102,6 +97,7 @@ func TestRequest_Connect_RuleFail(t *testing.T) { _, err = io.ReadAtLeast(conn, buf, 4) require.NoError(t, err) require.Equal(t, []byte("ping"), buf) + conn.Write([]byte("pong")) // nolint: errcheck }() lAddr := l.Addr().(*net.TCPAddr) @@ -115,33 +111,28 @@ func TestRequest_Connect_RuleFail(t *testing.T) { } // Create the connect request - buf := bytes.NewBuffer(nil) - buf.Write([]byte{5, 1, 0, 1, 127, 0, 0, 1}) - - port := []byte{0, 0} - binary.BigEndian.PutUint16(port, uint16(lAddr.Port)) - buf.Write(port) + hi, lo := statute.BreakPort(lAddr.Port) + buf := bytes.NewBuffer([]byte{ + statute.VersionSocks5, statute.CommandConnect, 0, + statute.ATYPIPv4, 127, 0, 0, 1, hi, lo, + }) // Send a ping buf.Write([]byte("ping")) // Handle the request - resp := &MockConn{} - req, err := NewRequest(buf) + rsp := new(MockConn) + req, err := ParseRequest(buf) require.NoError(t, err) - err = s.handleRequest(resp, req) + err = s.handleRequest(rsp, req) require.Contains(t, err.Error(), "blocked by rules") // Verify response - out := resp.buf.Bytes() + out := rsp.buf.Bytes() expected := []byte{ - 5, - 2, - 0, - 1, - 0, 0, 0, 0, - 0, 0, + statute.VersionSocks5, statute.RepRuleFailure, 0, + statute.ATYPIPv4, 0, 0, 0, 0, 0, 0, } require.Equal(t, expected, out) } diff --git a/server.go b/server.go index 19233ff..5ead3a1 100644 --- a/server.go +++ b/server.go @@ -3,6 +3,7 @@ package socks5 import ( "bufio" "context" + "errors" "fmt" "io" "io/ioutil" @@ -116,6 +117,7 @@ func (s *Server) ServeConn(conn net.Conn) error { var authContext *AuthContext defer conn.Close() + bufConn := bufio.NewReader(conn) mr, err := statute.ParseMethodRequest(bufConn) @@ -133,24 +135,30 @@ func (s *Server) ServeConn(conn net.Conn) error { } // The client request detail - request, err := NewRequest(bufConn) + request, err := ParseRequest(bufConn) if err != nil { - if err == statute.ErrUnrecognizedAddrType { - if err := SendReply(conn, statute.Request{Version: mr.Ver}, statute.RepAddrTypeNotSupported); err != nil { + if errors.Is(err, statute.ErrUnrecognizedAddrType) { + if err := SendReply(conn, statute.RepAddrTypeNotSupported, nil); err != nil { return fmt.Errorf("failed to send reply %w", err) } } return fmt.Errorf("failed to read destination address, %w", err) } + if request.Request.Command != statute.CommandConnect && + request.Request.Command != statute.CommandBind && + request.Request.Command != statute.CommandAssociate { + if err := SendReply(conn, statute.RepCommandNotSupported, nil); err != nil { + return fmt.Errorf("failed to send reply, %v", err) + } + return fmt.Errorf("unrecognized command[%d]", request.Request.Command) + } + request.AuthContext = authContext request.LocalAddr = conn.LocalAddr() request.RemoteAddr = conn.RemoteAddr() // Process the client request - if err := s.handleRequest(conn, request); err != nil { - return fmt.Errorf("failed to handle request, %v", err) - } - return nil + return s.handleRequest(conn, request) } // authenticate is used to handle connection authentication diff --git a/server_test.go b/server_test.go index 13f2634..cfab349 100644 --- a/server_test.go +++ b/server_test.go @@ -31,11 +31,12 @@ func TestSOCKS5_Connect(t *testing.T) { _, err = io.ReadAtLeast(conn, buf, 4) require.NoError(t, err) assert.Equal(t, []byte("ping"), buf) - _, _ = conn.Write([]byte("pong")) + + conn.Write([]byte("pong")) // nolint: errcheck }() lAddr := l.Addr().(*net.TCPAddr) - // Create a socks server + // Create a socks server with UserPass auth. cator := UserPassAuthenticator{StaticCredentials{"foo": "bar"}} srv := NewServer( WithAuthMethods([]Authenticator{cator}), @@ -54,18 +55,20 @@ func TestSOCKS5_Connect(t *testing.T) { require.NoError(t, err) // Connect, auth and connec to local - req := new(bytes.Buffer) - req.Write([]byte{statute.VersionSocks5, 2, statute.MethodNoAuth, statute.MethodUserPassAuth}) - req.Write([]byte{statute.UserPassAuthVersion, 3, 'f', 'o', 'o', 3, 'b', 'a', 'r'}) + req := bytes.NewBuffer( + []byte{ + statute.VersionSocks5, 2, statute.MethodNoAuth, statute.MethodUserPassAuth, // methods + statute.UserPassAuthVersion, 3, 'f', 'o', 'o', 3, 'b', 'a', 'r', // userpass auth + }) reqHead := statute.Request{ Version: statute.VersionSocks5, Command: statute.CommandConnect, Reserved: 0, DstAddress: statute.AddrSpec{ - "", - net.ParseIP("127.0.0.1"), - lAddr.Port, - statute.ATYPIPv4, + FQDN: "", + IP: net.ParseIP("127.0.0.1"), + Port: lAddr.Port, + AddrType: statute.ATYPIPv4, }, } req.Write(reqHead.Bytes()) @@ -73,11 +76,11 @@ func TestSOCKS5_Connect(t *testing.T) { req.Write([]byte("ping")) // Send all the bytes - conn.Write(req.Bytes()) + conn.Write(req.Bytes()) // nolint: errcheck // Verify response expected := []byte{ - statute.VersionSocks5, statute.MethodUserPassAuth, // use user password auth + statute.VersionSocks5, statute.MethodUserPassAuth, // response use UserPass auth statute.UserPassAuthVersion, statute.AuthSuccess, // response auth success } rspHead := statute.Request{ @@ -85,22 +88,19 @@ func TestSOCKS5_Connect(t *testing.T) { Command: statute.RepSuccess, Reserved: 0, DstAddress: statute.AddrSpec{ - "", - net.ParseIP("127.0.0.1"), - 0, // Ignore the port - statute.ATYPIPv4, + FQDN: "", + IP: net.ParseIP("127.0.0.1"), + Port: 0, + AddrType: statute.ATYPIPv4, }, } expected = append(expected, rspHead.Bytes()...) expected = append(expected, []byte("pong")...) out := make([]byte, len(expected)) - _ = conn.SetDeadline(time.Now().Add(time.Second)) + conn.SetDeadline(time.Now().Add(time.Second)) // nolint: errcheck _, err = io.ReadFull(conn, out) require.NoError(t, err) - - t.Logf("proxy bind port: %d", statute.BuildPort(out[12], out[13])) - // Ignore the port out[12] = 0 out[13] = 0 @@ -157,10 +157,10 @@ func TestSOCKS5_Associate(t *testing.T) { Command: statute.CommandAssociate, Reserved: 0, DstAddress: statute.AddrSpec{ - "", - locIP, - lAddr.Port, - statute.ATYPIPv4, + FQDN: "", + IP: locIP, + Port: lAddr.Port, + AddrType: statute.ATYPIPv4, }, } req.Write(reqHead.Bytes()) @@ -216,11 +216,11 @@ func Test_SocksWithProxy(t *testing.T) { require.NoError(t, err) require.Equal(t, []byte("ping"), buf) - conn.Write([]byte("pong")) + conn.Write([]byte("pong")) // nolint: errcheck }() lAddr := l.Addr().(*net.TCPAddr) - // Create a socks server + // Create a socks server with UserPass auth. cator := UserPassAuthenticator{StaticCredentials{"foo": "bar"}} serv := NewServer( WithAuthMethods([]Authenticator{cator}), @@ -245,14 +245,13 @@ func Test_SocksWithProxy(t *testing.T) { conn.Write([]byte("ping")) // nolint: errcheck out := make([]byte, 4) - _ = conn.SetDeadline(time.Now().Add(time.Second)) // nolint: errcheck + conn.SetDeadline(time.Now().Add(time.Second)) // nolint: errcheck _, err = io.ReadFull(conn, out) require.NoError(t, err) - require.Equal(t, []byte("pong"), out) } -/***************************** auth *******************************/ +/***************************** auth *******************************/ func TestNoAuth_Server(t *testing.T) { req := bytes.NewBuffer(nil) diff --git a/statute/datagram.go b/statute/datagram.go index d77a27d..abb515d 100644 --- a/statute/datagram.go +++ b/statute/datagram.go @@ -4,119 +4,96 @@ import ( "errors" "math" "net" - "strconv" ) -/* - The SOCKS UDP request/response is formed as follows: - +-----+------+-------+----------+----------+----------+ - | RSV | FRAG | ATYP | DST.ADDR | DST.PORT | DATA | - +-----+------+-------+----------+----------+----------+ - | 2 | 1 | X'00' | Variable | 2 | Variable | - +-----+------+-------+----------+----------+----------+ -*/ -// Packet udp packet -type Packet struct { +// The SOCKS UDP request/response is formed as follows: +// +-----+------+-------+----------+----------+----------+ +// | RSV | FRAG | ATYP | DST.ADDR | DST.PORT | DATA | +// +-----+------+-------+----------+----------+----------+ +// | 2 | 1 | X'00' | Variable | 2 | Variable | +// +-----+------+-------+----------+----------+----------+ +// Datagram udp packet +type Datagram struct { RSV uint16 Frag uint8 - ATYP uint8 DstAddr AddrSpec Data []byte } -// NewEmptyPacket new empty packet -func NewEmptyPacket() Packet { - return Packet{} -} - -// NewPacket new packet with dest addr and data -func NewPacket(destAddr string, data []byte) (p Packet, err error) { - var host, port string - - host, port, err = net.SplitHostPort(destAddr) +// NewDatagram new packet with dest addr and data +func NewDatagram(destAddr string, data []byte) (p Datagram, err error) { + p.DstAddr, err = ParseAddrSpec(destAddr) if err != nil { return } - p.DstAddr.Port, err = strconv.Atoi(port) - if err != nil { + if p.DstAddr.AddrType == ATYPDomain && len(p.DstAddr.FQDN) > math.MaxUint8 { + err = errors.New("destination host name too long") return } p.RSV = 0 p.Frag = 0 - if ip := net.ParseIP(host); ip != nil { - if ip4 := ip.To4(); ip4 != nil { - p.ATYP = ATYPIPv4 - p.DstAddr.IP = ip - } else { - p.ATYP = ATYPIPv6 - p.DstAddr.IP = ip - } - } else { - if len(host) > math.MaxUint8 { - err = errors.New("destination host name too long") - return - } - p.ATYP = ATYPDomain - p.DstAddr.FQDN = host - } p.Data = data return } // ParseRequest parse to packet -func (sf *Packet) Parse(b []byte) error { - if len(b) <= 4+net.IPv4len+2 { // no data - return errors.New("too short") +func ParseDatagram(b []byte) (da Datagram, err error) { + if len(b) < 4+net.IPv4len+2 { // no data + err = errors.New("datagram to short") + return } // ignore RSV - sf.RSV = 0 + da.RSV = 0 // FRAG - sf.Frag = b[2] - sf.ATYP = b[3] + da.Frag = b[2] + da.DstAddr.AddrType = b[3] headLen := 4 - switch sf.ATYP { + switch da.DstAddr.AddrType { case ATYPIPv4: headLen += net.IPv4len + 2 - sf.DstAddr.IP = net.IPv4(b[4], b[5], b[6], b[7]) - sf.DstAddr.Port = BuildPort(b[4+net.IPv4len], b[4+net.IPv4len+1]) + da.DstAddr.IP = net.IPv4(b[4], b[5], b[6], b[7]) + da.DstAddr.Port = BuildPort(b[4+net.IPv4len], b[4+net.IPv4len+1]) case ATYPIPv6: headLen += net.IPv6len + 2 if len(b) <= headLen { - return errors.New("too short") + err = errors.New("datagram to short") + return } - sf.DstAddr.IP = net.IP{b[4], b[5], b[6], b[7], b[8], b[9], b[10], b[11], b[12], b[13], b[14], b[15], b[16], b[17], b[18], b[19]} - sf.DstAddr.Port = BuildPort(b[4+net.IPv6len], b[4+net.IPv6len+1]) + da.DstAddr.IP = net.IP{b[4], b[5], b[6], b[7], b[8], b[9], b[10], b[11], b[12], b[13], b[14], b[15], b[16], b[17], b[18], b[19]} + da.DstAddr.Port = BuildPort(b[4+net.IPv6len], b[4+net.IPv6len+1]) case ATYPDomain: addrLen := int(b[4]) headLen += 1 + addrLen + 2 if len(b) <= headLen { - return errors.New("too short") + err = errors.New("datagram to short") + return } str := make([]byte, addrLen) copy(str, b[5:5+addrLen]) - sf.DstAddr.FQDN = string(str) - sf.DstAddr.Port = BuildPort(b[5+addrLen], b[5+addrLen+1]) + da.DstAddr.FQDN = string(str) + da.DstAddr.Port = BuildPort(b[5+addrLen], b[5+addrLen+1]) default: - return ErrUnrecognizedAddrType + err = ErrUnrecognizedAddrType + return } - sf.Data = b[headLen:] - return nil + da.Data = b[headLen:] + return } -// Request returns s slice of packet reply -func (sf *Packet) Header() []byte { +// Request returns s slice of datagram header except data +func (sf *Datagram) Header() []byte { bs := make([]byte, 0, 32) bs = append(bs, []byte{byte(sf.RSV << 8), byte(sf.RSV), sf.Frag}...) - switch sf.ATYP { + switch sf.DstAddr.AddrType { case ATYPIPv4: bs = append(bs, ATYPIPv4) - bs = append(bs, sf.DstAddr.IP...) + bs = append(bs, sf.DstAddr.IP.To4()...) case ATYPIPv6: bs = append(bs, ATYPIPv6) - bs = append(bs, sf.DstAddr.IP...) + bs = append(bs, sf.DstAddr.IP.To16()...) case ATYPDomain: - bs = append(bs, ATYPDomain) + bs = append(bs, ATYPDomain, byte(len(sf.DstAddr.FQDN))) bs = append(bs, []byte(sf.DstAddr.FQDN)...) } hi, lo := BreakPort(sf.DstAddr.Port) @@ -124,6 +101,6 @@ func (sf *Packet) Header() []byte { return bs } -func (sf *Packet) Bytes() []byte { +func (sf *Datagram) Bytes() []byte { return append(sf.Header(), sf.Data...) } diff --git a/statute/datagram_test.go b/statute/datagram_test.go new file mode 100644 index 0000000..a849d63 --- /dev/null +++ b/statute/datagram_test.go @@ -0,0 +1,145 @@ +package statute + +import ( + "net" + "reflect" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestDatagram(t *testing.T) { + _, err := NewDatagram("localhost", nil) + require.Error(t, err) + + _, err = NewDatagram("localhostlocalhostlocalhostlocalhostlocalhostlocalhostlocalhostlocalhostlocalhostlocalhost"+ + "localhostlocalhostlocalhostlocalhostlocalhostlocalhostlocalhostlocalhostlocalhostlocalhost"+ + "localhostlocalhostlocalhostlocalhostlocalhostlocalhostlocalhostlocalhostlocalhostlocalhost:8080", nil) + require.Error(t, err) + + datagram, err := NewDatagram("localhost:8080", []byte{1, 2, 3}) + require.NoError(t, err) + require.Equal(t, Datagram{ + 0, 0, AddrSpec{ + FQDN: "localhost", + Port: 8080, + AddrType: ATYPDomain, + }, + []byte{1, 2, 3}, + }, datagram) + require.Equal(t, []byte{0, 0, 0, ATYPDomain, 9, 'l', 'o', 'c', 'a', 'l', 'h', 'o', 's', 't', 0x1f, 0x90}, datagram.Header()) + require.Equal(t, []byte{0, 0, 0, ATYPDomain, 9, 'l', 'o', 'c', 'a', 'l', 'h', 'o', 's', 't', 0x1f, 0x90, 1, 2, 3}, datagram.Bytes()) + + datagram, err = NewDatagram("127.0.0.1:8080", []byte{1, 2, 3}) + require.NoError(t, err) + require.Equal(t, Datagram{ + 0, 0, AddrSpec{ + IP: net.IPv4(127, 0, 0, 1), + Port: 8080, + AddrType: ATYPIPv4, + }, + []byte{1, 2, 3}, + }, datagram) + require.Equal(t, []byte{0, 0, 0, ATYPIPv4, 127, 0, 0, 1, 0x1f, 0x90}, datagram.Header()) + require.Equal(t, []byte{0, 0, 0, ATYPIPv4, 127, 0, 0, 1, 0x1f, 0x90, 1, 2, 3}, datagram.Bytes()) + datagram, err = NewDatagram("[::1]:8080", []byte{1, 2, 3}) + require.NoError(t, err) + require.Equal(t, Datagram{ + 0, 0, AddrSpec{ + IP: net.IPv6loopback, + Port: 8080, + AddrType: ATYPIPv6, + }, + []byte{1, 2, 3}, + }, datagram) + require.Equal(t, []byte{0, 0, 0, ATYPIPv6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0x1f, 0x90}, datagram.Header()) + require.Equal(t, []byte{0, 0, 0, ATYPIPv6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0x1f, 0x90, 1, 2, 3}, datagram.Bytes()) +} + +func TestParseDatagram(t *testing.T) { + type args struct { + b []byte + } + tests := []struct { + name string + args args + wantDa Datagram + wantErr bool + }{ + { + "IPv4", + args{[]byte{0, 0, 0, ATYPIPv4, 127, 0, 0, 1, 0x1f, 0x90, 1, 2, 3}}, + Datagram{ + 0, 0, AddrSpec{ + IP: net.IPv4(127, 0, 0, 1), + Port: 8080, + AddrType: ATYPIPv4, + }, + []byte{1, 2, 3}, + }, + false, + }, + { + "IPv6", + args{[]byte{0, 0, 0, ATYPIPv6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0x1f, 0x90, 1, 2, 3}}, + Datagram{ + 0, 0, AddrSpec{ + IP: net.IPv6loopback, + Port: 8080, + AddrType: ATYPIPv6, + }, + []byte{1, 2, 3}, + }, + false, + }, + { + "FQDN", + args{[]byte{0, 0, 0, ATYPDomain, 9, 'l', 'o', 'c', 'a', 'l', 'h', 'o', 's', 't', 0x1f, 0x90, 1, 2, 3}}, + Datagram{ + 0, 0, AddrSpec{ + FQDN: "localhost", + Port: 8080, + AddrType: ATYPDomain, + }, + []byte{1, 2, 3}, + }, + false, + }, + { + "invalid address type", + args{[]byte{0, 0, 0, 0x02, 127, 0, 0, 1, 0x1f, 0x90}}, + Datagram{}, + true, + }, + { + "less min length", + args{[]byte{0, 0, 0, ATYPIPv4, 127, 0, 0, 1, 0x1f}}, + Datagram{}, + true, + }, + { + "less domain length", + args{[]byte{0, 0, 0, ATYPDomain, 10, 127, 0, 0, 1, 0x1f, 0x09}}, + Datagram{}, + true, + }, + { + "less ipv6 length", + args{[]byte{0, 0, 0, ATYPIPv6, 127, 0, 0, 1, 0x1f, 0x09}}, + Datagram{}, + true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotDa, err := ParseDatagram(tt.args.b) + if (err != nil) != tt.wantErr { + t.Errorf("ParseDatagram() error = %v, wantErr %v", err, tt.wantErr) + return + } + if err == nil && !reflect.DeepEqual(gotDa, tt.wantDa) { + t.Errorf("ParseDatagram() gotDa = %v, want %v", gotDa, tt.wantDa) + } + }) + } +}