use buffer pool interface

This commit is contained in:
mo 2020-08-06 10:12:46 +08:00
parent d83167e33c
commit 9935a8baca
13 changed files with 74 additions and 70 deletions

3
.gitignore vendored

@ -6,6 +6,7 @@
# Folders # Folders
_obj _obj
_test _test
.idea
# Architecture specific extensions/prefixes # Architecture specific extensions/prefixes
*.[568vq] *.[568vq]
@ -20,4 +21,4 @@ _cgo_export.*
_testmain.go _testmain.go
*.exe *.exe
.idea

@ -1,14 +1,10 @@
package main package main
import ( import (
"io"
"log" "log"
"net"
"os" "os"
"time"
"github.com/thinkgos/go-socks5" "github.com/thinkgos/go-socks5"
"github.com/thinkgos/go-socks5/client"
) )
func handleErr(err error) { func handleErr(err error) {
@ -18,41 +14,41 @@ func handleErr(err error) {
} }
func main() { func main() {
// Create a local listener // // Create a local listener
l, err := net.Listen("tcp", "127.0.0.1:0") // l, err := net.Listen("tcp", "127.0.0.1:0")
handleErr(err) // handleErr(err)
//
go func() { // go func() {
conn, err := l.Accept() // conn, err := l.Accept()
handleErr(err) // handleErr(err)
defer conn.Close() // defer conn.Close()
//
buf := make([]byte, 4) // buf := make([]byte, 4)
_, err = io.ReadAtLeast(conn, buf, 4) // _, err = io.ReadAtLeast(conn, buf, 4)
handleErr(err) // handleErr(err)
log.Printf("server: %+v", string(buf)) // log.Printf("server: %+v", string(buf))
conn.Write([]byte("pong")) // conn.Write([]byte("pong"))
}() // }()
lAddr := l.Addr().(*net.TCPAddr) // lAddr := l.Addr().(*net.TCPAddr)
//
go func() { // go func() {
time.Sleep(time.Second) // time.Sleep(time.Second)
c, err := client.NewClient("127.0.0.1:1080") // c, err := client.NewClient("127.0.0.1:1080")
handleErr(err) // handleErr(err)
con, err := c.Dial("tcp", lAddr.String()) // con, err := c.Dial("tcp", lAddr.String())
handleErr(err) // handleErr(err)
con.Write([]byte("ping")) // con.Write([]byte("ping"))
out := make([]byte, 4) // out := make([]byte, 4)
_ = con.SetDeadline(time.Now().Add(time.Second)) // nolint: errcheck // _ = con.SetDeadline(time.Now().Add(time.Second)) // nolint: errcheck
_, err = io.ReadFull(con, out) // _, err = io.ReadFull(con, out)
log.Printf("client: %+v", string(out)) // log.Printf("client: %+v", string(out))
}() // }()
// Create a SOCKS5 server // Create a SOCKS5 server
server := socks5.NewServer(socks5.WithLogger(socks5.NewLogger(log.New(os.Stdout, "socks5: ", log.LstdFlags)))) server := socks5.NewServer(socks5.WithLogger(socks5.NewLogger(log.New(os.Stdout, "socks5: ", log.LstdFlags))))
// Create SOCKS5 proxy on localhost port 8000 // Create SOCKS5 proxy on localhost port 8000
if err := server.ListenAndServe("tcp", "127.0.0.1:1080"); err != nil { if err := server.ListenAndServe("tcp", ":10800"); err != nil {
panic(err) panic(err)
} }
} }

@ -150,7 +150,7 @@ func (s *Server) handleConnect(ctx context.Context, writer io.Writer, request *R
} }
// handleBind is used to handle a connect command // handleBind is used to handle a connect command
func (s *Server) handleBind(_ context.Context, writer io.Writer, request *Request) error { func (s *Server) handleBind(_ context.Context, writer io.Writer, _ *Request) error {
// TODO: Support bind // TODO: Support bind
if err := SendReply(writer, statute.RepCommandNotSupported, nil); err != nil { if err := SendReply(writer, statute.RepCommandNotSupported, nil); err != nil {
return fmt.Errorf("failed to send reply: %v", err) return fmt.Errorf("failed to send reply: %v", err)

@ -49,7 +49,7 @@ func TestRequest_Connect(t *testing.T) {
rules: NewPermitAll(), rules: NewPermitAll(),
resolver: DNSResolver{}, resolver: DNSResolver{},
logger: NewLogger(log.New(os.Stdout, "socks5: ", log.LstdFlags)), logger: NewLogger(log.New(os.Stdout, "socks5: ", log.LstdFlags)),
bufferPool: newPool(32 * 1024), bufferPool: NewPool(32 * 1024),
} }
// Create the connect request // Create the connect request
@ -106,7 +106,7 @@ func TestRequest_Connect_RuleFail(t *testing.T) {
rules: NewPermitNone(), rules: NewPermitNone(),
resolver: DNSResolver{}, resolver: DNSResolver{},
logger: NewLogger(log.New(os.Stdout, "socks5: ", log.LstdFlags)), logger: NewLogger(log.New(os.Stdout, "socks5: ", log.LstdFlags)),
bufferPool: newPool(32 * 1024), bufferPool: NewPool(32 * 1024),
} }
// Create the connect request // Create the connect request

@ -9,6 +9,14 @@ import (
// Option user's option // Option user's option
type Option func(s *Server) type Option func(s *Server)
// WithBufferPool can be provided to implement custom buffer pool
// By default, buffer pool use size is 32k
func WithBufferPool(bufferPool BufferPool) Option {
return func(s *Server) {
s.bufferPool = bufferPool
}
}
// WithAuthMethods can be provided to implement custom authentication // WithAuthMethods can be provided to implement custom authentication
// By default, "auth-less" mode is enabled. // By default, "auth-less" mode is enabled.
// For password-based auth use UserPassAuthenticator. // For password-based auth use UserPassAuthenticator.
@ -26,9 +34,7 @@ func WithAuthMethods(authMethods []Authenticator) Option {
// and AUthMethods is nil, then "auth-less" mode is enabled. // and AUthMethods is nil, then "auth-less" mode is enabled.
func WithCredential(cs CredentialStore) Option { func WithCredential(cs CredentialStore) Option {
return func(s *Server) { return func(s *Server) {
if cs != nil { s.credentials = cs
s.credentials = cs
}
} }
} }
@ -36,9 +42,7 @@ func WithCredential(cs CredentialStore) Option {
// Defaults to DNSResolver if not provided. // Defaults to DNSResolver if not provided.
func WithResolver(res NameResolver) Option { func WithResolver(res NameResolver) Option {
return func(s *Server) { return func(s *Server) {
if res != nil { s.resolver = res
s.resolver = res
}
} }
} }
@ -46,9 +50,7 @@ func WithResolver(res NameResolver) Option {
// various commands. If not provided, NewPermitAll is used. // various commands. If not provided, NewPermitAll is used.
func WithRule(rule RuleSet) Option { func WithRule(rule RuleSet) Option {
return func(s *Server) { return func(s *Server) {
if rule != nil { s.rules = rule
s.rules = rule
}
} }
} }
@ -57,9 +59,7 @@ func WithRule(rule RuleSet) Option {
// Defaults to NoRewrite. // Defaults to NoRewrite.
func WithRewriter(rew AddressRewriter) Option { func WithRewriter(rew AddressRewriter) Option {
return func(s *Server) { return func(s *Server) {
if rew != nil { s.rewriter = rew
s.rewriter = rew
}
} }
} }
@ -77,27 +77,21 @@ func WithBindIP(ip net.IP) Option {
// Defaults to ioutil.Discard. // Defaults to ioutil.Discard.
func WithLogger(l Logger) Option { func WithLogger(l Logger) Option {
return func(s *Server) { return func(s *Server) {
if l != nil { s.logger = l
s.logger = l
}
} }
} }
// WithDial Optional function for dialing out // WithDial Optional function for dialing out
func WithDial(dial func(ctx context.Context, network, addr string) (net.Conn, error)) Option { func WithDial(dial func(ctx context.Context, network, addr string) (net.Conn, error)) Option {
return func(s *Server) { return func(s *Server) {
if dial != nil { s.dial = dial
s.dial = dial
}
} }
} }
// WithGPool can be provided to do custom goroutine pool. // WithGPool can be provided to do custom goroutine pool.
func WithGPool(pool GPool) Option { func WithGPool(pool GPool) Option {
return func(s *Server) { return func(s *Server) {
if pool != nil { s.gPool = pool
s.gPool = pool
}
} }
} }

14
pool.go

@ -4,22 +4,34 @@ import (
"sync" "sync"
) )
// A BufferPool is an interface for getting and returning temporary
// byte slices for use by io.CopyBuffer.
type BufferPool interface {
Get() []byte
Put([]byte)
}
type pool struct { type pool struct {
size int size int
pool *sync.Pool pool *sync.Pool
} }
func newPool(size int) *pool { // NewPool new buffer pool for getting and returning temporary
// byte slices for use by io.CopyBuffer.
func NewPool(size int) BufferPool {
return &pool{ return &pool{
size, size,
&sync.Pool{ &sync.Pool{
New: func() interface{} { return make([]byte, 0, size) }}, New: func() interface{} { return make([]byte, 0, size) }},
} }
} }
// Get implement interface BufferPool
func (sf *pool) Get() []byte { func (sf *pool) Get() []byte {
return sf.pool.Get().([]byte) return sf.pool.Get().([]byte)
} }
// Put implement interface BufferPool
func (sf *pool) Put(b []byte) { func (sf *pool) Put(b []byte) {
if cap(b) != sf.size { if cap(b) != sf.size {
panic("invalid buffer size that's put into leaky buffer") panic("invalid buffer size that's put into leaky buffer")

@ -8,7 +8,7 @@ import (
) )
func TestPool(t *testing.T) { func TestPool(t *testing.T) {
p := newPool(2048) p := NewPool(2048)
b := p.Get() b := p.Get()
bs := b[0:cap(b)] bs := b[0:cap(b)]
require.Equal(t, cap(b), len(bs)) require.Equal(t, cap(b), len(bs))
@ -21,7 +21,7 @@ func TestPool(t *testing.T) {
} }
func BenchmarkSyncPool(b *testing.B) { func BenchmarkSyncPool(b *testing.B) {
p := newPool(32 * 1024) p := NewPool(32 * 1024)
wg := new(sync.WaitGroup) wg := new(sync.WaitGroup)
b.ResetTimer() b.ResetTimer()

@ -48,7 +48,7 @@ type Server struct {
// Optional function for dialing out // Optional function for dialing out
dial func(ctx context.Context, network, addr string) (net.Conn, error) dial func(ctx context.Context, network, addr string) (net.Conn, error)
// buffer pool // buffer pool
bufferPool *pool bufferPool BufferPool
// goroutine pool // goroutine pool
gPool GPool gPool GPool
// user's handle // user's handle
@ -62,7 +62,7 @@ func NewServer(opts ...Option) *Server {
server := &Server{ server := &Server{
authMethods: make(map[uint8]Authenticator), authMethods: make(map[uint8]Authenticator),
authCustomMethods: []Authenticator{&NoAuthAuthenticator{}}, authCustomMethods: []Authenticator{&NoAuthAuthenticator{}},
bufferPool: newPool(2 * 1024), bufferPool: NewPool(32 * 1024),
resolver: DNSResolver{}, resolver: DNSResolver{},
rules: NewPermitAll(), rules: NewPermitAll(),
logger: NewLogger(log.New(ioutil.Discard, "socks5: ", log.LstdFlags)), logger: NewLogger(log.New(ioutil.Discard, "socks5: ", log.LstdFlags)),

@ -6,13 +6,13 @@ import (
"net" "net"
) )
// Datagram udp packet
// The SOCKS UDP request/response is formed as follows: // The SOCKS UDP request/response is formed as follows:
// +-----+------+-------+----------+----------+----------+ // +-----+------+-------+----------+----------+----------+
// | RSV | FRAG | ATYP | DST.ADDR | DST.PORT | DATA | // | RSV | FRAG | ATYP | DST.ADDR | DST.PORT | DATA |
// +-----+------+-------+----------+----------+----------+ // +-----+------+-------+----------+----------+----------+
// | 2 | 1 | X'00' | Variable | 2 | Variable | // | 2 | 1 | X'00' | Variable | 2 | Variable |
// +-----+------+-------+----------+----------+----------+ // +-----+------+-------+----------+----------+----------+
// Datagram udp packet
type Datagram struct { type Datagram struct {
RSV uint16 RSV uint16
Frag byte Frag byte
@ -34,7 +34,7 @@ func NewDatagram(destAddr string, data []byte) (p Datagram, err error) {
return return
} }
// ParseRequest parse to datagram // ParseDatagram parse to datagram from bytes
func ParseDatagram(b []byte) (da Datagram, err error) { func ParseDatagram(b []byte) (da Datagram, err error) {
if len(b) < 4+net.IPv4len+2 { // no enough data if len(b) < 4+net.IPv4len+2 { // no enough data
err = errors.New("datagram to short") err = errors.New("datagram to short")

@ -4,6 +4,7 @@ import (
"errors" "errors"
) )
// error defined
var ( var (
ErrUnrecognizedAddrType = errors.New("Unrecognized address type") ErrUnrecognizedAddrType = errors.New("Unrecognized address type")
ErrNotSupportVersion = errors.New("not support version") ErrNotSupportVersion = errors.New("not support version")

@ -146,7 +146,7 @@ func (h Reply) Bytes() (b []byte) {
return b return b
} }
// ParseRequest to request from io.Reader // ParseReply parse to reply from io.Reader
func ParseReply(r io.Reader) (rep Reply, err error) { func ParseReply(r io.Reader) (rep Reply, err error) {
// Read the version and command // Read the version and command
tmp := []byte{0, 0} tmp := []byte{0, 0}

@ -1,6 +1,6 @@
package statute package statute
// socks protocol version // VersionSocks5 socks protocol version
const VersionSocks5 = byte(0x05) const VersionSocks5 = byte(0x05)
// request command defined // request command defined

@ -16,7 +16,7 @@ type AddrSpec struct {
AddrType byte AddrType byte
} }
// DstAddress returns a string suitable to dial; prefer returning IP-based // String returns a string suitable to dial; prefer returning IP-based
// address, fallback to FQDN // address, fallback to FQDN
func (a *AddrSpec) String() string { func (a *AddrSpec) String() string {
if 0 != len(a.IP) { if 0 != len(a.IP) {
@ -25,7 +25,7 @@ func (a *AddrSpec) String() string {
return net.JoinHostPort(a.FQDN, strconv.Itoa(a.Port)) return net.JoinHostPort(a.FQDN, strconv.Itoa(a.Port))
} }
// DstAddress returns a string which may be specified // Address returns a string which may be specified
// if IPv4/IPv6 will return < ip:port > // if IPv4/IPv6 will return < ip:port >
// if FQDN will return < domain ip:port > // if FQDN will return < domain ip:port >
// Note: do not used to dial, Please use String // Note: do not used to dial, Please use String