implement raw udp

This commit is contained in:
mo 2020-04-20 16:19:51 +08:00
parent 32060b0332
commit be8e44e47d
5 changed files with 221 additions and 68 deletions

21
auth.go

@ -6,12 +6,13 @@ import (
) )
const ( const (
NoAuth = uint8(0) NoAuth = uint8(0)
noAcceptable = uint8(255) GSSAPI = uint8(1)
UserPassAuth = uint8(2) UserPassAuth = uint8(2)
userAuthVersion = uint8(1) NoAcceptable = uint8(255)
authSuccess = uint8(0) userPassAuthVersion = uint8(1)
authFailure = uint8(1) authSuccess = uint8(0)
authFailure = uint8(1)
) )
var ( var (
@ -70,7 +71,7 @@ func (a UserPassAuthenticator) Authenticate(reader io.Reader, writer io.Writer)
} }
// Ensure we are compatible // Ensure we are compatible
if header[0] != userAuthVersion { if header[0] != userPassAuthVersion {
return nil, fmt.Errorf("Unsupported auth version: %v", header[0]) return nil, fmt.Errorf("Unsupported auth version: %v", header[0])
} }
@ -95,11 +96,11 @@ func (a UserPassAuthenticator) Authenticate(reader io.Reader, writer io.Writer)
// Verify the password // Verify the password
if a.Credentials.Valid(string(user), string(pass)) { if a.Credentials.Valid(string(user), string(pass)) {
if _, err := writer.Write([]byte{userAuthVersion, authSuccess}); err != nil { if _, err := writer.Write([]byte{userPassAuthVersion, authSuccess}); err != nil {
return nil, err return nil, err
} }
} else { } else {
if _, err := writer.Write([]byte{userAuthVersion, authFailure}); err != nil { if _, err := writer.Write([]byte{userPassAuthVersion, authFailure}); err != nil {
return nil, err return nil, err
} }
return nil, UserAuthFailed return nil, UserAuthFailed
@ -132,7 +133,7 @@ func (s *Server) authenticate(conn io.Writer, bufConn io.Reader) (*AuthContext,
// noAcceptableAuth is used to handle when we have no eligible // noAcceptableAuth is used to handle when we have no eligible
// authentication mechanism // authentication mechanism
func noAcceptableAuth(conn io.Writer) error { func noAcceptableAuth(conn io.Writer) error {
conn.Write([]byte{socks5Version, noAcceptable}) conn.Write([]byte{socks5Version, NoAcceptable})
return NoSupportedAuth return NoSupportedAuth
} }

@ -113,7 +113,7 @@ func TestNoSupportedAuth(t *testing.T) {
} }
out := resp.Bytes() out := resp.Bytes()
if !bytes.Equal(out, []byte{socks5Version, noAcceptable}) { if !bytes.Equal(out, []byte{socks5Version, NoAcceptable}) {
t.Fatalf("bad: %v", out) t.Fatalf("bad: %v", out)
} }
} }

