add udp packet encode and decode

This commit is contained in:
mo 2020-04-23 23:17:54 +08:00
parent b16d6bc8f5
commit f2844250ff
2 changed files with 118 additions and 50 deletions

96
packet.go Normal file

@ -0,0 +1,96 @@
package socks5
import (
"errors"
"math"
"net"
"strconv"
)
/*
The SOCKS UDP request/response is formed as follows:
+-----+------+-------+----------+----------+----------+
| RSV | FRAG | ATYP | DST.ADDR | DST.PORT | DATA |
+-----+------+-------+----------+----------+----------+
| 2 | 1 | X'00' | Variable | 2 | Variable |
+-----+------+-------+----------+----------+----------+
*/
// Packet udp packet
type Packet struct {
RSV uint16
Frag uint8
ATYP uint8
DstAddr AddrSpec
Data []byte
}
func NewEmptyPacket() Packet {
return Packet{}
}
func NewPacket(destAddr string, data []byte) (p Packet, err error) {
var host, port string
host, port, err = net.SplitHostPort(destAddr)
if err != nil {
return
}
p.DstAddr.Port, err = strconv.Atoi(port)
if err != nil {
return
}
p.RSV = 0
p.Frag = 0
if ip := net.ParseIP(host); ip != nil {
if ip4 := ip.To4(); ip4 != nil {
p.ATYP = ATYPIPv4
p.DstAddr.IP = ip4
} else {
p.ATYP = ATYPIPV6
p.DstAddr.IP = ip.To16()
}
} else {
if len(host) > math.MaxUint8 {
err = errors.New("destination host name too long")
return
}
p.ATYP = ATYPDomain
p.DstAddr.FQDN = host
}
return
}
func (sf *Packet) Parses(b []byte) error {
if len(b) <= 4+net.IPv4len+2 { // no data
return errors.New("too short")
}
// ignore RSV
sf.RSV = 0
// FRAG
sf.Frag = b[2]
sf.ATYP = b[3]
switch sf.ATYP {
case ATYPIPv4:
sf.DstAddr.IP = net.IPv4(b[4], b[5], b[6], b[7])
sf.DstAddr.Port = buildPort(b[4+net.IPv4len], b[4+net.IPv4len+1])
case ATYPIPV6:
if len(b) <= (4 + net.IPv6len + 2) {
return errors.New("too short")
}
sf.DstAddr.IP = net.IP{b[4], b[5], b[6], b[7], b[8], b[9], b[10], b[11], b[12], b[13], b[14], b[15], b[16], b[17], b[18], b[19]}
sf.DstAddr.Port = buildPort(b[4+net.IPv6len], b[4+net.IPv6len+1])
case ATYPDomain:
addrLen := int(b[4])
if len(b) <= (4 + 1 + addrLen + 2) {
return errors.New("too short")
}
str := make([]byte, addrLen)
copy(str, b[5:5+addrLen])
sf.DstAddr.FQDN = string(str)
sf.DstAddr.Port = buildPort(b[5+addrLen], b[5+addrLen+1])
default:
return errUnrecognizedAddrType
}
return nil
}

@ -252,45 +252,12 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, req *Req
return
}
if n <= 4+net.IPv4len+2 { // no data
pk := NewEmptyPacket()
if err := pk.Parses(buf[:n]); err != nil {
continue
}
// ignore RSV,FRAG
addrType := buf[3]
addrLen := 0
headLen := 0
var addrSpc AddrSpec
if addrType == ATYPIPv4 {
headLen = 4 + net.IPv4len + 2
addrLen = net.IPv4len
addrSpc.IP = make(net.IP, net.IPv4len)
copy(addrSpc.IP, buf[4:4+net.IPv4len])
addrSpc.Port = buildPort(buf[4+net.IPv4len], buf[4+net.IPv4len+1])
} else if addrType == ATYPIPV6 {
headLen = 4 + net.IPv6len + 2
if n <= headLen {
continue
}
addrLen = net.IPv6len
addrSpc.IP = make(net.IP, net.IPv6len)
copy(addrSpc.IP, buf[4:4+net.IPv6len])
addrSpc.Port = buildPort(buf[4+net.IPv6len], buf[4+net.IPv6len+1])
} else if addrType == ATYPDomain {
addrLen = int(buf[4])
headLen = 4 + 1 + addrLen + 2
if n <= headLen {
continue
}
str := make([]byte, addrLen)
copy(str, buf[5:5+addrLen])
addrSpc.FQDN = string(str)
addrSpc.Port = buildPort(buf[5+addrLen], buf[5+addrLen+1])
} else {
continue
}
// 把消息写给remote sever
if _, err := targetUDP.Write(buf[headLen:n]); err != nil {
if _, err := targetUDP.Write(pk.Data); err != nil {
s.logger.Errorf("write data to remote %s failed, %v", targetUDP.RemoteAddr(), err)
return
}
@ -313,22 +280,27 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, req *Req
return
}
tmpBufPool := s.bufferPool.Get()
proBuf := tmpBufPool
rAddr, _ := net.ResolveUDPAddr("udp", remote.String())
hi, lo := breakPort(rAddr.Port)
if rAddr.IP.To4() != nil {
proBuf = append(proBuf, []byte{0, 0, 0, ATYPIPv4}...)
proBuf = append(proBuf, rAddr.IP.To4()...)
proBuf = append(proBuf, hi, lo)
} else if rAddr.IP.To16() != nil {
proBuf = append(proBuf, []byte{0, 0, 0, ATYPIPV6}...)
proBuf = append(proBuf, rAddr.IP.To16()...)
proBuf = append(proBuf, hi, lo)
} else { // should never happen
pkb, err := NewPacket(remote.String(), buf[:n])
if err != nil {
continue
}
proBuf = append(proBuf, buf[:n]...)
tmpBufPool := s.bufferPool.Get()
proBuf := tmpBufPool
proBuf = append(proBuf, []byte{byte(pkb.RSV << 8), byte(pkb.RSV), pkb.Frag}...)
hi, lo := breakPort(pkb.DstAddr.Port)
switch pkb.ATYP {
case ATYPIPv4:
proBuf = append(proBuf, ATYPIPv4)
proBuf = append(proBuf, pkb.DstAddr.IP...)
case ATYPIPV6:
proBuf = append(proBuf, ATYPIPV6)
proBuf = append(proBuf, pkb.DstAddr.IP...)
case ATYPDomain:
proBuf = append(proBuf, ATYPDomain)
proBuf = append(proBuf, []byte(pkb.DstAddr.FQDN)...)
}
proBuf = append(proBuf, hi, lo)
proBuf = append(proBuf, pkb.Data...)
if _, err := bindLn.WriteTo(proBuf, srcAddr); err != nil {
s.bufferPool.Put(tmpBufPool)
s.logger.Errorf("write data to client %s failed, %v", bindLn.LocalAddr(), err)