From 9935a8bacaa9d4b38b7010f2ea43be8b669173cd Mon Sep 17 00:00:00 2001 From: mo Date: Thu, 6 Aug 2020 10:12:46 +0800 Subject: [PATCH] use buffer pool interface --- .gitignore | 3 ++- _example/main.go | 64 +++++++++++++++++++++------------------------ handle.go | 2 +- handle_test.go | 4 +-- option.go | 36 +++++++++++-------------- pool.go | 14 +++++++++- pool_test.go | 4 +-- server.go | 4 +-- statute/datagram.go | 4 +-- statute/errors.go | 1 + statute/message.go | 2 +- statute/statute.go | 2 +- statute/util.go | 4 +-- 13 files changed, 74 insertions(+), 70 deletions(-) diff --git a/.gitignore b/.gitignore index 2870498..9022ed5 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ # Folders _obj _test +.idea # Architecture specific extensions/prefixes *.[568vq] @@ -20,4 +21,4 @@ _cgo_export.* _testmain.go *.exe -.idea + diff --git a/_example/main.go b/_example/main.go index dfd1b31..863d3bd 100644 --- a/_example/main.go +++ b/_example/main.go @@ -1,14 +1,10 @@ package main import ( - "io" "log" - "net" "os" - "time" "github.com/thinkgos/go-socks5" - "github.com/thinkgos/go-socks5/client" ) func handleErr(err error) { @@ -18,41 +14,41 @@ func handleErr(err error) { } func main() { - // Create a local listener - l, err := net.Listen("tcp", "127.0.0.1:0") - handleErr(err) - - go func() { - conn, err := l.Accept() - handleErr(err) - defer conn.Close() - - buf := make([]byte, 4) - _, err = io.ReadAtLeast(conn, buf, 4) - handleErr(err) - log.Printf("server: %+v", string(buf)) - conn.Write([]byte("pong")) - }() - lAddr := l.Addr().(*net.TCPAddr) - - go func() { - time.Sleep(time.Second) - c, err := client.NewClient("127.0.0.1:1080") - handleErr(err) - con, err := c.Dial("tcp", lAddr.String()) - handleErr(err) - con.Write([]byte("ping")) - out := make([]byte, 4) - _ = con.SetDeadline(time.Now().Add(time.Second)) // nolint: errcheck - _, err = io.ReadFull(con, out) - log.Printf("client: %+v", string(out)) - }() + // // Create a local listener + // l, err := net.Listen("tcp", "127.0.0.1:0") + // handleErr(err) + // + // go func() { + // conn, err := l.Accept() + // handleErr(err) + // defer conn.Close() + // + // buf := make([]byte, 4) + // _, err = io.ReadAtLeast(conn, buf, 4) + // handleErr(err) + // log.Printf("server: %+v", string(buf)) + // conn.Write([]byte("pong")) + // }() + // lAddr := l.Addr().(*net.TCPAddr) + // + // go func() { + // time.Sleep(time.Second) + // c, err := client.NewClient("127.0.0.1:1080") + // handleErr(err) + // con, err := c.Dial("tcp", lAddr.String()) + // handleErr(err) + // con.Write([]byte("ping")) + // out := make([]byte, 4) + // _ = con.SetDeadline(time.Now().Add(time.Second)) // nolint: errcheck + // _, err = io.ReadFull(con, out) + // log.Printf("client: %+v", string(out)) + // }() // Create a SOCKS5 server server := socks5.NewServer(socks5.WithLogger(socks5.NewLogger(log.New(os.Stdout, "socks5: ", log.LstdFlags)))) // Create SOCKS5 proxy on localhost port 8000 - if err := server.ListenAndServe("tcp", "127.0.0.1:1080"); err != nil { + if err := server.ListenAndServe("tcp", ":10800"); err != nil { panic(err) } } diff --git a/handle.go b/handle.go index 051ad59..b1839a6 100644 --- a/handle.go +++ b/handle.go @@ -150,7 +150,7 @@ func (s *Server) handleConnect(ctx context.Context, writer io.Writer, request *R } // handleBind is used to handle a connect command -func (s *Server) handleBind(_ context.Context, writer io.Writer, request *Request) error { +func (s *Server) handleBind(_ context.Context, writer io.Writer, _ *Request) error { // TODO: Support bind if err := SendReply(writer, statute.RepCommandNotSupported, nil); err != nil { return fmt.Errorf("failed to send reply: %v", err) diff --git a/handle_test.go b/handle_test.go index de2b1c3..5e61bf8 100644 --- a/handle_test.go +++ b/handle_test.go @@ -49,7 +49,7 @@ func TestRequest_Connect(t *testing.T) { rules: NewPermitAll(), resolver: DNSResolver{}, logger: NewLogger(log.New(os.Stdout, "socks5: ", log.LstdFlags)), - bufferPool: newPool(32 * 1024), + bufferPool: NewPool(32 * 1024), } // Create the connect request @@ -106,7 +106,7 @@ func TestRequest_Connect_RuleFail(t *testing.T) { rules: NewPermitNone(), resolver: DNSResolver{}, logger: NewLogger(log.New(os.Stdout, "socks5: ", log.LstdFlags)), - bufferPool: newPool(32 * 1024), + bufferPool: NewPool(32 * 1024), } // Create the connect request diff --git a/option.go b/option.go index 9ad1b63..8d89d4e 100644 --- a/option.go +++ b/option.go @@ -9,6 +9,14 @@ import ( // Option user's option type Option func(s *Server) +// WithBufferPool can be provided to implement custom buffer pool +// By default, buffer pool use size is 32k +func WithBufferPool(bufferPool BufferPool) Option { + return func(s *Server) { + s.bufferPool = bufferPool + } +} + // WithAuthMethods can be provided to implement custom authentication // By default, "auth-less" mode is enabled. // For password-based auth use UserPassAuthenticator. @@ -26,9 +34,7 @@ func WithAuthMethods(authMethods []Authenticator) Option { // and AUthMethods is nil, then "auth-less" mode is enabled. func WithCredential(cs CredentialStore) Option { return func(s *Server) { - if cs != nil { - s.credentials = cs - } + s.credentials = cs } } @@ -36,9 +42,7 @@ func WithCredential(cs CredentialStore) Option { // Defaults to DNSResolver if not provided. func WithResolver(res NameResolver) Option { return func(s *Server) { - if res != nil { - s.resolver = res - } + s.resolver = res } } @@ -46,9 +50,7 @@ func WithResolver(res NameResolver) Option { // various commands. If not provided, NewPermitAll is used. func WithRule(rule RuleSet) Option { return func(s *Server) { - if rule != nil { - s.rules = rule - } + s.rules = rule } } @@ -57,9 +59,7 @@ func WithRule(rule RuleSet) Option { // Defaults to NoRewrite. func WithRewriter(rew AddressRewriter) Option { return func(s *Server) { - if rew != nil { - s.rewriter = rew - } + s.rewriter = rew } } @@ -77,27 +77,21 @@ func WithBindIP(ip net.IP) Option { // Defaults to ioutil.Discard. func WithLogger(l Logger) Option { return func(s *Server) { - if l != nil { - s.logger = l - } + s.logger = l } } // WithDial Optional function for dialing out func WithDial(dial func(ctx context.Context, network, addr string) (net.Conn, error)) Option { return func(s *Server) { - if dial != nil { - s.dial = dial - } + s.dial = dial } } // WithGPool can be provided to do custom goroutine pool. func WithGPool(pool GPool) Option { return func(s *Server) { - if pool != nil { - s.gPool = pool - } + s.gPool = pool } } diff --git a/pool.go b/pool.go index bdcf44b..3b9d7c7 100644 --- a/pool.go +++ b/pool.go @@ -4,22 +4,34 @@ import ( "sync" ) +// A BufferPool is an interface for getting and returning temporary +// byte slices for use by io.CopyBuffer. +type BufferPool interface { + Get() []byte + Put([]byte) +} + type pool struct { size int pool *sync.Pool } -func newPool(size int) *pool { +// NewPool new buffer pool for getting and returning temporary +// byte slices for use by io.CopyBuffer. +func NewPool(size int) BufferPool { return &pool{ size, &sync.Pool{ New: func() interface{} { return make([]byte, 0, size) }}, } } + +// Get implement interface BufferPool func (sf *pool) Get() []byte { return sf.pool.Get().([]byte) } +// Put implement interface BufferPool func (sf *pool) Put(b []byte) { if cap(b) != sf.size { panic("invalid buffer size that's put into leaky buffer") diff --git a/pool_test.go b/pool_test.go index ccb2e19..fa42b0f 100644 --- a/pool_test.go +++ b/pool_test.go @@ -8,7 +8,7 @@ import ( ) func TestPool(t *testing.T) { - p := newPool(2048) + p := NewPool(2048) b := p.Get() bs := b[0:cap(b)] require.Equal(t, cap(b), len(bs)) @@ -21,7 +21,7 @@ func TestPool(t *testing.T) { } func BenchmarkSyncPool(b *testing.B) { - p := newPool(32 * 1024) + p := NewPool(32 * 1024) wg := new(sync.WaitGroup) b.ResetTimer() diff --git a/server.go b/server.go index 365bb3f..7ed95e0 100644 --- a/server.go +++ b/server.go @@ -48,7 +48,7 @@ type Server struct { // Optional function for dialing out dial func(ctx context.Context, network, addr string) (net.Conn, error) // buffer pool - bufferPool *pool + bufferPool BufferPool // goroutine pool gPool GPool // user's handle @@ -62,7 +62,7 @@ func NewServer(opts ...Option) *Server { server := &Server{ authMethods: make(map[uint8]Authenticator), authCustomMethods: []Authenticator{&NoAuthAuthenticator{}}, - bufferPool: newPool(2 * 1024), + bufferPool: NewPool(32 * 1024), resolver: DNSResolver{}, rules: NewPermitAll(), logger: NewLogger(log.New(ioutil.Discard, "socks5: ", log.LstdFlags)), diff --git a/statute/datagram.go b/statute/datagram.go index 4981f8f..af7e671 100644 --- a/statute/datagram.go +++ b/statute/datagram.go @@ -6,13 +6,13 @@ import ( "net" ) +// Datagram udp packet // The SOCKS UDP request/response is formed as follows: // +-----+------+-------+----------+----------+----------+ // | RSV | FRAG | ATYP | DST.ADDR | DST.PORT | DATA | // +-----+------+-------+----------+----------+----------+ // | 2 | 1 | X'00' | Variable | 2 | Variable | // +-----+------+-------+----------+----------+----------+ -// Datagram udp packet type Datagram struct { RSV uint16 Frag byte @@ -34,7 +34,7 @@ func NewDatagram(destAddr string, data []byte) (p Datagram, err error) { return } -// ParseRequest parse to datagram +// ParseDatagram parse to datagram from bytes func ParseDatagram(b []byte) (da Datagram, err error) { if len(b) < 4+net.IPv4len+2 { // no enough data err = errors.New("datagram to short") diff --git a/statute/errors.go b/statute/errors.go index 4ba1136..8905feb 100644 --- a/statute/errors.go +++ b/statute/errors.go @@ -4,6 +4,7 @@ import ( "errors" ) +// error defined var ( ErrUnrecognizedAddrType = errors.New("Unrecognized address type") ErrNotSupportVersion = errors.New("not support version") diff --git a/statute/message.go b/statute/message.go index 0fe1ad2..986a0ef 100644 --- a/statute/message.go +++ b/statute/message.go @@ -146,7 +146,7 @@ func (h Reply) Bytes() (b []byte) { return b } -// ParseRequest to request from io.Reader +// ParseReply parse to reply from io.Reader func ParseReply(r io.Reader) (rep Reply, err error) { // Read the version and command tmp := []byte{0, 0} diff --git a/statute/statute.go b/statute/statute.go index 98ddd35..833dbdc 100644 --- a/statute/statute.go +++ b/statute/statute.go @@ -1,6 +1,6 @@ package statute -// socks protocol version +// VersionSocks5 socks protocol version const VersionSocks5 = byte(0x05) // request command defined diff --git a/statute/util.go b/statute/util.go index cec99af..2d18bc4 100644 --- a/statute/util.go +++ b/statute/util.go @@ -16,7 +16,7 @@ type AddrSpec struct { AddrType byte } -// DstAddress returns a string suitable to dial; prefer returning IP-based +// String returns a string suitable to dial; prefer returning IP-based // address, fallback to FQDN func (a *AddrSpec) String() string { if 0 != len(a.IP) { @@ -25,7 +25,7 @@ func (a *AddrSpec) String() string { return net.JoinHostPort(a.FQDN, strconv.Itoa(a.Port)) } -// DstAddress returns a string which may be specified +// Address returns a string which may be specified // if IPv4/IPv6 will return < ip:port > // if FQDN will return < domain ip:port > // Note: do not used to dial, Please use String