diff --git a/ccsocks5/client.go b/ccsocks5/client.go index 19c2d1c..330e6af 100644 --- a/ccsocks5/client.go +++ b/ccsocks5/client.go @@ -171,9 +171,9 @@ func (sf *Client) handshake(command byte, addr string) (string, error) { return "", err } reqHead := statute.Request{ - Version: statute.VersionSocks5, - Command: command, - DstAddress: a, + Version: statute.VersionSocks5, + Command: command, + DstAddr: a, } if _, err := sf.proxyConn.Write(reqHead.Bytes()); err != nil { return "", err @@ -186,7 +186,7 @@ func (sf *Client) handshake(command byte, addr string) (string, error) { if rspHead.Response != statute.RepSuccess { return "", errors.New("host unreachable") } - return rspHead.BndAddress.String(), nil + return rspHead.BndAddr.String(), nil } // SetKeepAlive sets whether the operating system should send diff --git a/handle.go b/handle.go index 3f64e12..efbfc94 100644 --- a/handle.go +++ b/handle.go @@ -41,7 +41,7 @@ func ParseRequest(bufConn io.Reader) (*Request, error) { } return &Request{ Request: hd, - RawDestAddr: &hd.DstAddress, + RawDestAddr: &hd.DstAddr, Reader: bufConn, }, nil } @@ -294,7 +294,7 @@ func SendReply(w io.Writer, rep uint8, bindAddr net.Addr) error { rsp := statute.Reply{ Version: statute.VersionSocks5, Response: rep, - BndAddress: statute.AddrSpec{ + BndAddr: statute.AddrSpec{ AddrType: statute.ATYPIPv4, IP: net.IPv4zero, Port: 0, @@ -303,19 +303,19 @@ func SendReply(w io.Writer, rep uint8, bindAddr net.Addr) error { if rsp.Response == statute.RepSuccess { if tcpAddr, ok := bindAddr.(*net.TCPAddr); ok && tcpAddr != nil { - rsp.BndAddress.IP = tcpAddr.IP - rsp.BndAddress.Port = tcpAddr.Port + rsp.BndAddr.IP = tcpAddr.IP + rsp.BndAddr.Port = tcpAddr.Port } else if udpAddr, ok := bindAddr.(*net.UDPAddr); ok && udpAddr != nil { - rsp.BndAddress.IP = udpAddr.IP - rsp.BndAddress.Port = udpAddr.Port + rsp.BndAddr.IP = udpAddr.IP + rsp.BndAddr.Port = udpAddr.Port } else { rsp.Response = statute.RepAddrTypeNotSupported } - if rsp.BndAddress.IP.To4() != nil { - rsp.BndAddress.AddrType = statute.ATYPIPv4 - } else if rsp.BndAddress.IP.To16() != nil { - rsp.BndAddress.AddrType = statute.ATYPIPv6 + if rsp.BndAddr.IP.To4() != nil { + rsp.BndAddr.AddrType = statute.ATYPIPv4 + } else if rsp.BndAddr.IP.To16() != nil { + rsp.BndAddr.AddrType = statute.ATYPIPv6 } } // Send the message diff --git a/server_test.go b/server_test.go index 6934fb9..7022c2e 100644 --- a/server_test.go +++ b/server_test.go @@ -64,7 +64,7 @@ func TestSOCKS5_Connect(t *testing.T) { Version: statute.VersionSocks5, Command: statute.CommandConnect, Reserved: 0, - DstAddress: statute.AddrSpec{ + DstAddr: statute.AddrSpec{ FQDN: "", IP: net.ParseIP("127.0.0.1"), Port: lAddr.Port, @@ -87,7 +87,7 @@ func TestSOCKS5_Connect(t *testing.T) { Version: statute.VersionSocks5, Command: statute.RepSuccess, Reserved: 0, - DstAddress: statute.AddrSpec{ + DstAddr: statute.AddrSpec{ FQDN: "", IP: net.ParseIP("127.0.0.1"), Port: 0, @@ -156,7 +156,7 @@ func TestSOCKS5_Associate(t *testing.T) { Version: statute.VersionSocks5, Command: statute.CommandAssociate, Reserved: 0, - DstAddress: statute.AddrSpec{ + DstAddr: statute.AddrSpec{ FQDN: "", IP: locIP, Port: lAddr.Port, @@ -185,11 +185,11 @@ func TestSOCKS5_Associate(t *testing.T) { require.Equal(t, statute.VersionSocks5, rspHead.Version) require.Equal(t, statute.RepSuccess, rspHead.Response) - // t.Logf("proxy bind listen port: %d", rspHead.BndAddress.Port) + // t.Logf("proxy bind listen port: %d", rspHead.BndAddr.Port) udpConn, err := net.DialUDP("udp", nil, &net.UDPAddr{ IP: locIP, - Port: rspHead.BndAddress.Port, + Port: rspHead.BndAddr.Port, }) require.NoError(t, err) // Send a ping diff --git a/statute/util.go b/statute/addr.go similarity index 68% rename from statute/util.go rename to statute/addr.go index 2d18bc4..a28bd61 100644 --- a/statute/util.go +++ b/statute/addr.go @@ -18,22 +18,22 @@ type AddrSpec struct { // String returns a string suitable to dial; prefer returning IP-based // address, fallback to FQDN -func (a *AddrSpec) String() string { - if 0 != len(a.IP) { - return net.JoinHostPort(a.IP.String(), strconv.Itoa(a.Port)) +func (sf *AddrSpec) String() string { + if 0 != len(sf.IP) { + return net.JoinHostPort(sf.IP.String(), strconv.Itoa(sf.Port)) } - return net.JoinHostPort(a.FQDN, strconv.Itoa(a.Port)) + return net.JoinHostPort(sf.FQDN, strconv.Itoa(sf.Port)) } // Address returns a string which may be specified // if IPv4/IPv6 will return < ip:port > // if FQDN will return < domain ip:port > // Note: do not used to dial, Please use String -func (a AddrSpec) Address() string { - if a.FQDN != "" { - return fmt.Sprintf("%s (%s):%d", a.FQDN, a.IP, a.Port) +func (sf AddrSpec) Address() string { + if sf.FQDN != "" { + return fmt.Sprintf("%s (%s):%d", sf.FQDN, sf.IP, sf.Port) } - return fmt.Sprintf("%s:%d", a.IP, a.Port) + return fmt.Sprintf("%s:%d", sf.IP, sf.Port) } // ParseAddrSpec parse address to the AddrSpec address @@ -55,6 +55,3 @@ func ParseAddrSpec(address string) (as AddrSpec, err error) { as.Port, err = strconv.Atoi(port) return } - -func buildPort(hi, lo byte) int { return (int(hi) << 8) | int(lo) } -func breakPort(port int) (hi, lo byte) { return byte(port >> 8), byte(port) } diff --git a/statute/util_test.go b/statute/addr_test.go similarity index 100% rename from statute/util_test.go rename to statute/addr_test.go diff --git a/statute/datagram.go b/statute/datagram.go index af7e671..ed697cc 100644 --- a/statute/datagram.go +++ b/statute/datagram.go @@ -1,6 +1,7 @@ package statute import ( + "encoding/binary" "errors" "math" "net" @@ -49,7 +50,7 @@ func ParseDatagram(b []byte) (da Datagram, err error) { case ATYPIPv4: headLen += net.IPv4len + 2 da.DstAddr.IP = net.IPv4(b[4], b[5], b[6], b[7]) - da.DstAddr.Port = buildPort(b[4+net.IPv4len], b[4+net.IPv4len+1]) + da.DstAddr.Port = int(binary.BigEndian.Uint16((b[4+net.IPv4len:]))) case ATYPIPv6: headLen += net.IPv6len + 2 if len(b) <= headLen { @@ -58,7 +59,7 @@ func ParseDatagram(b []byte) (da Datagram, err error) { } da.DstAddr.IP = b[4 : 4+net.IPv6len] - da.DstAddr.Port = buildPort(b[4+net.IPv6len], b[4+net.IPv6len+1]) + da.DstAddr.Port = int(binary.BigEndian.Uint16(b[4+net.IPv6len:])) case ATYPDomain: addrLen := int(b[4]) headLen += 1 + addrLen + 2 @@ -69,7 +70,7 @@ func ParseDatagram(b []byte) (da Datagram, err error) { str := make([]byte, addrLen) copy(str, b[5:5+addrLen]) da.DstAddr.FQDN = string(str) - da.DstAddr.Port = buildPort(b[5+addrLen], b[5+addrLen+1]) + da.DstAddr.Port = int(binary.BigEndian.Uint16(b[5+addrLen:])) default: err = ErrUnrecognizedAddrType return @@ -114,8 +115,7 @@ func (sf *Datagram) values(hasData bool) (bs []byte) { bs = append(bs, byte(len(sf.DstAddr.FQDN))) } bs = append(bs, addr...) - hi, lo := breakPort(sf.DstAddr.Port) - bs = append(bs, hi, lo) + bs = append(bs, byte(sf.DstAddr.Port>>8), byte(sf.DstAddr.Port)) if hasData { bs = append(bs, sf.Data...) } diff --git a/statute/message.go b/statute/message.go index 986a0ef..1649e42 100644 --- a/statute/message.go +++ b/statute/message.go @@ -1,6 +1,7 @@ package statute import ( + "encoding/binary" "fmt" "io" "net" @@ -20,8 +21,8 @@ type Request struct { Command byte // Reserved byte Reserved byte - // DstAddress in socks message - DstAddress AddrSpec + // DstAddr in socks message + DstAddr AddrSpec } // ParseRequest to request from io.Reader @@ -40,23 +41,23 @@ func ParseRequest(r io.Reader) (req Request, err error) { if _, err = io.ReadFull(r, tmp); err != nil { return req, fmt.Errorf("failed to get request RSV and address type, %v", err) } - req.Reserved, req.DstAddress.AddrType = tmp[0], tmp[1] + req.Reserved, req.DstAddr.AddrType = tmp[0], tmp[1] - switch req.DstAddress.AddrType { + switch req.DstAddr.AddrType { case ATYPIPv4: addr := make([]byte, net.IPv4len+2) if _, err = io.ReadFull(r, addr); err != nil { return req, fmt.Errorf("failed to get request, %v", err) } - req.DstAddress.IP = net.IPv4(addr[0], addr[1], addr[2], addr[3]) - req.DstAddress.Port = buildPort(addr[net.IPv4len], addr[net.IPv4len+1]) + req.DstAddr.IP = net.IPv4(addr[0], addr[1], addr[2], addr[3]) + req.DstAddr.Port = int(binary.BigEndian.Uint16(addr[net.IPv4len:])) case ATYPIPv6: addr := make([]byte, net.IPv6len+2) if _, err = io.ReadFull(r, addr); err != nil { return req, fmt.Errorf("failed to get request, %v", err) } - req.DstAddress.IP = addr[:net.IPv6len] - req.DstAddress.Port = buildPort(addr[net.IPv6len], addr[net.IPv6len+1]) + req.DstAddr.IP = addr[:net.IPv6len] + req.DstAddr.Port = int(binary.BigEndian.Uint16(addr[net.IPv6len:])) case ATYPDomain: if _, err = io.ReadFull(r, tmp[:1]); err != nil { return req, fmt.Errorf("failed to get request, %v", err) @@ -66,8 +67,8 @@ func ParseRequest(r io.Reader) (req Request, err error) { if _, err = io.ReadFull(r, addr); err != nil { return req, fmt.Errorf("failed to get request, %v", err) } - req.DstAddress.FQDN = string(addr[:domainLen]) - req.DstAddress.Port = buildPort(addr[domainLen], addr[domainLen+1]) + req.DstAddr.FQDN = string(addr[:domainLen]) + req.DstAddr.Port = int(binary.BigEndian.Uint16(addr[domainLen:])) default: return req, ErrUnrecognizedAddrType } @@ -79,25 +80,24 @@ func (h Request) Bytes() (b []byte) { var addr []byte length := 6 - if h.DstAddress.AddrType == ATYPIPv4 { + if h.DstAddr.AddrType == ATYPIPv4 { length += net.IPv4len - addr = h.DstAddress.IP.To4() - } else if h.DstAddress.AddrType == ATYPIPv6 { + addr = h.DstAddr.IP.To4() + } else if h.DstAddr.AddrType == ATYPIPv6 { length += net.IPv6len - addr = h.DstAddress.IP.To16() + addr = h.DstAddr.IP.To16() } else { //ATYPDomain - length += 1 + len(h.DstAddress.FQDN) - addr = []byte(h.DstAddress.FQDN) + length += 1 + len(h.DstAddr.FQDN) + addr = []byte(h.DstAddr.FQDN) } b = make([]byte, 0, length) - b = append(b, h.Version, h.Command, h.Reserved, h.DstAddress.AddrType) - if h.DstAddress.AddrType == ATYPDomain { - b = append(b, byte(len(h.DstAddress.FQDN))) + b = append(b, h.Version, h.Command, h.Reserved, h.DstAddr.AddrType) + if h.DstAddr.AddrType == ATYPDomain { + b = append(b, byte(len(h.DstAddr.FQDN))) } b = append(b, addr...) - hiPort, loPort := breakPort(h.DstAddress.Port) - b = append(b, hiPort, loPort) + b = append(b, byte(h.DstAddr.Port>>8), byte(h.DstAddr.Port)) return b } @@ -116,33 +116,32 @@ type Reply struct { // Reserved byte Reserved byte // Bind Address in socks message - BndAddress AddrSpec + BndAddr AddrSpec } // Bytes returns a slice of request -func (h Reply) Bytes() (b []byte) { +func (sf Reply) Bytes() (b []byte) { var addr []byte length := 6 - if h.BndAddress.AddrType == ATYPIPv4 { + if sf.BndAddr.AddrType == ATYPIPv4 { length += net.IPv4len - addr = h.BndAddress.IP.To4() - } else if h.BndAddress.AddrType == ATYPIPv6 { + addr = sf.BndAddr.IP.To4() + } else if sf.BndAddr.AddrType == ATYPIPv6 { length += net.IPv6len - addr = h.BndAddress.IP.To16() + addr = sf.BndAddr.IP.To16() } else { //ATYPDomain - length += 1 + len(h.BndAddress.FQDN) - addr = []byte(h.BndAddress.FQDN) + length += 1 + len(sf.BndAddr.FQDN) + addr = []byte(sf.BndAddr.FQDN) } b = make([]byte, 0, length) - b = append(b, h.Version, h.Response, h.Reserved, h.BndAddress.AddrType) - if h.BndAddress.AddrType == ATYPDomain { - b = append(b, byte(len(h.BndAddress.FQDN))) + b = append(b, sf.Version, sf.Response, sf.Reserved, sf.BndAddr.AddrType) + if sf.BndAddr.AddrType == ATYPDomain { + b = append(b, byte(len(sf.BndAddr.FQDN))) } b = append(b, addr...) - hiPort, loPort := breakPort(h.BndAddress.Port) - b = append(b, hiPort, loPort) + b = append(b, byte(sf.BndAddr.Port>>8), byte(sf.BndAddr.Port)) return b } @@ -161,9 +160,9 @@ func ParseReply(r io.Reader) (rep Reply, err error) { if _, err = io.ReadFull(r, tmp); err != nil { return rep, fmt.Errorf("failed to get request RSV and address type, %v", err) } - rep.Reserved, rep.BndAddress.AddrType = tmp[0], tmp[1] + rep.Reserved, rep.BndAddr.AddrType = tmp[0], tmp[1] - switch rep.BndAddress.AddrType { + switch rep.BndAddr.AddrType { case ATYPDomain: if _, err = io.ReadFull(r, tmp[:1]); err != nil { return rep, fmt.Errorf("failed to get request, %v", err) @@ -173,22 +172,22 @@ func ParseReply(r io.Reader) (rep Reply, err error) { if _, err = io.ReadFull(r, addr); err != nil { return rep, fmt.Errorf("failed to get request, %v", err) } - rep.BndAddress.FQDN = string(addr[:domainLen]) - rep.BndAddress.Port = buildPort(addr[domainLen], addr[domainLen+1]) + rep.BndAddr.FQDN = string(addr[:domainLen]) + rep.BndAddr.Port = int(binary.BigEndian.Uint16(addr[domainLen:])) case ATYPIPv4: addr := make([]byte, net.IPv4len+2) if _, err = io.ReadFull(r, addr); err != nil { return rep, fmt.Errorf("failed to get request, %v", err) } - rep.BndAddress.IP = net.IPv4(addr[0], addr[1], addr[2], addr[3]) - rep.BndAddress.Port = buildPort(addr[net.IPv4len], addr[net.IPv4len+1]) + rep.BndAddr.IP = net.IPv4(addr[0], addr[1], addr[2], addr[3]) + rep.BndAddr.Port = int(binary.BigEndian.Uint16(addr[net.IPv4len:])) case ATYPIPv6: addr := make([]byte, net.IPv6len+2) if _, err = io.ReadFull(r, addr); err != nil { return rep, fmt.Errorf("failed to get request, %v", err) } - rep.BndAddress.IP = addr[:net.IPv6len] - rep.BndAddress.Port = buildPort(addr[net.IPv6len], addr[net.IPv6len+1]) + rep.BndAddr.IP = addr[:net.IPv6len] + rep.BndAddr.Port = int(binary.BigEndian.Uint16(addr[net.IPv6len:])) default: return rep, ErrUnrecognizedAddrType }