clear invalid code

rename address to addr
fix example bug
This commit is contained in:
mo 2020-08-07 08:38:36 +08:00
parent a7f924fe7a
commit 3f8f48c598
7 changed files with 73 additions and 77 deletions

@ -171,9 +171,9 @@ func (sf *Client) handshake(command byte, addr string) (string, error) {
return "", err return "", err
} }
reqHead := statute.Request{ reqHead := statute.Request{
Version: statute.VersionSocks5, Version: statute.VersionSocks5,
Command: command, Command: command,
DstAddress: a, DstAddr: a,
} }
if _, err := sf.proxyConn.Write(reqHead.Bytes()); err != nil { if _, err := sf.proxyConn.Write(reqHead.Bytes()); err != nil {
return "", err return "", err
@ -186,7 +186,7 @@ func (sf *Client) handshake(command byte, addr string) (string, error) {
if rspHead.Response != statute.RepSuccess { if rspHead.Response != statute.RepSuccess {
return "", errors.New("host unreachable") return "", errors.New("host unreachable")
} }
return rspHead.BndAddress.String(), nil return rspHead.BndAddr.String(), nil
} }
// SetKeepAlive sets whether the operating system should send // SetKeepAlive sets whether the operating system should send

@ -41,7 +41,7 @@ func ParseRequest(bufConn io.Reader) (*Request, error) {
} }
return &Request{ return &Request{
Request: hd, Request: hd,
RawDestAddr: &hd.DstAddress, RawDestAddr: &hd.DstAddr,
Reader: bufConn, Reader: bufConn,
}, nil }, nil
} }
@ -294,7 +294,7 @@ func SendReply(w io.Writer, rep uint8, bindAddr net.Addr) error {
rsp := statute.Reply{ rsp := statute.Reply{
Version: statute.VersionSocks5, Version: statute.VersionSocks5,
Response: rep, Response: rep,
BndAddress: statute.AddrSpec{ BndAddr: statute.AddrSpec{
AddrType: statute.ATYPIPv4, AddrType: statute.ATYPIPv4,
IP: net.IPv4zero, IP: net.IPv4zero,
Port: 0, Port: 0,
@ -303,19 +303,19 @@ func SendReply(w io.Writer, rep uint8, bindAddr net.Addr) error {
if rsp.Response == statute.RepSuccess { if rsp.Response == statute.RepSuccess {
if tcpAddr, ok := bindAddr.(*net.TCPAddr); ok && tcpAddr != nil { if tcpAddr, ok := bindAddr.(*net.TCPAddr); ok && tcpAddr != nil {
rsp.BndAddress.IP = tcpAddr.IP rsp.BndAddr.IP = tcpAddr.IP
rsp.BndAddress.Port = tcpAddr.Port rsp.BndAddr.Port = tcpAddr.Port
} else if udpAddr, ok := bindAddr.(*net.UDPAddr); ok && udpAddr != nil { } else if udpAddr, ok := bindAddr.(*net.UDPAddr); ok && udpAddr != nil {
rsp.BndAddress.IP = udpAddr.IP rsp.BndAddr.IP = udpAddr.IP
rsp.BndAddress.Port = udpAddr.Port rsp.BndAddr.Port = udpAddr.Port
} else { } else {
rsp.Response = statute.RepAddrTypeNotSupported rsp.Response = statute.RepAddrTypeNotSupported
} }
if rsp.BndAddress.IP.To4() != nil { if rsp.BndAddr.IP.To4() != nil {
rsp.BndAddress.AddrType = statute.ATYPIPv4 rsp.BndAddr.AddrType = statute.ATYPIPv4
} else if rsp.BndAddress.IP.To16() != nil { } else if rsp.BndAddr.IP.To16() != nil {
rsp.BndAddress.AddrType = statute.ATYPIPv6 rsp.BndAddr.AddrType = statute.ATYPIPv6
} }
} }
// Send the message // Send the message

@ -64,7 +64,7 @@ func TestSOCKS5_Connect(t *testing.T) {
Version: statute.VersionSocks5, Version: statute.VersionSocks5,
Command: statute.CommandConnect, Command: statute.CommandConnect,
Reserved: 0, Reserved: 0,
DstAddress: statute.AddrSpec{ DstAddr: statute.AddrSpec{
FQDN: "", FQDN: "",
IP: net.ParseIP("127.0.0.1"), IP: net.ParseIP("127.0.0.1"),
Port: lAddr.Port, Port: lAddr.Port,
@ -87,7 +87,7 @@ func TestSOCKS5_Connect(t *testing.T) {
Version: statute.VersionSocks5, Version: statute.VersionSocks5,
Command: statute.RepSuccess, Command: statute.RepSuccess,
Reserved: 0, Reserved: 0,
DstAddress: statute.AddrSpec{ DstAddr: statute.AddrSpec{
FQDN: "", FQDN: "",
IP: net.ParseIP("127.0.0.1"), IP: net.ParseIP("127.0.0.1"),
Port: 0, Port: 0,
@ -156,7 +156,7 @@ func TestSOCKS5_Associate(t *testing.T) {
Version: statute.VersionSocks5, Version: statute.VersionSocks5,
Command: statute.CommandAssociate, Command: statute.CommandAssociate,
Reserved: 0, Reserved: 0,
DstAddress: statute.AddrSpec{ DstAddr: statute.AddrSpec{
FQDN: "", FQDN: "",
IP: locIP, IP: locIP,
Port: lAddr.Port, Port: lAddr.Port,
@ -185,11 +185,11 @@ func TestSOCKS5_Associate(t *testing.T) {
require.Equal(t, statute.VersionSocks5, rspHead.Version) require.Equal(t, statute.VersionSocks5, rspHead.Version)
require.Equal(t, statute.RepSuccess, rspHead.Response) require.Equal(t, statute.RepSuccess, rspHead.Response)
// t.Logf("proxy bind listen port: %d", rspHead.BndAddress.Port) // t.Logf("proxy bind listen port: %d", rspHead.BndAddr.Port)
udpConn, err := net.DialUDP("udp", nil, &net.UDPAddr{ udpConn, err := net.DialUDP("udp", nil, &net.UDPAddr{
IP: locIP, IP: locIP,
Port: rspHead.BndAddress.Port, Port: rspHead.BndAddr.Port,
}) })
require.NoError(t, err) require.NoError(t, err)
// Send a ping // Send a ping

@ -18,22 +18,22 @@ type AddrSpec struct {
// String returns a string suitable to dial; prefer returning IP-based // String returns a string suitable to dial; prefer returning IP-based
// address, fallback to FQDN // address, fallback to FQDN
func (a *AddrSpec) String() string { func (sf *AddrSpec) String() string {
if 0 != len(a.IP) { if 0 != len(sf.IP) {
return net.JoinHostPort(a.IP.String(), strconv.Itoa(a.Port)) return net.JoinHostPort(sf.IP.String(), strconv.Itoa(sf.Port))
} }
return net.JoinHostPort(a.FQDN, strconv.Itoa(a.Port)) return net.JoinHostPort(sf.FQDN, strconv.Itoa(sf.Port))
} }
// Address returns a string which may be specified // Address returns a string which may be specified
// if IPv4/IPv6 will return < ip:port > // if IPv4/IPv6 will return < ip:port >
// if FQDN will return < domain ip:port > // if FQDN will return < domain ip:port >
// Note: do not used to dial, Please use String // Note: do not used to dial, Please use String
func (a AddrSpec) Address() string { func (sf AddrSpec) Address() string {
if a.FQDN != "" { if sf.FQDN != "" {
return fmt.Sprintf("%s (%s):%d", a.FQDN, a.IP, a.Port) return fmt.Sprintf("%s (%s):%d", sf.FQDN, sf.IP, sf.Port)
} }
return fmt.Sprintf("%s:%d", a.IP, a.Port) return fmt.Sprintf("%s:%d", sf.IP, sf.Port)
} }
// ParseAddrSpec parse address to the AddrSpec address // ParseAddrSpec parse address to the AddrSpec address
@ -55,6 +55,3 @@ func ParseAddrSpec(address string) (as AddrSpec, err error) {
as.Port, err = strconv.Atoi(port) as.Port, err = strconv.Atoi(port)
return return
} }
func buildPort(hi, lo byte) int { return (int(hi) << 8) | int(lo) }
func breakPort(port int) (hi, lo byte) { return byte(port >> 8), byte(port) }

@ -1,6 +1,7 @@
package statute package statute
import ( import (
"encoding/binary"
"errors" "errors"
"math" "math"
"net" "net"
@ -49,7 +50,7 @@ func ParseDatagram(b []byte) (da Datagram, err error) {
case ATYPIPv4: case ATYPIPv4:
headLen += net.IPv4len + 2 headLen += net.IPv4len + 2
da.DstAddr.IP = net.IPv4(b[4], b[5], b[6], b[7]) da.DstAddr.IP = net.IPv4(b[4], b[5], b[6], b[7])
da.DstAddr.Port = buildPort(b[4+net.IPv4len], b[4+net.IPv4len+1]) da.DstAddr.Port = int(binary.BigEndian.Uint16((b[4+net.IPv4len:])))
case ATYPIPv6: case ATYPIPv6:
headLen += net.IPv6len + 2 headLen += net.IPv6len + 2
if len(b) <= headLen { if len(b) <= headLen {
@ -58,7 +59,7 @@ func ParseDatagram(b []byte) (da Datagram, err error) {
} }
da.DstAddr.IP = b[4 : 4+net.IPv6len] da.DstAddr.IP = b[4 : 4+net.IPv6len]
da.DstAddr.Port = buildPort(b[4+net.IPv6len], b[4+net.IPv6len+1]) da.DstAddr.Port = int(binary.BigEndian.Uint16(b[4+net.IPv6len:]))
case ATYPDomain: case ATYPDomain:
addrLen := int(b[4]) addrLen := int(b[4])
headLen += 1 + addrLen + 2 headLen += 1 + addrLen + 2
@ -69,7 +70,7 @@ func ParseDatagram(b []byte) (da Datagram, err error) {
str := make([]byte, addrLen) str := make([]byte, addrLen)
copy(str, b[5:5+addrLen]) copy(str, b[5:5+addrLen])
da.DstAddr.FQDN = string(str) da.DstAddr.FQDN = string(str)
da.DstAddr.Port = buildPort(b[5+addrLen], b[5+addrLen+1]) da.DstAddr.Port = int(binary.BigEndian.Uint16(b[5+addrLen:]))
default: default:
err = ErrUnrecognizedAddrType err = ErrUnrecognizedAddrType
return return
@ -114,8 +115,7 @@ func (sf *Datagram) values(hasData bool) (bs []byte) {
bs = append(bs, byte(len(sf.DstAddr.FQDN))) bs = append(bs, byte(len(sf.DstAddr.FQDN)))
} }
bs = append(bs, addr...) bs = append(bs, addr...)
hi, lo := breakPort(sf.DstAddr.Port) bs = append(bs, byte(sf.DstAddr.Port>>8), byte(sf.DstAddr.Port))
bs = append(bs, hi, lo)
if hasData { if hasData {
bs = append(bs, sf.Data...) bs = append(bs, sf.Data...)
} }

@ -1,6 +1,7 @@
package statute package statute
import ( import (
"encoding/binary"
"fmt" "fmt"
"io" "io"
"net" "net"
@ -20,8 +21,8 @@ type Request struct {
Command byte Command byte
// Reserved byte // Reserved byte
Reserved byte Reserved byte
// DstAddress in socks message // DstAddr in socks message
DstAddress AddrSpec DstAddr AddrSpec
} }
// ParseRequest to request from io.Reader // ParseRequest to request from io.Reader
@ -40,23 +41,23 @@ func ParseRequest(r io.Reader) (req Request, err error) {
if _, err = io.ReadFull(r, tmp); err != nil { if _, err = io.ReadFull(r, tmp); err != nil {
return req, fmt.Errorf("failed to get request RSV and address type, %v", err) return req, fmt.Errorf("failed to get request RSV and address type, %v", err)
} }
req.Reserved, req.DstAddress.AddrType = tmp[0], tmp[1] req.Reserved, req.DstAddr.AddrType = tmp[0], tmp[1]
switch req.DstAddress.AddrType { switch req.DstAddr.AddrType {
case ATYPIPv4: case ATYPIPv4:
addr := make([]byte, net.IPv4len+2) addr := make([]byte, net.IPv4len+2)
if _, err = io.ReadFull(r, addr); err != nil { if _, err = io.ReadFull(r, addr); err != nil {
return req, fmt.Errorf("failed to get request, %v", err) return req, fmt.Errorf("failed to get request, %v", err)
} }
req.DstAddress.IP = net.IPv4(addr[0], addr[1], addr[2], addr[3]) req.DstAddr.IP = net.IPv4(addr[0], addr[1], addr[2], addr[3])
req.DstAddress.Port = buildPort(addr[net.IPv4len], addr[net.IPv4len+1]) req.DstAddr.Port = int(binary.BigEndian.Uint16(addr[net.IPv4len:]))
case ATYPIPv6: case ATYPIPv6:
addr := make([]byte, net.IPv6len+2) addr := make([]byte, net.IPv6len+2)
if _, err = io.ReadFull(r, addr); err != nil { if _, err = io.ReadFull(r, addr); err != nil {
return req, fmt.Errorf("failed to get request, %v", err) return req, fmt.Errorf("failed to get request, %v", err)
} }
req.DstAddress.IP = addr[:net.IPv6len] req.DstAddr.IP = addr[:net.IPv6len]
req.DstAddress.Port = buildPort(addr[net.IPv6len], addr[net.IPv6len+1]) req.DstAddr.Port = int(binary.BigEndian.Uint16(addr[net.IPv6len:]))
case ATYPDomain: case ATYPDomain:
if _, err = io.ReadFull(r, tmp[:1]); err != nil { if _, err = io.ReadFull(r, tmp[:1]); err != nil {
return req, fmt.Errorf("failed to get request, %v", err) return req, fmt.Errorf("failed to get request, %v", err)
@ -66,8 +67,8 @@ func ParseRequest(r io.Reader) (req Request, err error) {
if _, err = io.ReadFull(r, addr); err != nil { if _, err = io.ReadFull(r, addr); err != nil {
return req, fmt.Errorf("failed to get request, %v", err) return req, fmt.Errorf("failed to get request, %v", err)
} }
req.DstAddress.FQDN = string(addr[:domainLen]) req.DstAddr.FQDN = string(addr[:domainLen])
req.DstAddress.Port = buildPort(addr[domainLen], addr[domainLen+1]) req.DstAddr.Port = int(binary.BigEndian.Uint16(addr[domainLen:]))
default: default:
return req, ErrUnrecognizedAddrType return req, ErrUnrecognizedAddrType
} }
@ -79,25 +80,24 @@ func (h Request) Bytes() (b []byte) {
var addr []byte var addr []byte
length := 6 length := 6
if h.DstAddress.AddrType == ATYPIPv4 { if h.DstAddr.AddrType == ATYPIPv4 {
length += net.IPv4len length += net.IPv4len
addr = h.DstAddress.IP.To4() addr = h.DstAddr.IP.To4()
} else if h.DstAddress.AddrType == ATYPIPv6 { } else if h.DstAddr.AddrType == ATYPIPv6 {
length += net.IPv6len length += net.IPv6len
addr = h.DstAddress.IP.To16() addr = h.DstAddr.IP.To16()
} else { //ATYPDomain } else { //ATYPDomain
length += 1 + len(h.DstAddress.FQDN) length += 1 + len(h.DstAddr.FQDN)
addr = []byte(h.DstAddress.FQDN) addr = []byte(h.DstAddr.FQDN)
} }
b = make([]byte, 0, length) b = make([]byte, 0, length)
b = append(b, h.Version, h.Command, h.Reserved, h.DstAddress.AddrType) b = append(b, h.Version, h.Command, h.Reserved, h.DstAddr.AddrType)
if h.DstAddress.AddrType == ATYPDomain { if h.DstAddr.AddrType == ATYPDomain {
b = append(b, byte(len(h.DstAddress.FQDN))) b = append(b, byte(len(h.DstAddr.FQDN)))
} }
b = append(b, addr...) b = append(b, addr...)
hiPort, loPort := breakPort(h.DstAddress.Port) b = append(b, byte(h.DstAddr.Port>>8), byte(h.DstAddr.Port))
b = append(b, hiPort, loPort)
return b return b
} }
@ -116,33 +116,32 @@ type Reply struct {
// Reserved byte // Reserved byte
Reserved byte Reserved byte
// Bind Address in socks message // Bind Address in socks message
BndAddress AddrSpec BndAddr AddrSpec
} }
// Bytes returns a slice of request // Bytes returns a slice of request
func (h Reply) Bytes() (b []byte) { func (sf Reply) Bytes() (b []byte) {
var addr []byte var addr []byte
length := 6 length := 6
if h.BndAddress.AddrType == ATYPIPv4 { if sf.BndAddr.AddrType == ATYPIPv4 {
length += net.IPv4len length += net.IPv4len
addr = h.BndAddress.IP.To4() addr = sf.BndAddr.IP.To4()
} else if h.BndAddress.AddrType == ATYPIPv6 { } else if sf.BndAddr.AddrType == ATYPIPv6 {
length += net.IPv6len length += net.IPv6len
addr = h.BndAddress.IP.To16() addr = sf.BndAddr.IP.To16()
} else { //ATYPDomain } else { //ATYPDomain
length += 1 + len(h.BndAddress.FQDN) length += 1 + len(sf.BndAddr.FQDN)
addr = []byte(h.BndAddress.FQDN) addr = []byte(sf.BndAddr.FQDN)
} }
b = make([]byte, 0, length) b = make([]byte, 0, length)
b = append(b, h.Version, h.Response, h.Reserved, h.BndAddress.AddrType) b = append(b, sf.Version, sf.Response, sf.Reserved, sf.BndAddr.AddrType)
if h.BndAddress.AddrType == ATYPDomain { if sf.BndAddr.AddrType == ATYPDomain {
b = append(b, byte(len(h.BndAddress.FQDN))) b = append(b, byte(len(sf.BndAddr.FQDN)))
} }
b = append(b, addr...) b = append(b, addr...)
hiPort, loPort := breakPort(h.BndAddress.Port) b = append(b, byte(sf.BndAddr.Port>>8), byte(sf.BndAddr.Port))
b = append(b, hiPort, loPort)
return b return b
} }
@ -161,9 +160,9 @@ func ParseReply(r io.Reader) (rep Reply, err error) {
if _, err = io.ReadFull(r, tmp); err != nil { if _, err = io.ReadFull(r, tmp); err != nil {
return rep, fmt.Errorf("failed to get request RSV and address type, %v", err) return rep, fmt.Errorf("failed to get request RSV and address type, %v", err)
} }
rep.Reserved, rep.BndAddress.AddrType = tmp[0], tmp[1] rep.Reserved, rep.BndAddr.AddrType = tmp[0], tmp[1]
switch rep.BndAddress.AddrType { switch rep.BndAddr.AddrType {
case ATYPDomain: case ATYPDomain:
if _, err = io.ReadFull(r, tmp[:1]); err != nil { if _, err = io.ReadFull(r, tmp[:1]); err != nil {
return rep, fmt.Errorf("failed to get request, %v", err) return rep, fmt.Errorf("failed to get request, %v", err)
@ -173,22 +172,22 @@ func ParseReply(r io.Reader) (rep Reply, err error) {
if _, err = io.ReadFull(r, addr); err != nil { if _, err = io.ReadFull(r, addr); err != nil {
return rep, fmt.Errorf("failed to get request, %v", err) return rep, fmt.Errorf("failed to get request, %v", err)
} }
rep.BndAddress.FQDN = string(addr[:domainLen]) rep.BndAddr.FQDN = string(addr[:domainLen])
rep.BndAddress.Port = buildPort(addr[domainLen], addr[domainLen+1]) rep.BndAddr.Port = int(binary.BigEndian.Uint16(addr[domainLen:]))
case ATYPIPv4: case ATYPIPv4:
addr := make([]byte, net.IPv4len+2) addr := make([]byte, net.IPv4len+2)
if _, err = io.ReadFull(r, addr); err != nil { if _, err = io.ReadFull(r, addr); err != nil {
return rep, fmt.Errorf("failed to get request, %v", err) return rep, fmt.Errorf("failed to get request, %v", err)
} }
rep.BndAddress.IP = net.IPv4(addr[0], addr[1], addr[2], addr[3]) rep.BndAddr.IP = net.IPv4(addr[0], addr[1], addr[2], addr[3])
rep.BndAddress.Port = buildPort(addr[net.IPv4len], addr[net.IPv4len+1]) rep.BndAddr.Port = int(binary.BigEndian.Uint16(addr[net.IPv4len:]))
case ATYPIPv6: case ATYPIPv6:
addr := make([]byte, net.IPv6len+2) addr := make([]byte, net.IPv6len+2)
if _, err = io.ReadFull(r, addr); err != nil { if _, err = io.ReadFull(r, addr); err != nil {
return rep, fmt.Errorf("failed to get request, %v", err) return rep, fmt.Errorf("failed to get request, %v", err)
} }
rep.BndAddress.IP = addr[:net.IPv6len] rep.BndAddr.IP = addr[:net.IPv6len]
rep.BndAddress.Port = buildPort(addr[net.IPv6len], addr[net.IPv6len+1]) rep.BndAddr.Port = int(binary.BigEndian.Uint16(addr[net.IPv6len:]))
default: default:
return rep, ErrUnrecognizedAddrType return rep, ErrUnrecognizedAddrType
} }