use buffer pool interface

This commit is contained in:
mo 2020-08-06 10:12:46 +08:00
parent d83167e33c
commit 9935a8baca
13 changed files with 74 additions and 70 deletions

3
.gitignore vendored
View File

@ -6,6 +6,7 @@
# Folders
_obj
_test
.idea
# Architecture specific extensions/prefixes
*.[568vq]
@ -20,4 +21,4 @@ _cgo_export.*
_testmain.go
*.exe
.idea

View File

@ -1,14 +1,10 @@
package main
import (
"io"
"log"
"net"
"os"
"time"
"github.com/thinkgos/go-socks5"
"github.com/thinkgos/go-socks5/client"
)
func handleErr(err error) {
@ -18,41 +14,41 @@ func handleErr(err error) {
}
func main() {
// Create a local listener
l, err := net.Listen("tcp", "127.0.0.1:0")
handleErr(err)
go func() {
conn, err := l.Accept()
handleErr(err)
defer conn.Close()
buf := make([]byte, 4)
_, err = io.ReadAtLeast(conn, buf, 4)
handleErr(err)
log.Printf("server: %+v", string(buf))
conn.Write([]byte("pong"))
}()
lAddr := l.Addr().(*net.TCPAddr)
go func() {
time.Sleep(time.Second)
c, err := client.NewClient("127.0.0.1:1080")
handleErr(err)
con, err := c.Dial("tcp", lAddr.String())
handleErr(err)
con.Write([]byte("ping"))
out := make([]byte, 4)
_ = con.SetDeadline(time.Now().Add(time.Second)) // nolint: errcheck
_, err = io.ReadFull(con, out)
log.Printf("client: %+v", string(out))
}()
// // Create a local listener
// l, err := net.Listen("tcp", "127.0.0.1:0")
// handleErr(err)
//
// go func() {
// conn, err := l.Accept()
// handleErr(err)
// defer conn.Close()
//
// buf := make([]byte, 4)
// _, err = io.ReadAtLeast(conn, buf, 4)
// handleErr(err)
// log.Printf("server: %+v", string(buf))
// conn.Write([]byte("pong"))
// }()
// lAddr := l.Addr().(*net.TCPAddr)
//
// go func() {
// time.Sleep(time.Second)
// c, err := client.NewClient("127.0.0.1:1080")
// handleErr(err)
// con, err := c.Dial("tcp", lAddr.String())
// handleErr(err)
// con.Write([]byte("ping"))
// out := make([]byte, 4)
// _ = con.SetDeadline(time.Now().Add(time.Second)) // nolint: errcheck
// _, err = io.ReadFull(con, out)
// log.Printf("client: %+v", string(out))
// }()
// Create a SOCKS5 server
server := socks5.NewServer(socks5.WithLogger(socks5.NewLogger(log.New(os.Stdout, "socks5: ", log.LstdFlags))))
// Create SOCKS5 proxy on localhost port 8000
if err := server.ListenAndServe("tcp", "127.0.0.1:1080"); err != nil {
if err := server.ListenAndServe("tcp", ":10800"); err != nil {
panic(err)
}
}

View File

@ -150,7 +150,7 @@ func (s *Server) handleConnect(ctx context.Context, writer io.Writer, request *R
}
// handleBind is used to handle a connect command
func (s *Server) handleBind(_ context.Context, writer io.Writer, request *Request) error {
func (s *Server) handleBind(_ context.Context, writer io.Writer, _ *Request) error {
// TODO: Support bind
if err := SendReply(writer, statute.RepCommandNotSupported, nil); err != nil {
return fmt.Errorf("failed to send reply: %v", err)

View File

@ -49,7 +49,7 @@ func TestRequest_Connect(t *testing.T) {
rules: NewPermitAll(),
resolver: DNSResolver{},
logger: NewLogger(log.New(os.Stdout, "socks5: ", log.LstdFlags)),
bufferPool: newPool(32 * 1024),
bufferPool: NewPool(32 * 1024),
}
// Create the connect request
@ -106,7 +106,7 @@ func TestRequest_Connect_RuleFail(t *testing.T) {
rules: NewPermitNone(),
resolver: DNSResolver{},
logger: NewLogger(log.New(os.Stdout, "socks5: ", log.LstdFlags)),
bufferPool: newPool(32 * 1024),
bufferPool: NewPool(32 * 1024),
}
// Create the connect request

View File

@ -9,6 +9,14 @@ import (
// Option user's option
type Option func(s *Server)
// WithBufferPool can be provided to implement custom buffer pool
// By default, buffer pool use size is 32k
func WithBufferPool(bufferPool BufferPool) Option {
return func(s *Server) {
s.bufferPool = bufferPool
}
}
// WithAuthMethods can be provided to implement custom authentication
// By default, "auth-less" mode is enabled.
// For password-based auth use UserPassAuthenticator.
@ -26,9 +34,7 @@ func WithAuthMethods(authMethods []Authenticator) Option {
// and AUthMethods is nil, then "auth-less" mode is enabled.
func WithCredential(cs CredentialStore) Option {
return func(s *Server) {
if cs != nil {
s.credentials = cs
}
s.credentials = cs
}
}
@ -36,9 +42,7 @@ func WithCredential(cs CredentialStore) Option {
// Defaults to DNSResolver if not provided.
func WithResolver(res NameResolver) Option {
return func(s *Server) {
if res != nil {
s.resolver = res
}
s.resolver = res
}
}
@ -46,9 +50,7 @@ func WithResolver(res NameResolver) Option {
// various commands. If not provided, NewPermitAll is used.
func WithRule(rule RuleSet) Option {
return func(s *Server) {
if rule != nil {
s.rules = rule
}
s.rules = rule
}
}
@ -57,9 +59,7 @@ func WithRule(rule RuleSet) Option {
// Defaults to NoRewrite.
func WithRewriter(rew AddressRewriter) Option {
return func(s *Server) {
if rew != nil {
s.rewriter = rew
}
s.rewriter = rew
}
}
@ -77,27 +77,21 @@ func WithBindIP(ip net.IP) Option {
// Defaults to ioutil.Discard.
func WithLogger(l Logger) Option {
return func(s *Server) {
if l != nil {
s.logger = l
}
s.logger = l
}
}
// WithDial Optional function for dialing out
func WithDial(dial func(ctx context.Context, network, addr string) (net.Conn, error)) Option {
return func(s *Server) {
if dial != nil {
s.dial = dial
}
s.dial = dial
}
}
// WithGPool can be provided to do custom goroutine pool.
func WithGPool(pool GPool) Option {
return func(s *Server) {
if pool != nil {
s.gPool = pool
}
s.gPool = pool
}
}

14
pool.go
View File

@ -4,22 +4,34 @@ import (
"sync"
)
// A BufferPool is an interface for getting and returning temporary
// byte slices for use by io.CopyBuffer.
type BufferPool interface {
Get() []byte
Put([]byte)
}
type pool struct {
size int
pool *sync.Pool
}
func newPool(size int) *pool {
// NewPool new buffer pool for getting and returning temporary
// byte slices for use by io.CopyBuffer.
func NewPool(size int) BufferPool {
return &pool{
size,
&sync.Pool{
New: func() interface{} { return make([]byte, 0, size) }},
}
}
// Get implement interface BufferPool
func (sf *pool) Get() []byte {
return sf.pool.Get().([]byte)
}
// Put implement interface BufferPool
func (sf *pool) Put(b []byte) {
if cap(b) != sf.size {
panic("invalid buffer size that's put into leaky buffer")

View File

@ -8,7 +8,7 @@ import (
)
func TestPool(t *testing.T) {
p := newPool(2048)
p := NewPool(2048)
b := p.Get()
bs := b[0:cap(b)]
require.Equal(t, cap(b), len(bs))
@ -21,7 +21,7 @@ func TestPool(t *testing.T) {
}
func BenchmarkSyncPool(b *testing.B) {
p := newPool(32 * 1024)
p := NewPool(32 * 1024)
wg := new(sync.WaitGroup)
b.ResetTimer()

View File

@ -48,7 +48,7 @@ type Server struct {
// Optional function for dialing out
dial func(ctx context.Context, network, addr string) (net.Conn, error)
// buffer pool
bufferPool *pool
bufferPool BufferPool
// goroutine pool
gPool GPool
// user's handle
@ -62,7 +62,7 @@ func NewServer(opts ...Option) *Server {
server := &Server{
authMethods: make(map[uint8]Authenticator),
authCustomMethods: []Authenticator{&NoAuthAuthenticator{}},
bufferPool: newPool(2 * 1024),
bufferPool: NewPool(32 * 1024),
resolver: DNSResolver{},
rules: NewPermitAll(),
logger: NewLogger(log.New(ioutil.Discard, "socks5: ", log.LstdFlags)),

View File

@ -6,13 +6,13 @@ import (
"net"
)
// Datagram udp packet
// The SOCKS UDP request/response is formed as follows:
// +-----+------+-------+----------+----------+----------+
// | RSV | FRAG | ATYP | DST.ADDR | DST.PORT | DATA |
// +-----+------+-------+----------+----------+----------+
// | 2 | 1 | X'00' | Variable | 2 | Variable |
// +-----+------+-------+----------+----------+----------+
// Datagram udp packet
type Datagram struct {
RSV uint16
Frag byte
@ -34,7 +34,7 @@ func NewDatagram(destAddr string, data []byte) (p Datagram, err error) {
return
}
// ParseRequest parse to datagram
// ParseDatagram parse to datagram from bytes
func ParseDatagram(b []byte) (da Datagram, err error) {
if len(b) < 4+net.IPv4len+2 { // no enough data
err = errors.New("datagram to short")

View File

@ -4,6 +4,7 @@ import (
"errors"
)
// error defined
var (
ErrUnrecognizedAddrType = errors.New("Unrecognized address type")
ErrNotSupportVersion = errors.New("not support version")

View File

@ -146,7 +146,7 @@ func (h Reply) Bytes() (b []byte) {
return b
}
// ParseRequest to request from io.Reader
// ParseReply parse to reply from io.Reader
func ParseReply(r io.Reader) (rep Reply, err error) {
// Read the version and command
tmp := []byte{0, 0}

View File

@ -1,6 +1,6 @@
package statute
// socks protocol version
// VersionSocks5 socks protocol version
const VersionSocks5 = byte(0x05)
// request command defined

View File

@ -16,7 +16,7 @@ type AddrSpec struct {
AddrType byte
}
// DstAddress returns a string suitable to dial; prefer returning IP-based
// String returns a string suitable to dial; prefer returning IP-based
// address, fallback to FQDN
func (a *AddrSpec) String() string {
if 0 != len(a.IP) {
@ -25,7 +25,7 @@ func (a *AddrSpec) String() string {
return net.JoinHostPort(a.FQDN, strconv.Itoa(a.Port))
}
// DstAddress returns a string which may be specified
// Address returns a string which may be specified
// if IPv4/IPv6 will return < ip:port >
// if FQDN will return < domain ip:port >
// Note: do not used to dial, Please use String