diff --git a/packet.go b/packet.go index db70f51..f2f3dde 100644 --- a/packet.go +++ b/packet.go @@ -44,10 +44,10 @@ func NewPacket(destAddr string, data []byte) (p Packet, err error) { if ip := net.ParseIP(host); ip != nil { if ip4 := ip.To4(); ip4 != nil { p.ATYP = ATYPIPv4 - p.DstAddr.IP = ip4 + p.DstAddr.IP = ip } else { p.ATYP = ATYPIPV6 - p.DstAddr.IP = ip.To16() + p.DstAddr.IP = ip } } else { if len(host) > math.MaxUint8 { @@ -57,6 +57,7 @@ func NewPacket(destAddr string, data []byte) (p Packet, err error) { p.ATYP = ATYPDomain p.DstAddr.FQDN = host } + p.Data = data return } @@ -69,12 +70,15 @@ func (sf *Packet) Parses(b []byte) error { // FRAG sf.Frag = b[2] sf.ATYP = b[3] + headLen := 4 switch sf.ATYP { case ATYPIPv4: + headLen += net.IPv4len + 2 sf.DstAddr.IP = net.IPv4(b[4], b[5], b[6], b[7]) sf.DstAddr.Port = buildPort(b[4+net.IPv4len], b[4+net.IPv4len+1]) case ATYPIPV6: - if len(b) <= (4 + net.IPv6len + 2) { + headLen += net.IPv6len + 2 + if len(b) <= headLen { return errors.New("too short") } @@ -82,7 +86,8 @@ func (sf *Packet) Parses(b []byte) error { sf.DstAddr.Port = buildPort(b[4+net.IPv6len], b[4+net.IPv6len+1]) case ATYPDomain: addrLen := int(b[4]) - if len(b) <= (4 + 1 + addrLen + 2) { + headLen += 1 + addrLen + 2 + if len(b) <= headLen { return errors.New("too short") } str := make([]byte, addrLen) @@ -92,5 +97,25 @@ func (sf *Packet) Parses(b []byte) error { default: return errUnrecognizedAddrType } + sf.Data = b[headLen:] return nil } + +func (sf *Packet) Header() []byte { + bs := make([]byte, 0, 32) + bs = append(bs, []byte{byte(sf.RSV << 8), byte(sf.RSV), sf.Frag}...) + switch sf.ATYP { + case ATYPIPv4: + bs = append(bs, ATYPIPv4) + bs = append(bs, sf.DstAddr.IP...) + case ATYPIPV6: + bs = append(bs, ATYPIPV6) + bs = append(bs, sf.DstAddr.IP...) + case ATYPDomain: + bs = append(bs, ATYPDomain) + bs = append(bs, []byte(sf.DstAddr.FQDN)...) + } + hi, lo := breakPort(sf.DstAddr.Port) + bs = append(bs, hi, lo) + return bs +} diff --git a/request.go b/request.go index 48069ea..aeb1aca 100644 --- a/request.go +++ b/request.go @@ -286,20 +286,7 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, req *Req } tmpBufPool := s.bufferPool.Get() proBuf := tmpBufPool - proBuf = append(proBuf, []byte{byte(pkb.RSV << 8), byte(pkb.RSV), pkb.Frag}...) - hi, lo := breakPort(pkb.DstAddr.Port) - switch pkb.ATYP { - case ATYPIPv4: - proBuf = append(proBuf, ATYPIPv4) - proBuf = append(proBuf, pkb.DstAddr.IP...) - case ATYPIPV6: - proBuf = append(proBuf, ATYPIPV6) - proBuf = append(proBuf, pkb.DstAddr.IP...) - case ATYPDomain: - proBuf = append(proBuf, ATYPDomain) - proBuf = append(proBuf, []byte(pkb.DstAddr.FQDN)...) - } - proBuf = append(proBuf, hi, lo) + proBuf = append(proBuf, pkb.Header()...) proBuf = append(proBuf, pkb.Data...) if _, err := bindLn.WriteTo(proBuf, srcAddr); err != nil { s.bufferPool.Put(tmpBufPool) diff --git a/socks5_test.go b/socks5_test.go index 9d62dc9..1d44de3 100644 --- a/socks5_test.go +++ b/socks5_test.go @@ -137,6 +137,7 @@ func TestSOCKS5_Associate(t *testing.T) { if err != nil { return } + if !bytes.Equal(buf[:n], []byte("ping")) { t.Fatalf("bad: %v", buf) }