add testing example
This commit is contained in:
mo 2020-08-06 11:11:30 +08:00
parent 9935a8baca
commit 880274af1f
6 changed files with 137 additions and 112 deletions

@ -2,10 +2,12 @@ go-socks5
=========
[![Build Status](https://travis-ci.org/thinkgos/go-socks5.svg?branch=master)](https://travis-ci.org/thinkgos/go-socks5)
[![Go.Dev reference](https://img.shields.io/badge/go.dev-reference-blue?logo=go&logoColor=white)](https://pkg.go.dev/github.com/thinkgos/go-socks5?tab=doc)
[![codecov](https://codecov.io/gh/thinkgos/go-socks5/branch/master/graph/badge.svg)](https://codecov.io/gh/thinkgos/go-socks5)
![Action Status](https://github.com/thinkgos/go-socks5/workflows/Go/badge.svg)
[![Go Report Card](https://goreportcard.com/badge/github.com/thinkgos/go-socks5)](https://goreportcard.com/report/github.com/thinkgos/go-socks5)
[![License](https://img.shields.io/github/license/thinkgos/go-socks5)](https://github.com/thinkgos/go-socks5/raw/master/LICENSE)
[![Tag](https://img.shields.io/github/v/tag/thinkgos/go-socks5)](https://github.com/thinkgos/go-socks5/tags)
Provides the `socks5` package that implements a [SOCKS5 server](http://en.wikipedia.org/wiki/SOCKS).
SOCKS (Secure Sockets) is used to route traffic between a client and server through
@ -15,16 +17,16 @@ Feature
=======
The package has the following features:
* Unit tests
* "No Auth" mode
* User/Password authentication optional user addr limit
* Support for the CONNECT command
* Support for the ASSOCIATE command
* Rules to do granular filtering of commands
* Custom DNS resolution
* Unit tests
* Custom goroutine pool
* buffer pool design and optional custom buffer pool
* Custom logger
* buffer pool design
TODO
====
@ -42,11 +44,11 @@ Below is a simple example of usage
server := socks5.NewServer()
// Create SOCKS5 proxy on localhost port 8000
if err := server.ListenAndServe("tcp", "127.0.0.1:8000"); err != nil {
if err := server.ListenAndServe("tcp", ":8000"); err != nil {
panic(err)
}
```
# Reference
- [rfc1928](https://www.ietf.org/rfc/rfc1928.txt)
- original armon go-sock5 [go-sock5](https://github.com/armon/go-socks5)
- original armon [go-sock5](https://github.com/armon/go-socks5)

@ -7,43 +7,7 @@ import (
"github.com/thinkgos/go-socks5"
)
func handleErr(err error) {
if err != nil {
panic(err)
}
}
func main() {
// // Create a local listener
// l, err := net.Listen("tcp", "127.0.0.1:0")
// handleErr(err)
//
// go func() {
// conn, err := l.Accept()
// handleErr(err)
// defer conn.Close()
//
// buf := make([]byte, 4)
// _, err = io.ReadAtLeast(conn, buf, 4)
// handleErr(err)
// log.Printf("server: %+v", string(buf))
// conn.Write([]byte("pong"))
// }()
// lAddr := l.Addr().(*net.TCPAddr)
//
// go func() {
// time.Sleep(time.Second)
// c, err := client.NewClient("127.0.0.1:1080")
// handleErr(err)
// con, err := c.Dial("tcp", lAddr.String())
// handleErr(err)
// con.Write([]byte("ping"))
// out := make([]byte, 4)
// _ = con.SetDeadline(time.Now().Add(time.Second)) // nolint: errcheck
// _, err = io.ReadFull(con, out)
// log.Printf("client: %+v", string(out))
// }()
// Create a SOCKS5 server
server := socks5.NewServer(socks5.WithLogger(socks5.NewLogger(log.New(os.Stdout, "socks5: ", log.LstdFlags))))

60
_testing/tcp/main.go Normal file

@ -0,0 +1,60 @@
package main
import (
"io"
"log"
"net"
"os"
"time"
"github.com/thinkgos/go-socks5"
"github.com/thinkgos/go-socks5/client"
)
func handleErr(err error) {
if err != nil {
panic(err)
}
}
func main() {
// Create a local listener
l, err := net.Listen("tcp", "127.0.0.1:0")
handleErr(err)
go func() {
conn, err := l.Accept()
handleErr(err)
defer conn.Close()
buf := make([]byte, 4)
_, err = io.ReadAtLeast(conn, buf, 4)
handleErr(err)
log.Printf("server: %+v", string(buf))
conn.Write([]byte("pong"))
}()
lAddr := l.Addr().(*net.TCPAddr)
go func() {
time.Sleep(time.Second * 1)
c, err := client.NewClient("127.0.0.1:10809")
handleErr(err)
con, err := c.Dial("tcp", lAddr.String())
handleErr(err)
con.Write([]byte("ping"))
out := make([]byte, 4)
_ = con.SetDeadline(time.Now().Add(time.Second)) // nolint: errcheck
_, err = io.ReadFull(con, out)
log.Printf("client: %+v", string(out))
}()
// Create a SOCKS5 server
server := socks5.NewServer(socks5.WithLogger(socks5.NewLogger(log.New(os.Stdout, "socks5: ", log.LstdFlags))))
// Create SOCKS5 proxy on localhost port 8000
if err := server.ListenAndServe("tcp", "127.0.0.1:10809"); err != nil {
panic(err)
}
}

@ -47,14 +47,14 @@ func ParseRequest(bufConn io.Reader) (*Request, error) {
}
// handleRequest is used for request processing after authentication
func (s *Server) handleRequest(write io.Writer, req *Request) error {
func (sf *Server) handleRequest(write io.Writer, req *Request) error {
var err error
ctx := context.Background()
// Resolve the address if we have a FQDN
dest := req.RawDestAddr
if dest.FQDN != "" {
ctx, dest.IP, err = s.resolver.Resolve(ctx, dest.FQDN)
ctx, dest.IP, err = sf.resolver.Resolve(ctx, dest.FQDN)
if err != nil {
if err := SendReply(write, statute.RepHostUnreachable, nil); err != nil {
return fmt.Errorf("failed to send reply, %v", err)
@ -65,13 +65,13 @@ func (s *Server) handleRequest(write io.Writer, req *Request) error {
// Apply any address rewrites
req.DestAddr = req.RawDestAddr
if s.rewriter != nil {
ctx, req.DestAddr = s.rewriter.Rewrite(ctx, req)
if sf.rewriter != nil {
ctx, req.DestAddr = sf.rewriter.Rewrite(ctx, req)
}
// Check if this is allowed
var ok bool
ctx, ok = s.rules.Allow(ctx, req)
ctx, ok = sf.rules.Allow(ctx, req)
if !ok {
if err := SendReply(write, statute.RepRuleFailure, nil); err != nil {
return fmt.Errorf("failed to send reply, %v", err)
@ -82,20 +82,20 @@ func (s *Server) handleRequest(write io.Writer, req *Request) error {
// Switch on the command
switch req.Command {
case statute.CommandConnect:
if s.userConnectHandle != nil {
return s.userConnectHandle(ctx, write, req)
if sf.userConnectHandle != nil {
return sf.userConnectHandle(ctx, write, req)
}
return s.handleConnect(ctx, write, req)
return sf.handleConnect(ctx, write, req)
case statute.CommandBind:
if s.userBindHandle != nil {
return s.userBindHandle(ctx, write, req)
if sf.userBindHandle != nil {
return sf.userBindHandle(ctx, write, req)
}
return s.handleBind(ctx, write, req)
return sf.handleBind(ctx, write, req)
case statute.CommandAssociate:
if s.userAssociateHandle != nil {
return s.userAssociateHandle(ctx, write, req)
if sf.userAssociateHandle != nil {
return sf.userAssociateHandle(ctx, write, req)
}
return s.handleAssociate(ctx, write, req)
return sf.handleAssociate(ctx, write, req)
default:
if err := SendReply(write, statute.RepCommandNotSupported, nil); err != nil {
return fmt.Errorf("failed to send reply, %v", err)
@ -105,9 +105,9 @@ func (s *Server) handleRequest(write io.Writer, req *Request) error {
}
// handleConnect is used to handle a connect command
func (s *Server) handleConnect(ctx context.Context, writer io.Writer, request *Request) error {
func (sf *Server) handleConnect(ctx context.Context, writer io.Writer, request *Request) error {
// Attempt to connect
dial := s.dial
dial := sf.dial
if dial == nil {
dial = func(ctx context.Context, net_, addr string) (net.Conn, error) {
return net.Dial(net_, addr)
@ -136,8 +136,8 @@ func (s *Server) handleConnect(ctx context.Context, writer io.Writer, request *R
// Start proxying
errCh := make(chan error, 2)
s.submit(func() { errCh <- s.Proxy(target, request.Reader) })
s.submit(func() { errCh <- s.Proxy(writer, target) })
sf.submit(func() { errCh <- sf.Proxy(target, request.Reader) })
sf.submit(func() { errCh <- sf.Proxy(writer, target) })
// Wait
for i := 0; i < 2; i++ {
e := <-errCh
@ -150,7 +150,7 @@ func (s *Server) handleConnect(ctx context.Context, writer io.Writer, request *R
}
// handleBind is used to handle a connect command
func (s *Server) handleBind(_ context.Context, writer io.Writer, _ *Request) error {
func (sf *Server) handleBind(_ context.Context, writer io.Writer, _ *Request) error {
// TODO: Support bind
if err := SendReply(writer, statute.RepCommandNotSupported, nil); err != nil {
return fmt.Errorf("failed to send reply: %v", err)
@ -159,14 +159,15 @@ func (s *Server) handleBind(_ context.Context, writer io.Writer, _ *Request) err
}
// handleAssociate is used to handle a connect command
func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, request *Request) error {
func (sf *Server) handleAssociate(ctx context.Context, writer io.Writer, request *Request) error {
// Attempt to connect
dial := s.dial
dial := sf.dial
if dial == nil {
dial = func(ctx context.Context, net_, addr string) (net.Conn, error) {
return net.Dial(net_, addr)
}
}
target, err := dial(ctx, "udp", request.DestAddr.String())
if err != nil {
msg := err.Error()
@ -200,26 +201,26 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, request
}
defer bindLn.Close()
s.logger.Errorf("target addr %v, listen addr: %s", targetUDP.RemoteAddr(), bindLn.LocalAddr())
// send BND.ADDR and BND.PORT, client must
sf.logger.Errorf("target addr %v, listen addr: %s", targetUDP.RemoteAddr(), bindLn.LocalAddr())
// send BND.ADDR and BND.PORT, client used
if err = SendReply(writer, statute.RepSuccess, bindLn.LocalAddr()); err != nil {
return fmt.Errorf("failed to send reply, %v", err)
}
s.submit(func() {
sf.submit(func() {
// read from client and write to remote server
conns := sync.Map{}
bufPool := s.bufferPool.Get()
bufPool := sf.bufferPool.Get()
defer func() {
targetUDP.Close()
bindLn.Close()
s.bufferPool.Put(bufPool)
sf.bufferPool.Put(bufPool)
}()
for {
n, srcAddr, err := bindLn.ReadFrom(bufPool[:cap(bufPool)])
if err != nil {
if strings.Contains(err.Error(), "use of closed network connection") {
s.logger.Errorf("read data from bind listen address %s failed, %v", bindLn.LocalAddr(), err)
sf.logger.Errorf("read data from bind listen address %s failed, %v", bindLn.LocalAddr(), err)
return
}
continue
@ -231,20 +232,20 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, request
}
if _, ok := conns.LoadOrStore(srcAddr.String(), struct{}{}); !ok {
s.submit(func() {
sf.submit(func() {
// read from remote server and write to client
bufPool := s.bufferPool.Get()
bufPool := sf.bufferPool.Get()
defer func() {
targetUDP.Close()
bindLn.Close()
s.bufferPool.Put(bufPool)
sf.bufferPool.Put(bufPool)
}()
for {
buf := bufPool[:cap(bufPool)]
n, remote, err := targetUDP.ReadFrom(buf)
if err != nil {
s.logger.Errorf("read data from remote %s failed, %v", targetUDP.RemoteAddr(), err)
sf.logger.Errorf("read data from remote %s failed, %v", targetUDP.RemoteAddr(), err)
return
}
@ -252,32 +253,31 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, request
if err != nil {
continue
}
tmpBufPool := s.bufferPool.Get()
tmpBufPool := sf.bufferPool.Get()
proBuf := tmpBufPool
proBuf = append(proBuf, pkb.Header()...)
proBuf = append(proBuf, pkb.Data...)
if _, err := bindLn.WriteTo(proBuf, srcAddr); err != nil {
s.bufferPool.Put(tmpBufPool)
s.logger.Errorf("write data to client %s failed, %v", bindLn.LocalAddr(), err)
sf.bufferPool.Put(tmpBufPool)
sf.logger.Errorf("write data to client %s failed, %v", bindLn.LocalAddr(), err)
return
}
s.bufferPool.Put(tmpBufPool)
sf.bufferPool.Put(tmpBufPool)
}
})
}
// 把消息写给remote sever
if _, err := targetUDP.Write(pk.Data); err != nil {
s.logger.Errorf("write data to remote %s failed, %v", targetUDP.RemoteAddr(), err)
sf.logger.Errorf("write data to remote %s failed, %v", targetUDP.RemoteAddr(), err)
return
}
}
})
buf := s.bufferPool.Get()
defer func() {
s.bufferPool.Put(buf)
}()
buf := sf.bufferPool.Get()
defer sf.bufferPool.Put(buf)
for {
_, err := request.Reader.Read(buf[:cap(buf)])
if err != nil {
@ -329,9 +329,9 @@ type closeWriter interface {
// Proxy is used to suffle data from src to destination, and sends errors
// down a dedicated channel
func (s *Server) Proxy(dst io.Writer, src io.Reader) error {
buf := s.bufferPool.Get()
defer s.bufferPool.Put(buf)
func (sf *Server) Proxy(dst io.Writer, src io.Reader) error {
buf := sf.bufferPool.Get()
defer sf.bufferPool.Put(buf)
_, err := io.CopyBuffer(dst, src, buf[:cap(buf)])
if tcpConn, ok := dst.(closeWriter); ok {
tcpConn.CloseWrite() // nolint: errcheck

@ -88,31 +88,31 @@ func NewServer(opts ...Option) *Server {
}
// ListenAndServe is used to create a listener and serve on it
func (s *Server) ListenAndServe(network, addr string) error {
func (sf *Server) ListenAndServe(network, addr string) error {
l, err := net.Listen(network, addr)
if err != nil {
return err
}
return s.Serve(l)
return sf.Serve(l)
}
// Serve is used to serve connections from a listener
func (s *Server) Serve(l net.Listener) error {
func (sf *Server) Serve(l net.Listener) error {
for {
conn, err := l.Accept()
if err != nil {
return err
}
s.submit(func() {
if err := s.ServeConn(conn); err != nil {
s.logger.Errorf("server conn %v", err)
sf.submit(func() {
if err := sf.ServeConn(conn); err != nil {
sf.logger.Errorf("server conn %v", err)
}
})
}
}
// ServeConn is used to serve a single connection.
func (s *Server) ServeConn(conn net.Conn) error {
func (sf *Server) ServeConn(conn net.Conn) error {
var authContext *AuthContext
defer conn.Close()
@ -128,7 +128,7 @@ func (s *Server) ServeConn(conn net.Conn) error {
}
// Authenticate the connection
authContext, err = s.authenticate(conn, bufConn, conn.RemoteAddr().String(), mr.Methods)
authContext, err = sf.authenticate(conn, bufConn, conn.RemoteAddr().String(), mr.Methods)
if err != nil {
return fmt.Errorf("failed to authenticate: %w", err)
}
@ -157,14 +157,14 @@ func (s *Server) ServeConn(conn net.Conn) error {
request.LocalAddr = conn.LocalAddr()
request.RemoteAddr = conn.RemoteAddr()
// Process the client request
return s.handleRequest(conn, request)
return sf.handleRequest(conn, request)
}
// authenticate is used to handle connection authentication
func (s *Server) authenticate(conn io.Writer, bufConn io.Reader, userAddr string, methods []byte) (*AuthContext, error) {
func (sf *Server) authenticate(conn io.Writer, bufConn io.Reader, userAddr string, methods []byte) (*AuthContext, error) {
// Select a usable method
for _, method := range methods {
if cator, found := s.authMethods[method]; found {
if cator, found := sf.authMethods[method]; found {
return cator.Authenticate(bufConn, conn, userAddr)
}
}
@ -173,8 +173,8 @@ func (s *Server) authenticate(conn io.Writer, bufConn io.Reader, userAddr string
return nil, statute.ErrNoSupportedAuth
}
func (s *Server) submit(f func()) {
if s.gPool == nil || s.gPool.Submit(f) != nil {
func (sf *Server) submit(f func()) {
if sf.gPool == nil || sf.gPool.Submit(f) != nil {
go f()
}
}

@ -110,10 +110,7 @@ func TestSOCKS5_Connect(t *testing.T) {
func TestSOCKS5_Associate(t *testing.T) {
locIP := net.ParseIP("127.0.0.1")
// Create a local listener
lAddr := &net.UDPAddr{
IP: locIP,
Port: 12398,
}
lAddr := &net.UDPAddr{IP: locIP, Port: 12398}
l, err := net.ListenUDP("udp", lAddr)
require.NoError(t, err)
defer l.Close()
@ -133,13 +130,13 @@ func TestSOCKS5_Associate(t *testing.T) {
// Create a socks server
cator := UserPassAuthenticator{StaticCredentials{"foo": "bar"}}
srv := NewServer(
proxySrv := NewServer(
WithAuthMethods([]Authenticator{cator}),
WithLogger(NewLogger(log.New(os.Stdout, "socks5: ", log.LstdFlags))),
)
// Start listening
go func() {
err := srv.ListenAndServe("tcp", "127.0.0.1:12355")
err := proxySrv.ListenAndServe("tcp", "127.0.0.1:12355")
require.NoError(t, err)
}()
time.Sleep(10 * time.Millisecond)
@ -149,9 +146,11 @@ func TestSOCKS5_Associate(t *testing.T) {
require.NoError(t, err)
// Connect, auth and connec to local
req := new(bytes.Buffer)
req.Write([]byte{statute.VersionSocks5, 2, statute.MethodNoAuth, statute.MethodUserPassAuth})
req.Write([]byte{statute.UserPassAuthVersion, 3, 'f', 'o', 'o', 3, 'b', 'a', 'r'})
req := bytes.NewBuffer(
[]byte{
statute.VersionSocks5, 2, statute.MethodNoAuth, statute.MethodUserPassAuth,
statute.UserPassAuthVersion, 3, 'f', 'o', 'o', 3, 'b', 'a', 'r',
})
reqHead := statute.Request{
Version: statute.VersionSocks5,
Command: statute.CommandAssociate,
@ -179,25 +178,25 @@ func TestSOCKS5_Associate(t *testing.T) {
require.NoError(t, err)
require.Equal(t, expected, out)
rspHead, err := statute.ParseRequest(conn)
rspHead, err := statute.ParseReply(conn)
require.NoError(t, err)
require.Equal(t, statute.VersionSocks5, rspHead.Version)
require.Equal(t, statute.RepSuccess, rspHead.Command)
require.Equal(t, statute.RepSuccess, rspHead.Response)
t.Logf("proxy bind listen port: %d", rspHead.DstAddress.Port)
t.Logf("proxy bind listen port: %d", rspHead.BndAddress.Port)
udpConn, err := net.DialUDP("udp", nil, &net.UDPAddr{
IP: locIP,
Port: rspHead.DstAddress.Port,
Port: rspHead.BndAddress.Port,
})
require.NoError(t, err)
// Send a ping
udpConn.Write(append([]byte{0, 0, 0, statute.ATYPIPv4, 0, 0, 0, 0, 0, 0}, []byte("ping")...)) // nolint: errcheck
response := make([]byte, 1024)
n, _, err := udpConn.ReadFrom(response)
if err != nil || !bytes.Equal(response[n-4:n], []byte("pong")) {
t.Fatalf("bad udp read: %v", string(response[:n]))
}
require.NoError(t, err)
assert.Equal(t, []byte("pong"), response[n-4:n])
time.Sleep(time.Second * 1)
}