use buffer pool interface
This commit is contained in:
parent
d83167e33c
commit
9935a8baca
|
@ -6,6 +6,7 @@
|
|||
# Folders
|
||||
_obj
|
||||
_test
|
||||
.idea
|
||||
|
||||
# Architecture specific extensions/prefixes
|
||||
*.[568vq]
|
||||
|
@ -20,4 +21,4 @@ _cgo_export.*
|
|||
_testmain.go
|
||||
|
||||
*.exe
|
||||
.idea
|
||||
|
||||
|
|
|
@ -1,14 +1,10 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/thinkgos/go-socks5"
|
||||
"github.com/thinkgos/go-socks5/client"
|
||||
)
|
||||
|
||||
func handleErr(err error) {
|
||||
|
@ -18,41 +14,41 @@ func handleErr(err error) {
|
|||
}
|
||||
|
||||
func main() {
|
||||
// Create a local listener
|
||||
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
handleErr(err)
|
||||
|
||||
go func() {
|
||||
conn, err := l.Accept()
|
||||
handleErr(err)
|
||||
defer conn.Close()
|
||||
|
||||
buf := make([]byte, 4)
|
||||
_, err = io.ReadAtLeast(conn, buf, 4)
|
||||
handleErr(err)
|
||||
log.Printf("server: %+v", string(buf))
|
||||
conn.Write([]byte("pong"))
|
||||
}()
|
||||
lAddr := l.Addr().(*net.TCPAddr)
|
||||
|
||||
go func() {
|
||||
time.Sleep(time.Second)
|
||||
c, err := client.NewClient("127.0.0.1:1080")
|
||||
handleErr(err)
|
||||
con, err := c.Dial("tcp", lAddr.String())
|
||||
handleErr(err)
|
||||
con.Write([]byte("ping"))
|
||||
out := make([]byte, 4)
|
||||
_ = con.SetDeadline(time.Now().Add(time.Second)) // nolint: errcheck
|
||||
_, err = io.ReadFull(con, out)
|
||||
log.Printf("client: %+v", string(out))
|
||||
}()
|
||||
// // Create a local listener
|
||||
// l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
// handleErr(err)
|
||||
//
|
||||
// go func() {
|
||||
// conn, err := l.Accept()
|
||||
// handleErr(err)
|
||||
// defer conn.Close()
|
||||
//
|
||||
// buf := make([]byte, 4)
|
||||
// _, err = io.ReadAtLeast(conn, buf, 4)
|
||||
// handleErr(err)
|
||||
// log.Printf("server: %+v", string(buf))
|
||||
// conn.Write([]byte("pong"))
|
||||
// }()
|
||||
// lAddr := l.Addr().(*net.TCPAddr)
|
||||
//
|
||||
// go func() {
|
||||
// time.Sleep(time.Second)
|
||||
// c, err := client.NewClient("127.0.0.1:1080")
|
||||
// handleErr(err)
|
||||
// con, err := c.Dial("tcp", lAddr.String())
|
||||
// handleErr(err)
|
||||
// con.Write([]byte("ping"))
|
||||
// out := make([]byte, 4)
|
||||
// _ = con.SetDeadline(time.Now().Add(time.Second)) // nolint: errcheck
|
||||
// _, err = io.ReadFull(con, out)
|
||||
// log.Printf("client: %+v", string(out))
|
||||
// }()
|
||||
|
||||
// Create a SOCKS5 server
|
||||
server := socks5.NewServer(socks5.WithLogger(socks5.NewLogger(log.New(os.Stdout, "socks5: ", log.LstdFlags))))
|
||||
|
||||
// Create SOCKS5 proxy on localhost port 8000
|
||||
if err := server.ListenAndServe("tcp", "127.0.0.1:1080"); err != nil {
|
||||
if err := server.ListenAndServe("tcp", ":10800"); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -150,7 +150,7 @@ func (s *Server) handleConnect(ctx context.Context, writer io.Writer, request *R
|
|||
}
|
||||
|
||||
// handleBind is used to handle a connect command
|
||||
func (s *Server) handleBind(_ context.Context, writer io.Writer, request *Request) error {
|
||||
func (s *Server) handleBind(_ context.Context, writer io.Writer, _ *Request) error {
|
||||
// TODO: Support bind
|
||||
if err := SendReply(writer, statute.RepCommandNotSupported, nil); err != nil {
|
||||
return fmt.Errorf("failed to send reply: %v", err)
|
||||
|
|
|
@ -49,7 +49,7 @@ func TestRequest_Connect(t *testing.T) {
|
|||
rules: NewPermitAll(),
|
||||
resolver: DNSResolver{},
|
||||
logger: NewLogger(log.New(os.Stdout, "socks5: ", log.LstdFlags)),
|
||||
bufferPool: newPool(32 * 1024),
|
||||
bufferPool: NewPool(32 * 1024),
|
||||
}
|
||||
|
||||
// Create the connect request
|
||||
|
@ -106,7 +106,7 @@ func TestRequest_Connect_RuleFail(t *testing.T) {
|
|||
rules: NewPermitNone(),
|
||||
resolver: DNSResolver{},
|
||||
logger: NewLogger(log.New(os.Stdout, "socks5: ", log.LstdFlags)),
|
||||
bufferPool: newPool(32 * 1024),
|
||||
bufferPool: NewPool(32 * 1024),
|
||||
}
|
||||
|
||||
// Create the connect request
|
||||
|
|
36
option.go
36
option.go
|
@ -9,6 +9,14 @@ import (
|
|||
// Option user's option
|
||||
type Option func(s *Server)
|
||||
|
||||
// WithBufferPool can be provided to implement custom buffer pool
|
||||
// By default, buffer pool use size is 32k
|
||||
func WithBufferPool(bufferPool BufferPool) Option {
|
||||
return func(s *Server) {
|
||||
s.bufferPool = bufferPool
|
||||
}
|
||||
}
|
||||
|
||||
// WithAuthMethods can be provided to implement custom authentication
|
||||
// By default, "auth-less" mode is enabled.
|
||||
// For password-based auth use UserPassAuthenticator.
|
||||
|
@ -26,9 +34,7 @@ func WithAuthMethods(authMethods []Authenticator) Option {
|
|||
// and AUthMethods is nil, then "auth-less" mode is enabled.
|
||||
func WithCredential(cs CredentialStore) Option {
|
||||
return func(s *Server) {
|
||||
if cs != nil {
|
||||
s.credentials = cs
|
||||
}
|
||||
s.credentials = cs
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -36,9 +42,7 @@ func WithCredential(cs CredentialStore) Option {
|
|||
// Defaults to DNSResolver if not provided.
|
||||
func WithResolver(res NameResolver) Option {
|
||||
return func(s *Server) {
|
||||
if res != nil {
|
||||
s.resolver = res
|
||||
}
|
||||
s.resolver = res
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -46,9 +50,7 @@ func WithResolver(res NameResolver) Option {
|
|||
// various commands. If not provided, NewPermitAll is used.
|
||||
func WithRule(rule RuleSet) Option {
|
||||
return func(s *Server) {
|
||||
if rule != nil {
|
||||
s.rules = rule
|
||||
}
|
||||
s.rules = rule
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -57,9 +59,7 @@ func WithRule(rule RuleSet) Option {
|
|||
// Defaults to NoRewrite.
|
||||
func WithRewriter(rew AddressRewriter) Option {
|
||||
return func(s *Server) {
|
||||
if rew != nil {
|
||||
s.rewriter = rew
|
||||
}
|
||||
s.rewriter = rew
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -77,27 +77,21 @@ func WithBindIP(ip net.IP) Option {
|
|||
// Defaults to ioutil.Discard.
|
||||
func WithLogger(l Logger) Option {
|
||||
return func(s *Server) {
|
||||
if l != nil {
|
||||
s.logger = l
|
||||
}
|
||||
s.logger = l
|
||||
}
|
||||
}
|
||||
|
||||
// WithDial Optional function for dialing out
|
||||
func WithDial(dial func(ctx context.Context, network, addr string) (net.Conn, error)) Option {
|
||||
return func(s *Server) {
|
||||
if dial != nil {
|
||||
s.dial = dial
|
||||
}
|
||||
s.dial = dial
|
||||
}
|
||||
}
|
||||
|
||||
// WithGPool can be provided to do custom goroutine pool.
|
||||
func WithGPool(pool GPool) Option {
|
||||
return func(s *Server) {
|
||||
if pool != nil {
|
||||
s.gPool = pool
|
||||
}
|
||||
s.gPool = pool
|
||||
}
|
||||
}
|
||||
|
||||
|
|
14
pool.go
14
pool.go
|
@ -4,22 +4,34 @@ import (
|
|||
"sync"
|
||||
)
|
||||
|
||||
// A BufferPool is an interface for getting and returning temporary
|
||||
// byte slices for use by io.CopyBuffer.
|
||||
type BufferPool interface {
|
||||
Get() []byte
|
||||
Put([]byte)
|
||||
}
|
||||
|
||||
type pool struct {
|
||||
size int
|
||||
pool *sync.Pool
|
||||
}
|
||||
|
||||
func newPool(size int) *pool {
|
||||
// NewPool new buffer pool for getting and returning temporary
|
||||
// byte slices for use by io.CopyBuffer.
|
||||
func NewPool(size int) BufferPool {
|
||||
return &pool{
|
||||
size,
|
||||
&sync.Pool{
|
||||
New: func() interface{} { return make([]byte, 0, size) }},
|
||||
}
|
||||
}
|
||||
|
||||
// Get implement interface BufferPool
|
||||
func (sf *pool) Get() []byte {
|
||||
return sf.pool.Get().([]byte)
|
||||
}
|
||||
|
||||
// Put implement interface BufferPool
|
||||
func (sf *pool) Put(b []byte) {
|
||||
if cap(b) != sf.size {
|
||||
panic("invalid buffer size that's put into leaky buffer")
|
||||
|
|
|
@ -8,7 +8,7 @@ import (
|
|||
)
|
||||
|
||||
func TestPool(t *testing.T) {
|
||||
p := newPool(2048)
|
||||
p := NewPool(2048)
|
||||
b := p.Get()
|
||||
bs := b[0:cap(b)]
|
||||
require.Equal(t, cap(b), len(bs))
|
||||
|
@ -21,7 +21,7 @@ func TestPool(t *testing.T) {
|
|||
}
|
||||
|
||||
func BenchmarkSyncPool(b *testing.B) {
|
||||
p := newPool(32 * 1024)
|
||||
p := NewPool(32 * 1024)
|
||||
wg := new(sync.WaitGroup)
|
||||
|
||||
b.ResetTimer()
|
||||
|
|
|
@ -48,7 +48,7 @@ type Server struct {
|
|||
// Optional function for dialing out
|
||||
dial func(ctx context.Context, network, addr string) (net.Conn, error)
|
||||
// buffer pool
|
||||
bufferPool *pool
|
||||
bufferPool BufferPool
|
||||
// goroutine pool
|
||||
gPool GPool
|
||||
// user's handle
|
||||
|
@ -62,7 +62,7 @@ func NewServer(opts ...Option) *Server {
|
|||
server := &Server{
|
||||
authMethods: make(map[uint8]Authenticator),
|
||||
authCustomMethods: []Authenticator{&NoAuthAuthenticator{}},
|
||||
bufferPool: newPool(2 * 1024),
|
||||
bufferPool: NewPool(32 * 1024),
|
||||
resolver: DNSResolver{},
|
||||
rules: NewPermitAll(),
|
||||
logger: NewLogger(log.New(ioutil.Discard, "socks5: ", log.LstdFlags)),
|
||||
|
|
|
@ -6,13 +6,13 @@ import (
|
|||
"net"
|
||||
)
|
||||
|
||||
// Datagram udp packet
|
||||
// The SOCKS UDP request/response is formed as follows:
|
||||
// +-----+------+-------+----------+----------+----------+
|
||||
// | RSV | FRAG | ATYP | DST.ADDR | DST.PORT | DATA |
|
||||
// +-----+------+-------+----------+----------+----------+
|
||||
// | 2 | 1 | X'00' | Variable | 2 | Variable |
|
||||
// +-----+------+-------+----------+----------+----------+
|
||||
// Datagram udp packet
|
||||
type Datagram struct {
|
||||
RSV uint16
|
||||
Frag byte
|
||||
|
@ -34,7 +34,7 @@ func NewDatagram(destAddr string, data []byte) (p Datagram, err error) {
|
|||
return
|
||||
}
|
||||
|
||||
// ParseRequest parse to datagram
|
||||
// ParseDatagram parse to datagram from bytes
|
||||
func ParseDatagram(b []byte) (da Datagram, err error) {
|
||||
if len(b) < 4+net.IPv4len+2 { // no enough data
|
||||
err = errors.New("datagram to short")
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"errors"
|
||||
)
|
||||
|
||||
// error defined
|
||||
var (
|
||||
ErrUnrecognizedAddrType = errors.New("Unrecognized address type")
|
||||
ErrNotSupportVersion = errors.New("not support version")
|
||||
|
|
|
@ -146,7 +146,7 @@ func (h Reply) Bytes() (b []byte) {
|
|||
return b
|
||||
}
|
||||
|
||||
// ParseRequest to request from io.Reader
|
||||
// ParseReply parse to reply from io.Reader
|
||||
func ParseReply(r io.Reader) (rep Reply, err error) {
|
||||
// Read the version and command
|
||||
tmp := []byte{0, 0}
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
package statute
|
||||
|
||||
// socks protocol version
|
||||
// VersionSocks5 socks protocol version
|
||||
const VersionSocks5 = byte(0x05)
|
||||
|
||||
// request command defined
|
||||
|
|
|
@ -16,7 +16,7 @@ type AddrSpec struct {
|
|||
AddrType byte
|
||||
}
|
||||
|
||||
// DstAddress returns a string suitable to dial; prefer returning IP-based
|
||||
// String returns a string suitable to dial; prefer returning IP-based
|
||||
// address, fallback to FQDN
|
||||
func (a *AddrSpec) String() string {
|
||||
if 0 != len(a.IP) {
|
||||
|
@ -25,7 +25,7 @@ func (a *AddrSpec) String() string {
|
|||
return net.JoinHostPort(a.FQDN, strconv.Itoa(a.Port))
|
||||
}
|
||||
|
||||
// DstAddress returns a string which may be specified
|
||||
// Address returns a string which may be specified
|
||||
// if IPv4/IPv6 will return < ip:port >
|
||||
// if FQDN will return < domain ip:port >
|
||||
// Note: do not used to dial, Please use String
|
||||
|
|
Loading…
Reference in New Issue