@ -6,6 +6,7 @@ import (
"io" "io"
"net" "net"
"strings" "strings"
"sync"
) )
var ( var (
@ -287,7 +288,7 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, req *Req
} }
lAddr, _ := net.ResolveUDPAddr("udp", ":0") lAddr, _ := net.ResolveUDPAddr("udp", ":0")
listen, err := net.ListenUDP("udp", lAddr) bindLn, err := net.ListenUDP("udp4", lAddr)
if err != nil { if err != nil {
if err := sendReply(writer, req.Header, serverFailure); err != nil { if err := sendReply(writer, req.Header, serverFailure); err != nil {
return fmt.Errorf("failed to send reply, %v", err) return fmt.Errorf("failed to send reply, %v", err)
@ -295,58 +296,62 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, req *Req
return fmt.Errorf("listen udp failed, %v", err) return fmt.Errorf("listen udp failed, %v", err)
} }
s.config.Logger.Errorf("target addr %v, listen addr: %s", targetUdp.RemoteAddr(), bindLn.LocalAddr())
// send BND.ADDR and BND.PORT, client must // send BND.ADDR and BND.PORT, client must
if err = sendReply(writer, req.Header, successReply, listen.LocalAddr()); err != nil { if err = sendReply(writer, req.Header, successReply, bindLn.LocalAddr()); err != nil {
return fmt.Errorf("failed to send reply, %v", err) return fmt.Errorf("failed to send reply, %v", err)
} }
go func() { go func() {
// read from client and write to remote server // read from client and write to remote server
conns := sync.Map{}
buf := s.bufferPool.Get() buf := s.bufferPool.Get()
defer s.bufferPool.Put(buf) defer s.bufferPool.Put(buf)
for { for {
n, _, err := listen.ReadFrom(buf[:cap(buf)]) n, srcAddr, err := bindLn.ReadFrom(buf[:cap(buf)])
if err != nil { if err != nil {
s.config.Logger.Errorf("read data from %s failed, %v", listen.LocalAddr(), err) s.config.Logger.Errorf("read data from bind listen address %s failed, %v", bindLn.LocalAddr(), err)
return return
} }
// 把消息写给remote sever
if _, err := targetUdp.Write(buf[:n]); err != nil { if _, err := targetUdp.Write(buf[:n]); err != nil {
s.config.Logger.Errorf("write data to remote %s failed, %v", targetUdp.RemoteAddr(), err) s.config.Logger.Errorf("write data to remote %s failed, %v", targetUdp.RemoteAddr(), err)
return return
} }
}
}()
go func() { if _, ok := conns.LoadOrStore(srcAddr.String(), struct{}{}); !ok {
// read from remote server and write to client go func() {
buf := s.bufferPool.Get() // read from remote server and write to client
defer s.bufferPool.Put(buf) buf := s.bufferPool.Get()
for { defer s.bufferPool.Put(buf)
n, _, err := targetUdp.ReadFromUDP(buf[:cap(buf)]) for {
if err != nil { n, _, err := targetUdp.ReadFrom(buf[:cap(buf)])
s.config.Logger.Errorf("read data from remote %s failed, %v", targetUdp.RemoteAddr(), err) if err != nil {
return s.config.Logger.Errorf("read data from remote %s failed, %v", targetUdp.RemoteAddr(), err)
} return
if _, err := listen.Write(buf[:n]); err != nil { }
s.config.Logger.Errorf("write data to remote %s failed, %v", targetUdp.RemoteAddr(), err)
return if _, err := bindLn.WriteTo(buf[:n], srcAddr); err != nil {
s.config.Logger.Errorf("write data to client %s failed, %v", bindLn.LocalAddr(), err)
return
}
}
}()
} }
} }
}() }()
buf := s.bufferPool.Get() buf := s.bufferPool.Get()
defer s.bufferPool.Put(buf) defer func() {
s.bufferPool.Put(buf)
}()
for { for {
_, err := req.bufConn.Read(buf) _, err := req.bufConn.Read(buf)
if err != nil { if err != nil {
return err return err
} }
} }
//// TODO: Support associate
//if err := sendReply(writer, req.Header, commandNotSupported); err != nil {
// return fmt.Errorf("failed to send reply, %v", err)
//}
//return nil
} }
// sendReply is used to send a reply message // sendReply is used to send a reply message
@ -366,12 +371,18 @@ func sendReply(w io.Writer, head Header, resp uint8, bindAddr ...net.Addr) error
head.addrType = ipv4Address head.addrType = ipv4Address
head.Address.IP = []byte{0, 0, 0, 0} head.Address.IP = []byte{0, 0, 0, 0}
head.Address.Port = 0 head.Address.Port = 0
} else if tcpAddr, ok := bindAddr[0].(*net.TCPAddr); !ok || tcpAddr == nil {
head.addrType = ipv4Address
head.Address.IP = []byte{0, 0, 0, 0}
head.Address.Port = 0
} else { } else {
addrSpec := AddrSpec{IP: tcpAddr.IP, Port: tcpAddr.Port} addrSpec := AddrSpec{}
if tcpAddr, ok := bindAddr[0].(*net.TCPAddr); ok && tcpAddr != nil {
addrSpec.IP = tcpAddr.IP
addrSpec.Port = tcpAddr.Port
} else if udpAddr, ok := bindAddr[0].(*net.UDPAddr); ok && udpAddr != nil {
addrSpec.IP = udpAddr.IP
addrSpec.Port = udpAddr.Port
} else {
addrSpec.IP = []byte{0, 0, 0, 0}
addrSpec.Port = 0
}
switch { switch {
case addrSpec.FQDN != "": case addrSpec.FQDN != "":
head.addrType = fqdnAddress head.addrType = fqdnAddress
@ -388,8 +399,8 @@ func sendReply(w io.Writer, head Header, resp uint8, bindAddr ...net.Addr) error
default: default:
return fmt.Errorf("failed to format address[%v]", bindAddr) return fmt.Errorf("failed to format address[%v]", bindAddr)
} }
}
}
// Send the message // Send the message
_, err := w.Write(head.Bytes()) _, err := w.Write(head.Bytes())
return err return err

