implement raw udp

This commit is contained in:
mo 2020-04-20 16:19:51 +08:00
parent 32060b0332
commit be8e44e47d
5 changed files with 221 additions and 68 deletions

21
auth.go

@ -6,12 +6,13 @@ import (
)
const (
NoAuth = uint8(0)
noAcceptable = uint8(255)
UserPassAuth = uint8(2)
userAuthVersion = uint8(1)
authSuccess = uint8(0)
authFailure = uint8(1)
NoAuth = uint8(0)
GSSAPI = uint8(1)
UserPassAuth = uint8(2)
NoAcceptable = uint8(255)
userPassAuthVersion = uint8(1)
authSuccess = uint8(0)
authFailure = uint8(1)
)
var (
@ -70,7 +71,7 @@ func (a UserPassAuthenticator) Authenticate(reader io.Reader, writer io.Writer)
}
// Ensure we are compatible
if header[0] != userAuthVersion {
if header[0] != userPassAuthVersion {
return nil, fmt.Errorf("Unsupported auth version: %v", header[0])
}
@ -95,11 +96,11 @@ func (a UserPassAuthenticator) Authenticate(reader io.Reader, writer io.Writer)
// Verify the password
if a.Credentials.Valid(string(user), string(pass)) {
if _, err := writer.Write([]byte{userAuthVersion, authSuccess}); err != nil {
if _, err := writer.Write([]byte{userPassAuthVersion, authSuccess}); err != nil {
return nil, err
}
} else {
if _, err := writer.Write([]byte{userAuthVersion, authFailure}); err != nil {
if _, err := writer.Write([]byte{userPassAuthVersion, authFailure}); err != nil {
return nil, err
}
return nil, UserAuthFailed
@ -132,7 +133,7 @@ func (s *Server) authenticate(conn io.Writer, bufConn io.Reader) (*AuthContext,
// noAcceptableAuth is used to handle when we have no eligible
// authentication mechanism
func noAcceptableAuth(conn io.Writer) error {
conn.Write([]byte{socks5Version, noAcceptable})
conn.Write([]byte{socks5Version, NoAcceptable})
return NoSupportedAuth
}

@ -113,7 +113,7 @@ func TestNoSupportedAuth(t *testing.T) {
}
out := resp.Bytes()
if !bytes.Equal(out, []byte{socks5Version, noAcceptable}) {
if !bytes.Equal(out, []byte{socks5Version, NoAcceptable}) {
t.Fatalf("bad: %v", out)
}
}

@ -6,6 +6,7 @@ import (
"io"
"net"
"strings"
"sync"
)
var (
@ -287,7 +288,7 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, req *Req
}
lAddr, _ := net.ResolveUDPAddr("udp", ":0")
listen, err := net.ListenUDP("udp", lAddr)
bindLn, err := net.ListenUDP("udp4", lAddr)
if err != nil {
if err := sendReply(writer, req.Header, serverFailure); err != nil {
return fmt.Errorf("failed to send reply, %v", err)
@ -295,58 +296,62 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, req *Req
return fmt.Errorf("listen udp failed, %v", err)
}
s.config.Logger.Errorf("target addr %v, listen addr: %s", targetUdp.RemoteAddr(), bindLn.LocalAddr())
// send BND.ADDR and BND.PORT, client must
if err = sendReply(writer, req.Header, successReply, listen.LocalAddr()); err != nil {
if err = sendReply(writer, req.Header, successReply, bindLn.LocalAddr()); err != nil {
return fmt.Errorf("failed to send reply, %v", err)
}
go func() {
// read from client and write to remote server
conns := sync.Map{}
buf := s.bufferPool.Get()
defer s.bufferPool.Put(buf)
for {
n, _, err := listen.ReadFrom(buf[:cap(buf)])
n, srcAddr, err := bindLn.ReadFrom(buf[:cap(buf)])
if err != nil {
s.config.Logger.Errorf("read data from %s failed, %v", listen.LocalAddr(), err)
s.config.Logger.Errorf("read data from bind listen address %s failed, %v", bindLn.LocalAddr(), err)
return
}
// 把消息写给remote sever
if _, err := targetUdp.Write(buf[:n]); err != nil {
s.config.Logger.Errorf("write data to remote %s failed, %v", targetUdp.RemoteAddr(), err)
return
}
}
}()
go func() {
// read from remote server and write to client
buf := s.bufferPool.Get()
defer s.bufferPool.Put(buf)
for {
n, _, err := targetUdp.ReadFromUDP(buf[:cap(buf)])
if err != nil {
s.config.Logger.Errorf("read data from remote %s failed, %v", targetUdp.RemoteAddr(), err)
return
}
if _, err := listen.Write(buf[:n]); err != nil {
s.config.Logger.Errorf("write data to remote %s failed, %v", targetUdp.RemoteAddr(), err)
return
if _, ok := conns.LoadOrStore(srcAddr.String(), struct{}{}); !ok {
go func() {
// read from remote server and write to client
buf := s.bufferPool.Get()
defer s.bufferPool.Put(buf)
for {
n, _, err := targetUdp.ReadFrom(buf[:cap(buf)])
if err != nil {
s.config.Logger.Errorf("read data from remote %s failed, %v", targetUdp.RemoteAddr(), err)
return
}
if _, err := bindLn.WriteTo(buf[:n], srcAddr); err != nil {
s.config.Logger.Errorf("write data to client %s failed, %v", bindLn.LocalAddr(), err)
return
}
}
}()
}
}
}()
buf := s.bufferPool.Get()
defer s.bufferPool.Put(buf)
defer func() {
s.bufferPool.Put(buf)
}()
for {
_, err := req.bufConn.Read(buf)
if err != nil {
return err
}
}
//// TODO: Support associate
//if err := sendReply(writer, req.Header, commandNotSupported); err != nil {
// return fmt.Errorf("failed to send reply, %v", err)
//}
//return nil
}
// sendReply is used to send a reply message
@ -366,12 +371,18 @@ func sendReply(w io.Writer, head Header, resp uint8, bindAddr ...net.Addr) error
head.addrType = ipv4Address
head.Address.IP = []byte{0, 0, 0, 0}
head.Address.Port = 0
} else if tcpAddr, ok := bindAddr[0].(*net.TCPAddr); !ok || tcpAddr == nil {
head.addrType = ipv4Address
head.Address.IP = []byte{0, 0, 0, 0}
head.Address.Port = 0
} else {
addrSpec := AddrSpec{IP: tcpAddr.IP, Port: tcpAddr.Port}
addrSpec := AddrSpec{}
if tcpAddr, ok := bindAddr[0].(*net.TCPAddr); ok && tcpAddr != nil {
addrSpec.IP = tcpAddr.IP
addrSpec.Port = tcpAddr.Port
} else if udpAddr, ok := bindAddr[0].(*net.UDPAddr); ok && udpAddr != nil {
addrSpec.IP = udpAddr.IP
addrSpec.Port = udpAddr.Port
} else {
addrSpec.IP = []byte{0, 0, 0, 0}
addrSpec.Port = 0
}
switch {
case addrSpec.FQDN != "":
head.addrType = fqdnAddress
@ -388,8 +399,8 @@ func sendReply(w io.Writer, head Header, resp uint8, bindAddr ...net.Addr) error
default:
return fmt.Errorf("failed to format address[%v]", bindAddr)
}
}
}
// Send the message
_, err := w.Write(head.Bytes())
return err

@ -79,9 +79,15 @@ func New(conf *Config) (*Server, error) {
conf.Logger = NewLogger(log.New(ioutil.Discard, "socks5: ", log.LstdFlags))
}
if conf.Dial == nil {
conf.Dial = func(ctx context.Context, net_, addr string) (net.Conn, error) {
return net.Dial(net_, addr)
}
}
server := &Server{
config: conf,
bufferPool: newPool(32 * 1024),
bufferPool: newPool(2 * 1024),
}
server.authMethods = make(map[uint8]Authenticator)

@ -2,7 +2,6 @@ package socks5
import (
"bytes"
"encoding/binary"
"io"
"log"
"net"
@ -37,10 +36,9 @@ func TestSOCKS5_Connect(t *testing.T) {
lAddr := l.Addr().(*net.TCPAddr)
// Create a socks server
creds := StaticCredentials{
"foo": "bar",
cator := UserPassAuthenticator{
Credentials: StaticCredentials{"foo": "bar"},
}
cator := UserPassAuthenticator{Credentials: creds}
conf := &Config{
AuthMethods: []Authenticator{cator},
Logger: NewLogger(log.New(os.Stdout, "socks5: ", log.LstdFlags)),
@ -65,16 +63,21 @@ func TestSOCKS5_Connect(t *testing.T) {
}
// Connect, auth and connec to local
req := bytes.NewBuffer(nil)
req.Write([]byte{5})
req.Write([]byte{2, NoAuth, UserPassAuth})
req.Write([]byte{1, 3, 'f', 'o', 'o', 3, 'b', 'a', 'r'})
req.Write([]byte{5, 1, 0, 1, 127, 0, 0, 1})
port := []byte{0, 0}
binary.BigEndian.PutUint16(port, uint16(lAddr.Port))
req.Write(port)
req := new(bytes.Buffer)
req.Write([]byte{socks5Version, 2, NoAuth, UserPassAuth})
req.Write([]byte{userPassAuthVersion, 3, 'f', 'o', 'o', 3, 'b', 'a', 'r'})
reqHead := Header{
Version: socks5Version,
Command: ConnectCommand,
Reserved: 0,
Address: AddrSpec{
"",
net.ParseIP("127.0.0.1"),
lAddr.Port,
},
addrType: ipv4Address,
}
req.Write(reqHead.Bytes())
// Send a ping
req.Write([]byte("ping"))
@ -83,23 +86,31 @@ func TestSOCKS5_Connect(t *testing.T) {
// Verify response
expected := []byte{
socks5Version, UserPassAuth,
1, authSuccess,
5,
0,
0,
1,
127, 0, 0, 1,
0, 0,
'p', 'o', 'n', 'g',
socks5Version, UserPassAuth, // use user password auth
userPassAuthVersion, authSuccess, // response auth success
}
out := make([]byte, len(expected))
rspHead := Header{
Version: socks5Version,
Command: successReply,
Reserved: 0,
Address: AddrSpec{
"",
net.ParseIP("127.0.0.1"),
0, // Ignore the port
},
addrType: ipv4Address,
}
expected = append(expected, rspHead.Bytes()...)
expected = append(expected, []byte("pong")...)
out := make([]byte, len(expected))
conn.SetDeadline(time.Now().Add(time.Second))
if _, err := io.ReadAtLeast(conn, out, len(out)); err != nil {
if _, err := io.ReadFull(conn, out); err != nil {
t.Fatalf("err: %v", err)
}
t.Logf("proxy bind port: %d", buildPort(out[12], out[13]))
// Ignore the port
out[12] = 0
out[13] = 0
@ -108,3 +119,127 @@ func TestSOCKS5_Connect(t *testing.T) {
t.Fatalf("bad: %v", out)
}
}
func TestSOCKS5_Associate(t *testing.T) {
locIP := net.ParseIP("127.0.0.1")
// Create a local listener
lAddr := &net.UDPAddr{
IP: locIP,
Port: 12398,
}
l, err := net.ListenUDP("udp4", lAddr)
if err != nil {
t.Fatalf("err: %v", err)
}
defer l.Close()
go func() {
buf := make([]byte, 2048)
for {
n, remote, err := l.ReadFrom(buf)
if err != nil {
return
}
if !bytes.Equal(buf[:n], []byte("ping")) {
t.Fatalf("bad: %v", buf)
}
l.WriteTo([]byte("pong"), remote)
}
}()
// Create a socks server
cator := UserPassAuthenticator{Credentials: StaticCredentials{"foo": "bar"}}
conf := &Config{
AuthMethods: []Authenticator{cator},
Logger: NewLogger(log.New(os.Stdout, "socks5: ", log.LstdFlags)),
}
serv, err := New(conf)
if err != nil {
t.Fatalf("err: %v", err)
}
// Start listening
go func() {
if err := serv.ListenAndServe("tcp", "127.0.0.1:12355"); err != nil {
t.Fatalf("err: %v", err)
}
}()
time.Sleep(10 * time.Millisecond)
// Get a local conn
conn, err := net.Dial("tcp", "127.0.0.1:12355")
if err != nil {
t.Fatalf("err: %v", err)
}
// Connect, auth and connec to local
req := new(bytes.Buffer)
req.Write([]byte{socks5Version, 2, NoAuth, UserPassAuth})
req.Write([]byte{userPassAuthVersion, 3, 'f', 'o', 'o', 3, 'b', 'a', 'r'})
reqHead := Header{
Version: socks5Version,
Command: AssociateCommand,
Reserved: 0,
Address: AddrSpec{
"",
locIP,
lAddr.Port,
},
addrType: ipv4Address,
}
req.Write(reqHead.Bytes())
// Send all the bytes
conn.Write(req.Bytes())
// Verify response
expected := []byte{
socks5Version, UserPassAuth, // use user password auth
userPassAuthVersion, authSuccess, // response auth success
}
rspHead := Header{
Version: socks5Version,
Command: successReply,
Reserved: 0,
Address: AddrSpec{
"",
net.ParseIP("0.0.0.0"),
0, // Ignore the port
},
addrType: ipv4Address,
}
expected = append(expected, rspHead.Bytes()...)
out := make([]byte, len(expected))
conn.SetDeadline(time.Now().Add(time.Second))
if _, err := io.ReadFull(conn, out); err != nil {
t.Fatalf("err: %v", err)
}
// Ignore the port
proxyBindPort := buildPort(out[12], out[13])
out[12] = 0
out[13] = 0
t.Logf("proxy bind listen port: %d", proxyBindPort)
if !bytes.Equal(out, expected) {
t.Fatalf("bad: %v", out)
}
udpConn, err := net.DialUDP("udp4", nil, &net.UDPAddr{
IP: locIP,
//Port: lAddr.Port,
Port: proxyBindPort,
})
if err != nil {
t.Fatalf("bad dial: %v", err)
}
// Send a ping
udpConn.Write([]byte("ping"))
response := make([]byte, 1024)
n, _, err := udpConn.ReadFrom(response)
if !bytes.Equal(response[:n], []byte("pong")) {
t.Fatalf("bad udp read: %v", string(response[:n]))
}
time.Sleep(time.Second * 1)
}