use buffer pool interface
This commit is contained in:
parent
d83167e33c
commit
9935a8baca
3
.gitignore
vendored
3
.gitignore
vendored
@ -6,6 +6,7 @@
|
|||||||
# Folders
|
# Folders
|
||||||
_obj
|
_obj
|
||||||
_test
|
_test
|
||||||
|
.idea
|
||||||
|
|
||||||
# Architecture specific extensions/prefixes
|
# Architecture specific extensions/prefixes
|
||||||
*.[568vq]
|
*.[568vq]
|
||||||
@ -20,4 +21,4 @@ _cgo_export.*
|
|||||||
_testmain.go
|
_testmain.go
|
||||||
|
|
||||||
*.exe
|
*.exe
|
||||||
.idea
|
|
||||||
|
@ -1,14 +1,10 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"io"
|
|
||||||
"log"
|
"log"
|
||||||
"net"
|
|
||||||
"os"
|
"os"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/thinkgos/go-socks5"
|
"github.com/thinkgos/go-socks5"
|
||||||
"github.com/thinkgos/go-socks5/client"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func handleErr(err error) {
|
func handleErr(err error) {
|
||||||
@ -18,41 +14,41 @@ func handleErr(err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
// Create a local listener
|
// // Create a local listener
|
||||||
l, err := net.Listen("tcp", "127.0.0.1:0")
|
// l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
handleErr(err)
|
// handleErr(err)
|
||||||
|
//
|
||||||
go func() {
|
// go func() {
|
||||||
conn, err := l.Accept()
|
// conn, err := l.Accept()
|
||||||
handleErr(err)
|
// handleErr(err)
|
||||||
defer conn.Close()
|
// defer conn.Close()
|
||||||
|
//
|
||||||
buf := make([]byte, 4)
|
// buf := make([]byte, 4)
|
||||||
_, err = io.ReadAtLeast(conn, buf, 4)
|
// _, err = io.ReadAtLeast(conn, buf, 4)
|
||||||
handleErr(err)
|
// handleErr(err)
|
||||||
log.Printf("server: %+v", string(buf))
|
// log.Printf("server: %+v", string(buf))
|
||||||
conn.Write([]byte("pong"))
|
// conn.Write([]byte("pong"))
|
||||||
}()
|
// }()
|
||||||
lAddr := l.Addr().(*net.TCPAddr)
|
// lAddr := l.Addr().(*net.TCPAddr)
|
||||||
|
//
|
||||||
go func() {
|
// go func() {
|
||||||
time.Sleep(time.Second)
|
// time.Sleep(time.Second)
|
||||||
c, err := client.NewClient("127.0.0.1:1080")
|
// c, err := client.NewClient("127.0.0.1:1080")
|
||||||
handleErr(err)
|
// handleErr(err)
|
||||||
con, err := c.Dial("tcp", lAddr.String())
|
// con, err := c.Dial("tcp", lAddr.String())
|
||||||
handleErr(err)
|
// handleErr(err)
|
||||||
con.Write([]byte("ping"))
|
// con.Write([]byte("ping"))
|
||||||
out := make([]byte, 4)
|
// out := make([]byte, 4)
|
||||||
_ = con.SetDeadline(time.Now().Add(time.Second)) // nolint: errcheck
|
// _ = con.SetDeadline(time.Now().Add(time.Second)) // nolint: errcheck
|
||||||
_, err = io.ReadFull(con, out)
|
// _, err = io.ReadFull(con, out)
|
||||||
log.Printf("client: %+v", string(out))
|
// log.Printf("client: %+v", string(out))
|
||||||
}()
|
// }()
|
||||||
|
|
||||||
// Create a SOCKS5 server
|
// Create a SOCKS5 server
|
||||||
server := socks5.NewServer(socks5.WithLogger(socks5.NewLogger(log.New(os.Stdout, "socks5: ", log.LstdFlags))))
|
server := socks5.NewServer(socks5.WithLogger(socks5.NewLogger(log.New(os.Stdout, "socks5: ", log.LstdFlags))))
|
||||||
|
|
||||||
// Create SOCKS5 proxy on localhost port 8000
|
// Create SOCKS5 proxy on localhost port 8000
|
||||||
if err := server.ListenAndServe("tcp", "127.0.0.1:1080"); err != nil {
|
if err := server.ListenAndServe("tcp", ":10800"); err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -150,7 +150,7 @@ func (s *Server) handleConnect(ctx context.Context, writer io.Writer, request *R
|
|||||||
}
|
}
|
||||||
|
|
||||||
// handleBind is used to handle a connect command
|
// handleBind is used to handle a connect command
|
||||||
func (s *Server) handleBind(_ context.Context, writer io.Writer, request *Request) error {
|
func (s *Server) handleBind(_ context.Context, writer io.Writer, _ *Request) error {
|
||||||
// TODO: Support bind
|
// TODO: Support bind
|
||||||
if err := SendReply(writer, statute.RepCommandNotSupported, nil); err != nil {
|
if err := SendReply(writer, statute.RepCommandNotSupported, nil); err != nil {
|
||||||
return fmt.Errorf("failed to send reply: %v", err)
|
return fmt.Errorf("failed to send reply: %v", err)
|
||||||
|
@ -49,7 +49,7 @@ func TestRequest_Connect(t *testing.T) {
|
|||||||
rules: NewPermitAll(),
|
rules: NewPermitAll(),
|
||||||
resolver: DNSResolver{},
|
resolver: DNSResolver{},
|
||||||
logger: NewLogger(log.New(os.Stdout, "socks5: ", log.LstdFlags)),
|
logger: NewLogger(log.New(os.Stdout, "socks5: ", log.LstdFlags)),
|
||||||
bufferPool: newPool(32 * 1024),
|
bufferPool: NewPool(32 * 1024),
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create the connect request
|
// Create the connect request
|
||||||
@ -106,7 +106,7 @@ func TestRequest_Connect_RuleFail(t *testing.T) {
|
|||||||
rules: NewPermitNone(),
|
rules: NewPermitNone(),
|
||||||
resolver: DNSResolver{},
|
resolver: DNSResolver{},
|
||||||
logger: NewLogger(log.New(os.Stdout, "socks5: ", log.LstdFlags)),
|
logger: NewLogger(log.New(os.Stdout, "socks5: ", log.LstdFlags)),
|
||||||
bufferPool: newPool(32 * 1024),
|
bufferPool: NewPool(32 * 1024),
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create the connect request
|
// Create the connect request
|
||||||
|
36
option.go
36
option.go
@ -9,6 +9,14 @@ import (
|
|||||||
// Option user's option
|
// Option user's option
|
||||||
type Option func(s *Server)
|
type Option func(s *Server)
|
||||||
|
|
||||||
|
// WithBufferPool can be provided to implement custom buffer pool
|
||||||
|
// By default, buffer pool use size is 32k
|
||||||
|
func WithBufferPool(bufferPool BufferPool) Option {
|
||||||
|
return func(s *Server) {
|
||||||
|
s.bufferPool = bufferPool
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// WithAuthMethods can be provided to implement custom authentication
|
// WithAuthMethods can be provided to implement custom authentication
|
||||||
// By default, "auth-less" mode is enabled.
|
// By default, "auth-less" mode is enabled.
|
||||||
// For password-based auth use UserPassAuthenticator.
|
// For password-based auth use UserPassAuthenticator.
|
||||||
@ -26,9 +34,7 @@ func WithAuthMethods(authMethods []Authenticator) Option {
|
|||||||
// and AUthMethods is nil, then "auth-less" mode is enabled.
|
// and AUthMethods is nil, then "auth-less" mode is enabled.
|
||||||
func WithCredential(cs CredentialStore) Option {
|
func WithCredential(cs CredentialStore) Option {
|
||||||
return func(s *Server) {
|
return func(s *Server) {
|
||||||
if cs != nil {
|
s.credentials = cs
|
||||||
s.credentials = cs
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -36,9 +42,7 @@ func WithCredential(cs CredentialStore) Option {
|
|||||||
// Defaults to DNSResolver if not provided.
|
// Defaults to DNSResolver if not provided.
|
||||||
func WithResolver(res NameResolver) Option {
|
func WithResolver(res NameResolver) Option {
|
||||||
return func(s *Server) {
|
return func(s *Server) {
|
||||||
if res != nil {
|
s.resolver = res
|
||||||
s.resolver = res
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -46,9 +50,7 @@ func WithResolver(res NameResolver) Option {
|
|||||||
// various commands. If not provided, NewPermitAll is used.
|
// various commands. If not provided, NewPermitAll is used.
|
||||||
func WithRule(rule RuleSet) Option {
|
func WithRule(rule RuleSet) Option {
|
||||||
return func(s *Server) {
|
return func(s *Server) {
|
||||||
if rule != nil {
|
s.rules = rule
|
||||||
s.rules = rule
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -57,9 +59,7 @@ func WithRule(rule RuleSet) Option {
|
|||||||
// Defaults to NoRewrite.
|
// Defaults to NoRewrite.
|
||||||
func WithRewriter(rew AddressRewriter) Option {
|
func WithRewriter(rew AddressRewriter) Option {
|
||||||
return func(s *Server) {
|
return func(s *Server) {
|
||||||
if rew != nil {
|
s.rewriter = rew
|
||||||
s.rewriter = rew
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -77,27 +77,21 @@ func WithBindIP(ip net.IP) Option {
|
|||||||
// Defaults to ioutil.Discard.
|
// Defaults to ioutil.Discard.
|
||||||
func WithLogger(l Logger) Option {
|
func WithLogger(l Logger) Option {
|
||||||
return func(s *Server) {
|
return func(s *Server) {
|
||||||
if l != nil {
|
s.logger = l
|
||||||
s.logger = l
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithDial Optional function for dialing out
|
// WithDial Optional function for dialing out
|
||||||
func WithDial(dial func(ctx context.Context, network, addr string) (net.Conn, error)) Option {
|
func WithDial(dial func(ctx context.Context, network, addr string) (net.Conn, error)) Option {
|
||||||
return func(s *Server) {
|
return func(s *Server) {
|
||||||
if dial != nil {
|
s.dial = dial
|
||||||
s.dial = dial
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithGPool can be provided to do custom goroutine pool.
|
// WithGPool can be provided to do custom goroutine pool.
|
||||||
func WithGPool(pool GPool) Option {
|
func WithGPool(pool GPool) Option {
|
||||||
return func(s *Server) {
|
return func(s *Server) {
|
||||||
if pool != nil {
|
s.gPool = pool
|
||||||
s.gPool = pool
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
14
pool.go
14
pool.go
@ -4,22 +4,34 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// A BufferPool is an interface for getting and returning temporary
|
||||||
|
// byte slices for use by io.CopyBuffer.
|
||||||
|
type BufferPool interface {
|
||||||
|
Get() []byte
|
||||||
|
Put([]byte)
|
||||||
|
}
|
||||||
|
|
||||||
type pool struct {
|
type pool struct {
|
||||||
size int
|
size int
|
||||||
pool *sync.Pool
|
pool *sync.Pool
|
||||||
}
|
}
|
||||||
|
|
||||||
func newPool(size int) *pool {
|
// NewPool new buffer pool for getting and returning temporary
|
||||||
|
// byte slices for use by io.CopyBuffer.
|
||||||
|
func NewPool(size int) BufferPool {
|
||||||
return &pool{
|
return &pool{
|
||||||
size,
|
size,
|
||||||
&sync.Pool{
|
&sync.Pool{
|
||||||
New: func() interface{} { return make([]byte, 0, size) }},
|
New: func() interface{} { return make([]byte, 0, size) }},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get implement interface BufferPool
|
||||||
func (sf *pool) Get() []byte {
|
func (sf *pool) Get() []byte {
|
||||||
return sf.pool.Get().([]byte)
|
return sf.pool.Get().([]byte)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Put implement interface BufferPool
|
||||||
func (sf *pool) Put(b []byte) {
|
func (sf *pool) Put(b []byte) {
|
||||||
if cap(b) != sf.size {
|
if cap(b) != sf.size {
|
||||||
panic("invalid buffer size that's put into leaky buffer")
|
panic("invalid buffer size that's put into leaky buffer")
|
||||||
|
@ -8,7 +8,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestPool(t *testing.T) {
|
func TestPool(t *testing.T) {
|
||||||
p := newPool(2048)
|
p := NewPool(2048)
|
||||||
b := p.Get()
|
b := p.Get()
|
||||||
bs := b[0:cap(b)]
|
bs := b[0:cap(b)]
|
||||||
require.Equal(t, cap(b), len(bs))
|
require.Equal(t, cap(b), len(bs))
|
||||||
@ -21,7 +21,7 @@ func TestPool(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkSyncPool(b *testing.B) {
|
func BenchmarkSyncPool(b *testing.B) {
|
||||||
p := newPool(32 * 1024)
|
p := NewPool(32 * 1024)
|
||||||
wg := new(sync.WaitGroup)
|
wg := new(sync.WaitGroup)
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
|
@ -48,7 +48,7 @@ type Server struct {
|
|||||||
// Optional function for dialing out
|
// Optional function for dialing out
|
||||||
dial func(ctx context.Context, network, addr string) (net.Conn, error)
|
dial func(ctx context.Context, network, addr string) (net.Conn, error)
|
||||||
// buffer pool
|
// buffer pool
|
||||||
bufferPool *pool
|
bufferPool BufferPool
|
||||||
// goroutine pool
|
// goroutine pool
|
||||||
gPool GPool
|
gPool GPool
|
||||||
// user's handle
|
// user's handle
|
||||||
@ -62,7 +62,7 @@ func NewServer(opts ...Option) *Server {
|
|||||||
server := &Server{
|
server := &Server{
|
||||||
authMethods: make(map[uint8]Authenticator),
|
authMethods: make(map[uint8]Authenticator),
|
||||||
authCustomMethods: []Authenticator{&NoAuthAuthenticator{}},
|
authCustomMethods: []Authenticator{&NoAuthAuthenticator{}},
|
||||||
bufferPool: newPool(2 * 1024),
|
bufferPool: NewPool(32 * 1024),
|
||||||
resolver: DNSResolver{},
|
resolver: DNSResolver{},
|
||||||
rules: NewPermitAll(),
|
rules: NewPermitAll(),
|
||||||
logger: NewLogger(log.New(ioutil.Discard, "socks5: ", log.LstdFlags)),
|
logger: NewLogger(log.New(ioutil.Discard, "socks5: ", log.LstdFlags)),
|
||||||
|
@ -6,13 +6,13 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Datagram udp packet
|
||||||
// The SOCKS UDP request/response is formed as follows:
|
// The SOCKS UDP request/response is formed as follows:
|
||||||
// +-----+------+-------+----------+----------+----------+
|
// +-----+------+-------+----------+----------+----------+
|
||||||
// | RSV | FRAG | ATYP | DST.ADDR | DST.PORT | DATA |
|
// | RSV | FRAG | ATYP | DST.ADDR | DST.PORT | DATA |
|
||||||
// +-----+------+-------+----------+----------+----------+
|
// +-----+------+-------+----------+----------+----------+
|
||||||
// | 2 | 1 | X'00' | Variable | 2 | Variable |
|
// | 2 | 1 | X'00' | Variable | 2 | Variable |
|
||||||
// +-----+------+-------+----------+----------+----------+
|
// +-----+------+-------+----------+----------+----------+
|
||||||
// Datagram udp packet
|
|
||||||
type Datagram struct {
|
type Datagram struct {
|
||||||
RSV uint16
|
RSV uint16
|
||||||
Frag byte
|
Frag byte
|
||||||
@ -34,7 +34,7 @@ func NewDatagram(destAddr string, data []byte) (p Datagram, err error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// ParseRequest parse to datagram
|
// ParseDatagram parse to datagram from bytes
|
||||||
func ParseDatagram(b []byte) (da Datagram, err error) {
|
func ParseDatagram(b []byte) (da Datagram, err error) {
|
||||||
if len(b) < 4+net.IPv4len+2 { // no enough data
|
if len(b) < 4+net.IPv4len+2 { // no enough data
|
||||||
err = errors.New("datagram to short")
|
err = errors.New("datagram to short")
|
||||||
|
@ -4,6 +4,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// error defined
|
||||||
var (
|
var (
|
||||||
ErrUnrecognizedAddrType = errors.New("Unrecognized address type")
|
ErrUnrecognizedAddrType = errors.New("Unrecognized address type")
|
||||||
ErrNotSupportVersion = errors.New("not support version")
|
ErrNotSupportVersion = errors.New("not support version")
|
||||||
|
@ -146,7 +146,7 @@ func (h Reply) Bytes() (b []byte) {
|
|||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
// ParseRequest to request from io.Reader
|
// ParseReply parse to reply from io.Reader
|
||||||
func ParseReply(r io.Reader) (rep Reply, err error) {
|
func ParseReply(r io.Reader) (rep Reply, err error) {
|
||||||
// Read the version and command
|
// Read the version and command
|
||||||
tmp := []byte{0, 0}
|
tmp := []byte{0, 0}
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
package statute
|
package statute
|
||||||
|
|
||||||
// socks protocol version
|
// VersionSocks5 socks protocol version
|
||||||
const VersionSocks5 = byte(0x05)
|
const VersionSocks5 = byte(0x05)
|
||||||
|
|
||||||
// request command defined
|
// request command defined
|
||||||
|
@ -16,7 +16,7 @@ type AddrSpec struct {
|
|||||||
AddrType byte
|
AddrType byte
|
||||||
}
|
}
|
||||||
|
|
||||||
// DstAddress returns a string suitable to dial; prefer returning IP-based
|
// String returns a string suitable to dial; prefer returning IP-based
|
||||||
// address, fallback to FQDN
|
// address, fallback to FQDN
|
||||||
func (a *AddrSpec) String() string {
|
func (a *AddrSpec) String() string {
|
||||||
if 0 != len(a.IP) {
|
if 0 != len(a.IP) {
|
||||||
@ -25,7 +25,7 @@ func (a *AddrSpec) String() string {
|
|||||||
return net.JoinHostPort(a.FQDN, strconv.Itoa(a.Port))
|
return net.JoinHostPort(a.FQDN, strconv.Itoa(a.Port))
|
||||||
}
|
}
|
||||||
|
|
||||||
// DstAddress returns a string which may be specified
|
// Address returns a string which may be specified
|
||||||
// if IPv4/IPv6 will return < ip:port >
|
// if IPv4/IPv6 will return < ip:port >
|
||||||
// if FQDN will return < domain ip:port >
|
// if FQDN will return < domain ip:port >
|
||||||
// Note: do not used to dial, Please use String
|
// Note: do not used to dial, Please use String
|
||||||
|
Loading…
Reference in New Issue
Block a user