@ -79,9 +79,15 @@ func New(conf *Config) (*Server, error) {
conf.Logger = NewLogger(log.New(ioutil.Discard, "socks5: ", log.LstdFlags)) conf.Logger = NewLogger(log.New(ioutil.Discard, "socks5: ", log.LstdFlags))
} }
if conf.Dial == nil {
conf.Dial = func(ctx context.Context, net_, addr string) (net.Conn, error) {
return net.Dial(net_, addr)
}
}
server := &Server{ server := &Server{
config: conf, config: conf,
bufferPool: newPool(32 * 1024), bufferPool: newPool(2 * 1024),
} }
server.authMethods = make(map[uint8]Authenticator) server.authMethods = make(map[uint8]Authenticator)

@ -2,7 +2,6 @@ package socks5
import ( import (
"bytes" "bytes"
"encoding/binary"
"io" "io"
"log" "log"
"net" "net"
@ -37,10 +36,9 @@ func TestSOCKS5_Connect(t *testing.T) {
lAddr := l.Addr().(*net.TCPAddr) lAddr := l.Addr().(*net.TCPAddr)
// Create a socks server // Create a socks server
creds := StaticCredentials{ cator := UserPassAuthenticator{
"foo": "bar", Credentials: StaticCredentials{"foo": "bar"},
} }
cator := UserPassAuthenticator{Credentials: creds}
conf := &Config{ conf := &Config{
AuthMethods: []Authenticator{cator}, AuthMethods: []Authenticator{cator},
Logger: NewLogger(log.New(os.Stdout, "socks5: ", log.LstdFlags)), Logger: NewLogger(log.New(os.Stdout, "socks5: ", log.LstdFlags)),
@ -65,16 +63,21 @@ func TestSOCKS5_Connect(t *testing.T) {
} }
// Connect, auth and connec to local // Connect, auth and connec to local
req := bytes.NewBuffer(nil) req := new(bytes.Buffer)
req.Write([]byte{5}) req.Write([]byte{socks5Version, 2, NoAuth, UserPassAuth})
req.Write([]byte{2, NoAuth, UserPassAuth}) req.Write([]byte{userPassAuthVersion, 3, 'f', 'o', 'o', 3, 'b', 'a', 'r'})
req.Write([]byte{1, 3, 'f', 'o', 'o', 3, 'b', 'a', 'r'}) reqHead := Header{
req.Write([]byte{5, 1, 0, 1, 127, 0, 0, 1}) Version: socks5Version,
Command: ConnectCommand,
port := []byte{0, 0} Reserved: 0,
binary.BigEndian.PutUint16(port, uint16(lAddr.Port)) Address: AddrSpec{
req.Write(port) "",
net.ParseIP("127.0.0.1"),
lAddr.Port,
},
addrType: ipv4Address,
}
req.Write(reqHead.Bytes())
// Send a ping // Send a ping
req.Write([]byte("ping")) req.Write([]byte("ping"))
@ -83,23 +86,31 @@ func TestSOCKS5_Connect(t *testing.T) {
// Verify response // Verify response
expected := []byte{ expected := []byte{
socks5Version, UserPassAuth, socks5Version, UserPassAuth, // use user password auth
1, authSuccess, userPassAuthVersion, authSuccess, // response auth success
5,
0,
0,
1,
127, 0, 0, 1,
0, 0,
'p', 'o', 'n', 'g',
} }
out := make([]byte, len(expected)) rspHead := Header{
Version: socks5Version,
Command: successReply,
Reserved: 0,
Address: AddrSpec{
"",
net.ParseIP("127.0.0.1"),
0, // Ignore the port
},
addrType: ipv4Address,
}
expected = append(expected, rspHead.Bytes()...)
expected = append(expected, []byte("pong")...)
out := make([]byte, len(expected))
conn.SetDeadline(time.Now().Add(time.Second)) conn.SetDeadline(time.Now().Add(time.Second))
if _, err := io.ReadAtLeast(conn, out, len(out)); err != nil { if _, err := io.ReadFull(conn, out); err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
t.Logf("proxy bind port: %d", buildPort(out[12], out[13]))
// Ignore the port // Ignore the port
out[12] = 0 out[12] = 0
out[13] = 0 out[13] = 0
@ -108,3 +119,127 @@ func TestSOCKS5_Connect(t *testing.T) {
t.Fatalf("bad: %v", out) t.Fatalf("bad: %v", out)
} }
} }
func TestSOCKS5_Associate(t *testing.T) {
locIP := net.ParseIP("127.0.0.1")
// Create a local listener
lAddr := &net.UDPAddr{
IP: locIP,
Port: 12398,
}
l, err := net.ListenUDP("udp4", lAddr)
if err != nil {
t.Fatalf("err: %v", err)
}
defer l.Close()
go func() {
buf := make([]byte, 2048)
for {
n, remote, err := l.ReadFrom(buf)
if err != nil {
return
}
if !bytes.Equal(buf[:n], []byte("ping")) {
t.Fatalf("bad: %v", buf)
}
l.WriteTo([]byte("pong"), remote)
}
}()
// Create a socks server
cator := UserPassAuthenticator{Credentials: StaticCredentials{"foo": "bar"}}
conf := &Config{
AuthMethods: []Authenticator{cator},
Logger: NewLogger(log.New(os.Stdout, "socks5: ", log.LstdFlags)),
}
serv, err := New(conf)
if err != nil {
t.Fatalf("err: %v", err)
}
// Start listening
go func() {
if err := serv.ListenAndServe("tcp", "127.0.0.1:12355"); err != nil {
t.Fatalf("err: %v", err)
}
}()
time.Sleep(10 * time.Millisecond)
// Get a local conn
conn, err := net.Dial("tcp", "127.0.0.1:12355")
if err != nil {
t.Fatalf("err: %v", err)
}
// Connect, auth and connec to local
req := new(bytes.Buffer)
req.Write([]byte{socks5Version, 2, NoAuth, UserPassAuth})
req.Write([]byte{userPassAuthVersion, 3, 'f', 'o', 'o', 3, 'b', 'a', 'r'})
reqHead := Header{
Version: socks5Version,
Command: AssociateCommand,
Reserved: 0,
Address: AddrSpec{
"",
locIP,
lAddr.Port,
},
addrType: ipv4Address,
}
req.Write(reqHead.Bytes())
// Send all the bytes
conn.Write(req.Bytes())
// Verify response
expected := []byte{
socks5Version, UserPassAuth, // use user password auth
userPassAuthVersion, authSuccess, // response auth success
}
rspHead := Header{
Version: socks5Version,
Command: successReply,
Reserved: 0,
Address: AddrSpec{
"",
net.ParseIP("0.0.0.0"),
0, // Ignore the port
},
addrType: ipv4Address,
}
expected = append(expected, rspHead.Bytes()...)
out := make([]byte, len(expected))
conn.SetDeadline(time.Now().Add(time.Second))
if _, err := io.ReadFull(conn, out); err != nil {
t.Fatalf("err: %v", err)
}
// Ignore the port
proxyBindPort := buildPort(out[12], out[13])
out[12] = 0
out[13] = 0
t.Logf("proxy bind listen port: %d", proxyBindPort)
if !bytes.Equal(out, expected) {
t.Fatalf("bad: %v", out)
}
udpConn, err := net.DialUDP("udp4", nil, &net.UDPAddr{
IP: locIP,
//Port: lAddr.Port,
Port: proxyBindPort,
})
if err != nil {
t.Fatalf("bad dial: %v", err)
}
// Send a ping
udpConn.Write([]byte("ping"))
response := make([]byte, 1024)
n, _, err := udpConn.ReadFrom(response)
if !bytes.Equal(response[:n], []byte("pong")) {
t.Fatalf("bad udp read: %v", string(response[:n]))
}
time.Sleep(time.Second * 1)
}