implement raw udp
This commit is contained in:
parent
32060b0332
commit
be8e44e47d
21
auth.go
21
auth.go
@ -6,12 +6,13 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
NoAuth = uint8(0)
|
NoAuth = uint8(0)
|
||||||
noAcceptable = uint8(255)
|
GSSAPI = uint8(1)
|
||||||
UserPassAuth = uint8(2)
|
UserPassAuth = uint8(2)
|
||||||
userAuthVersion = uint8(1)
|
NoAcceptable = uint8(255)
|
||||||
authSuccess = uint8(0)
|
userPassAuthVersion = uint8(1)
|
||||||
authFailure = uint8(1)
|
authSuccess = uint8(0)
|
||||||
|
authFailure = uint8(1)
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@ -70,7 +71,7 @@ func (a UserPassAuthenticator) Authenticate(reader io.Reader, writer io.Writer)
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Ensure we are compatible
|
// Ensure we are compatible
|
||||||
if header[0] != userAuthVersion {
|
if header[0] != userPassAuthVersion {
|
||||||
return nil, fmt.Errorf("Unsupported auth version: %v", header[0])
|
return nil, fmt.Errorf("Unsupported auth version: %v", header[0])
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -95,11 +96,11 @@ func (a UserPassAuthenticator) Authenticate(reader io.Reader, writer io.Writer)
|
|||||||
|
|
||||||
// Verify the password
|
// Verify the password
|
||||||
if a.Credentials.Valid(string(user), string(pass)) {
|
if a.Credentials.Valid(string(user), string(pass)) {
|
||||||
if _, err := writer.Write([]byte{userAuthVersion, authSuccess}); err != nil {
|
if _, err := writer.Write([]byte{userPassAuthVersion, authSuccess}); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if _, err := writer.Write([]byte{userAuthVersion, authFailure}); err != nil {
|
if _, err := writer.Write([]byte{userPassAuthVersion, authFailure}); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return nil, UserAuthFailed
|
return nil, UserAuthFailed
|
||||||
@ -132,7 +133,7 @@ func (s *Server) authenticate(conn io.Writer, bufConn io.Reader) (*AuthContext,
|
|||||||
// noAcceptableAuth is used to handle when we have no eligible
|
// noAcceptableAuth is used to handle when we have no eligible
|
||||||
// authentication mechanism
|
// authentication mechanism
|
||||||
func noAcceptableAuth(conn io.Writer) error {
|
func noAcceptableAuth(conn io.Writer) error {
|
||||||
conn.Write([]byte{socks5Version, noAcceptable})
|
conn.Write([]byte{socks5Version, NoAcceptable})
|
||||||
return NoSupportedAuth
|
return NoSupportedAuth
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -113,7 +113,7 @@ func TestNoSupportedAuth(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
out := resp.Bytes()
|
out := resp.Bytes()
|
||||||
if !bytes.Equal(out, []byte{socks5Version, noAcceptable}) {
|
if !bytes.Equal(out, []byte{socks5Version, NoAcceptable}) {
|
||||||
t.Fatalf("bad: %v", out)
|
t.Fatalf("bad: %v", out)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
73
request.go
73
request.go
@ -6,6 +6,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@ -287,7 +288,7 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, req *Req
|
|||||||
}
|
}
|
||||||
|
|
||||||
lAddr, _ := net.ResolveUDPAddr("udp", ":0")
|
lAddr, _ := net.ResolveUDPAddr("udp", ":0")
|
||||||
listen, err := net.ListenUDP("udp", lAddr)
|
bindLn, err := net.ListenUDP("udp4", lAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err := sendReply(writer, req.Header, serverFailure); err != nil {
|
if err := sendReply(writer, req.Header, serverFailure); err != nil {
|
||||||
return fmt.Errorf("failed to send reply, %v", err)
|
return fmt.Errorf("failed to send reply, %v", err)
|
||||||
@ -295,58 +296,62 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, req *Req
|
|||||||
return fmt.Errorf("listen udp failed, %v", err)
|
return fmt.Errorf("listen udp failed, %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
s.config.Logger.Errorf("target addr %v, listen addr: %s", targetUdp.RemoteAddr(), bindLn.LocalAddr())
|
||||||
// send BND.ADDR and BND.PORT, client must
|
// send BND.ADDR and BND.PORT, client must
|
||||||
if err = sendReply(writer, req.Header, successReply, listen.LocalAddr()); err != nil {
|
if err = sendReply(writer, req.Header, successReply, bindLn.LocalAddr()); err != nil {
|
||||||
return fmt.Errorf("failed to send reply, %v", err)
|
return fmt.Errorf("failed to send reply, %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
// read from client and write to remote server
|
// read from client and write to remote server
|
||||||
|
conns := sync.Map{}
|
||||||
buf := s.bufferPool.Get()
|
buf := s.bufferPool.Get()
|
||||||
defer s.bufferPool.Put(buf)
|
defer s.bufferPool.Put(buf)
|
||||||
for {
|
for {
|
||||||
n, _, err := listen.ReadFrom(buf[:cap(buf)])
|
n, srcAddr, err := bindLn.ReadFrom(buf[:cap(buf)])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.config.Logger.Errorf("read data from %s failed, %v", listen.LocalAddr(), err)
|
s.config.Logger.Errorf("read data from bind listen address %s failed, %v", bindLn.LocalAddr(), err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 把消息写给remote sever
|
||||||
if _, err := targetUdp.Write(buf[:n]); err != nil {
|
if _, err := targetUdp.Write(buf[:n]); err != nil {
|
||||||
s.config.Logger.Errorf("write data to remote %s failed, %v", targetUdp.RemoteAddr(), err)
|
s.config.Logger.Errorf("write data to remote %s failed, %v", targetUdp.RemoteAddr(), err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
go func() {
|
if _, ok := conns.LoadOrStore(srcAddr.String(), struct{}{}); !ok {
|
||||||
// read from remote server and write to client
|
go func() {
|
||||||
buf := s.bufferPool.Get()
|
// read from remote server and write to client
|
||||||
defer s.bufferPool.Put(buf)
|
buf := s.bufferPool.Get()
|
||||||
for {
|
defer s.bufferPool.Put(buf)
|
||||||
n, _, err := targetUdp.ReadFromUDP(buf[:cap(buf)])
|
for {
|
||||||
if err != nil {
|
n, _, err := targetUdp.ReadFrom(buf[:cap(buf)])
|
||||||
s.config.Logger.Errorf("read data from remote %s failed, %v", targetUdp.RemoteAddr(), err)
|
if err != nil {
|
||||||
return
|
s.config.Logger.Errorf("read data from remote %s failed, %v", targetUdp.RemoteAddr(), err)
|
||||||
}
|
return
|
||||||
if _, err := listen.Write(buf[:n]); err != nil {
|
}
|
||||||
s.config.Logger.Errorf("write data to remote %s failed, %v", targetUdp.RemoteAddr(), err)
|
|
||||||
return
|
if _, err := bindLn.WriteTo(buf[:n], srcAddr); err != nil {
|
||||||
|
s.config.Logger.Errorf("write data to client %s failed, %v", bindLn.LocalAddr(), err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
buf := s.bufferPool.Get()
|
buf := s.bufferPool.Get()
|
||||||
defer s.bufferPool.Put(buf)
|
defer func() {
|
||||||
|
s.bufferPool.Put(buf)
|
||||||
|
}()
|
||||||
for {
|
for {
|
||||||
_, err := req.bufConn.Read(buf)
|
_, err := req.bufConn.Read(buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
//// TODO: Support associate
|
|
||||||
//if err := sendReply(writer, req.Header, commandNotSupported); err != nil {
|
|
||||||
// return fmt.Errorf("failed to send reply, %v", err)
|
|
||||||
//}
|
|
||||||
//return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// sendReply is used to send a reply message
|
// sendReply is used to send a reply message
|
||||||
@ -366,12 +371,18 @@ func sendReply(w io.Writer, head Header, resp uint8, bindAddr ...net.Addr) error
|
|||||||
head.addrType = ipv4Address
|
head.addrType = ipv4Address
|
||||||
head.Address.IP = []byte{0, 0, 0, 0}
|
head.Address.IP = []byte{0, 0, 0, 0}
|
||||||
head.Address.Port = 0
|
head.Address.Port = 0
|
||||||
} else if tcpAddr, ok := bindAddr[0].(*net.TCPAddr); !ok || tcpAddr == nil {
|
|
||||||
head.addrType = ipv4Address
|
|
||||||
head.Address.IP = []byte{0, 0, 0, 0}
|
|
||||||
head.Address.Port = 0
|
|
||||||
} else {
|
} else {
|
||||||
addrSpec := AddrSpec{IP: tcpAddr.IP, Port: tcpAddr.Port}
|
addrSpec := AddrSpec{}
|
||||||
|
if tcpAddr, ok := bindAddr[0].(*net.TCPAddr); ok && tcpAddr != nil {
|
||||||
|
addrSpec.IP = tcpAddr.IP
|
||||||
|
addrSpec.Port = tcpAddr.Port
|
||||||
|
} else if udpAddr, ok := bindAddr[0].(*net.UDPAddr); ok && udpAddr != nil {
|
||||||
|
addrSpec.IP = udpAddr.IP
|
||||||
|
addrSpec.Port = udpAddr.Port
|
||||||
|
} else {
|
||||||
|
addrSpec.IP = []byte{0, 0, 0, 0}
|
||||||
|
addrSpec.Port = 0
|
||||||
|
}
|
||||||
switch {
|
switch {
|
||||||
case addrSpec.FQDN != "":
|
case addrSpec.FQDN != "":
|
||||||
head.addrType = fqdnAddress
|
head.addrType = fqdnAddress
|
||||||
@ -388,8 +399,8 @@ func sendReply(w io.Writer, head Header, resp uint8, bindAddr ...net.Addr) error
|
|||||||
default:
|
default:
|
||||||
return fmt.Errorf("failed to format address[%v]", bindAddr)
|
return fmt.Errorf("failed to format address[%v]", bindAddr)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
|
}
|
||||||
// Send the message
|
// Send the message
|
||||||
_, err := w.Write(head.Bytes())
|
_, err := w.Write(head.Bytes())
|
||||||
return err
|
return err
|
||||||
|
@ -79,9 +79,15 @@ func New(conf *Config) (*Server, error) {
|
|||||||
conf.Logger = NewLogger(log.New(ioutil.Discard, "socks5: ", log.LstdFlags))
|
conf.Logger = NewLogger(log.New(ioutil.Discard, "socks5: ", log.LstdFlags))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if conf.Dial == nil {
|
||||||
|
conf.Dial = func(ctx context.Context, net_, addr string) (net.Conn, error) {
|
||||||
|
return net.Dial(net_, addr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
server := &Server{
|
server := &Server{
|
||||||
config: conf,
|
config: conf,
|
||||||
bufferPool: newPool(32 * 1024),
|
bufferPool: newPool(2 * 1024),
|
||||||
}
|
}
|
||||||
|
|
||||||
server.authMethods = make(map[uint8]Authenticator)
|
server.authMethods = make(map[uint8]Authenticator)
|
||||||
|
185
socks5_test.go
185
socks5_test.go
@ -2,7 +2,6 @@ package socks5
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/binary"
|
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
@ -37,10 +36,9 @@ func TestSOCKS5_Connect(t *testing.T) {
|
|||||||
lAddr := l.Addr().(*net.TCPAddr)
|
lAddr := l.Addr().(*net.TCPAddr)
|
||||||
|
|
||||||
// Create a socks server
|
// Create a socks server
|
||||||
creds := StaticCredentials{
|
cator := UserPassAuthenticator{
|
||||||
"foo": "bar",
|
Credentials: StaticCredentials{"foo": "bar"},
|
||||||
}
|
}
|
||||||
cator := UserPassAuthenticator{Credentials: creds}
|
|
||||||
conf := &Config{
|
conf := &Config{
|
||||||
AuthMethods: []Authenticator{cator},
|
AuthMethods: []Authenticator{cator},
|
||||||
Logger: NewLogger(log.New(os.Stdout, "socks5: ", log.LstdFlags)),
|
Logger: NewLogger(log.New(os.Stdout, "socks5: ", log.LstdFlags)),
|
||||||
@ -65,16 +63,21 @@ func TestSOCKS5_Connect(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Connect, auth and connec to local
|
// Connect, auth and connec to local
|
||||||
req := bytes.NewBuffer(nil)
|
req := new(bytes.Buffer)
|
||||||
req.Write([]byte{5})
|
req.Write([]byte{socks5Version, 2, NoAuth, UserPassAuth})
|
||||||
req.Write([]byte{2, NoAuth, UserPassAuth})
|
req.Write([]byte{userPassAuthVersion, 3, 'f', 'o', 'o', 3, 'b', 'a', 'r'})
|
||||||
req.Write([]byte{1, 3, 'f', 'o', 'o', 3, 'b', 'a', 'r'})
|
reqHead := Header{
|
||||||
req.Write([]byte{5, 1, 0, 1, 127, 0, 0, 1})
|
Version: socks5Version,
|
||||||
|
Command: ConnectCommand,
|
||||||
port := []byte{0, 0}
|
Reserved: 0,
|
||||||
binary.BigEndian.PutUint16(port, uint16(lAddr.Port))
|
Address: AddrSpec{
|
||||||
req.Write(port)
|
"",
|
||||||
|
net.ParseIP("127.0.0.1"),
|
||||||
|
lAddr.Port,
|
||||||
|
},
|
||||||
|
addrType: ipv4Address,
|
||||||
|
}
|
||||||
|
req.Write(reqHead.Bytes())
|
||||||
// Send a ping
|
// Send a ping
|
||||||
req.Write([]byte("ping"))
|
req.Write([]byte("ping"))
|
||||||
|
|
||||||
@ -83,23 +86,31 @@ func TestSOCKS5_Connect(t *testing.T) {
|
|||||||
|
|
||||||
// Verify response
|
// Verify response
|
||||||
expected := []byte{
|
expected := []byte{
|
||||||
socks5Version, UserPassAuth,
|
socks5Version, UserPassAuth, // use user password auth
|
||||||
1, authSuccess,
|
userPassAuthVersion, authSuccess, // response auth success
|
||||||
5,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
1,
|
|
||||||
127, 0, 0, 1,
|
|
||||||
0, 0,
|
|
||||||
'p', 'o', 'n', 'g',
|
|
||||||
}
|
}
|
||||||
out := make([]byte, len(expected))
|
rspHead := Header{
|
||||||
|
Version: socks5Version,
|
||||||
|
Command: successReply,
|
||||||
|
Reserved: 0,
|
||||||
|
Address: AddrSpec{
|
||||||
|
"",
|
||||||
|
net.ParseIP("127.0.0.1"),
|
||||||
|
0, // Ignore the port
|
||||||
|
},
|
||||||
|
addrType: ipv4Address,
|
||||||
|
}
|
||||||
|
expected = append(expected, rspHead.Bytes()...)
|
||||||
|
expected = append(expected, []byte("pong")...)
|
||||||
|
|
||||||
|
out := make([]byte, len(expected))
|
||||||
conn.SetDeadline(time.Now().Add(time.Second))
|
conn.SetDeadline(time.Now().Add(time.Second))
|
||||||
if _, err := io.ReadAtLeast(conn, out, len(out)); err != nil {
|
if _, err := io.ReadFull(conn, out); err != nil {
|
||||||
t.Fatalf("err: %v", err)
|
t.Fatalf("err: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
t.Logf("proxy bind port: %d", buildPort(out[12], out[13]))
|
||||||
|
|
||||||
// Ignore the port
|
// Ignore the port
|
||||||
out[12] = 0
|
out[12] = 0
|
||||||
out[13] = 0
|
out[13] = 0
|
||||||
@ -108,3 +119,127 @@ func TestSOCKS5_Connect(t *testing.T) {
|
|||||||
t.Fatalf("bad: %v", out)
|
t.Fatalf("bad: %v", out)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSOCKS5_Associate(t *testing.T) {
|
||||||
|
locIP := net.ParseIP("127.0.0.1")
|
||||||
|
// Create a local listener
|
||||||
|
lAddr := &net.UDPAddr{
|
||||||
|
IP: locIP,
|
||||||
|
Port: 12398,
|
||||||
|
}
|
||||||
|
l, err := net.ListenUDP("udp4", lAddr)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("err: %v", err)
|
||||||
|
}
|
||||||
|
defer l.Close()
|
||||||
|
go func() {
|
||||||
|
buf := make([]byte, 2048)
|
||||||
|
for {
|
||||||
|
n, remote, err := l.ReadFrom(buf)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !bytes.Equal(buf[:n], []byte("ping")) {
|
||||||
|
t.Fatalf("bad: %v", buf)
|
||||||
|
}
|
||||||
|
l.WriteTo([]byte("pong"), remote)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Create a socks server
|
||||||
|
cator := UserPassAuthenticator{Credentials: StaticCredentials{"foo": "bar"}}
|
||||||
|
conf := &Config{
|
||||||
|
AuthMethods: []Authenticator{cator},
|
||||||
|
Logger: NewLogger(log.New(os.Stdout, "socks5: ", log.LstdFlags)),
|
||||||
|
}
|
||||||
|
serv, err := New(conf)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("err: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start listening
|
||||||
|
go func() {
|
||||||
|
if err := serv.ListenAndServe("tcp", "127.0.0.1:12355"); err != nil {
|
||||||
|
t.Fatalf("err: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
|
||||||
|
// Get a local conn
|
||||||
|
conn, err := net.Dial("tcp", "127.0.0.1:12355")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("err: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connect, auth and connec to local
|
||||||
|
req := new(bytes.Buffer)
|
||||||
|
req.Write([]byte{socks5Version, 2, NoAuth, UserPassAuth})
|
||||||
|
req.Write([]byte{userPassAuthVersion, 3, 'f', 'o', 'o', 3, 'b', 'a', 'r'})
|
||||||
|
reqHead := Header{
|
||||||
|
Version: socks5Version,
|
||||||
|
Command: AssociateCommand,
|
||||||
|
Reserved: 0,
|
||||||
|
Address: AddrSpec{
|
||||||
|
"",
|
||||||
|
locIP,
|
||||||
|
lAddr.Port,
|
||||||
|
},
|
||||||
|
addrType: ipv4Address,
|
||||||
|
}
|
||||||
|
req.Write(reqHead.Bytes())
|
||||||
|
// Send all the bytes
|
||||||
|
conn.Write(req.Bytes())
|
||||||
|
|
||||||
|
// Verify response
|
||||||
|
expected := []byte{
|
||||||
|
socks5Version, UserPassAuth, // use user password auth
|
||||||
|
userPassAuthVersion, authSuccess, // response auth success
|
||||||
|
}
|
||||||
|
rspHead := Header{
|
||||||
|
Version: socks5Version,
|
||||||
|
Command: successReply,
|
||||||
|
Reserved: 0,
|
||||||
|
Address: AddrSpec{
|
||||||
|
"",
|
||||||
|
net.ParseIP("0.0.0.0"),
|
||||||
|
0, // Ignore the port
|
||||||
|
},
|
||||||
|
addrType: ipv4Address,
|
||||||
|
}
|
||||||
|
expected = append(expected, rspHead.Bytes()...)
|
||||||
|
|
||||||
|
out := make([]byte, len(expected))
|
||||||
|
conn.SetDeadline(time.Now().Add(time.Second))
|
||||||
|
if _, err := io.ReadFull(conn, out); err != nil {
|
||||||
|
t.Fatalf("err: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ignore the port
|
||||||
|
proxyBindPort := buildPort(out[12], out[13])
|
||||||
|
|
||||||
|
out[12] = 0
|
||||||
|
out[13] = 0
|
||||||
|
|
||||||
|
t.Logf("proxy bind listen port: %d", proxyBindPort)
|
||||||
|
|
||||||
|
if !bytes.Equal(out, expected) {
|
||||||
|
t.Fatalf("bad: %v", out)
|
||||||
|
}
|
||||||
|
|
||||||
|
udpConn, err := net.DialUDP("udp4", nil, &net.UDPAddr{
|
||||||
|
IP: locIP,
|
||||||
|
//Port: lAddr.Port,
|
||||||
|
Port: proxyBindPort,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("bad dial: %v", err)
|
||||||
|
}
|
||||||
|
// Send a ping
|
||||||
|
udpConn.Write([]byte("ping"))
|
||||||
|
response := make([]byte, 1024)
|
||||||
|
n, _, err := udpConn.ReadFrom(response)
|
||||||
|
if !bytes.Equal(response[:n], []byte("pong")) {
|
||||||
|
t.Fatalf("bad udp read: %v", string(response[:n]))
|
||||||
|
}
|
||||||
|
time.Sleep(time.Second * 1)
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user