use options

This commit is contained in:
mo 2020-04-20 21:17:38 +08:00
parent 5518983a28
commit 1daccf05c2
7 changed files with 237 additions and 134 deletions

48
_example/udp/udp.go Normal file

@ -0,0 +1,48 @@
package main
import (
"bytes"
"log"
"net"
)
func main() {
lAddr := &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 12398,
}
l, err := net.ListenUDP("udp4", lAddr)
if err != nil {
panic(err)
}
defer l.Close()
go func() {
for {
buf := make([]byte, 2048)
n, remote, err := l.ReadFrom(buf)
if err != nil {
return
}
if !bytes.Equal(buf[:n], []byte("ping")) {
log.Println("bad: %v", buf)
}
l.WriteTo([]byte("pong"), remote)
}
}()
cli, err := net.DialUDP("udp", nil, lAddr)
if err != nil {
panic(err)
}
n, err := cli.Write([]byte("ping"))
if err != nil {
panic(err)
}
rsp := make([]byte, 1024)
n, _, err = cli.ReadFrom(rsp)
if err != nil {
panic(err)
}
log.Printf("%s", string(rsp[:n]))
}

@ -10,7 +10,7 @@ func TestNoAuth(t *testing.T) {
req.Write([]byte{1, NoAuth})
var resp bytes.Buffer
s, _ := New(&Config{})
s := New()
ctx, err := s.authenticate(&resp, req)
if err != nil {
t.Fatalf("err: %v", err)
@ -38,7 +38,7 @@ func TestPasswordAuth_Valid(t *testing.T) {
cator := UserPassAuthenticator{Credentials: cred}
s, _ := New(&Config{AuthMethods: []Authenticator{cator}})
s := New(WithAuthMethods([]Authenticator{cator}))
ctx, err := s.authenticate(&resp, req)
if err != nil {
@ -74,7 +74,7 @@ func TestPasswordAuth_Invalid(t *testing.T) {
"foo": "bar",
}
cator := UserPassAuthenticator{Credentials: cred}
s, _ := New(&Config{AuthMethods: []Authenticator{cator}})
s := New(WithAuthMethods([]Authenticator{cator}))
ctx, err := s.authenticate(&resp, req)
if err != UserAuthFailed {
@ -101,7 +101,7 @@ func TestNoSupportedAuth(t *testing.T) {
}
cator := UserPassAuthenticator{Credentials: cred}
s, _ := New(&Config{AuthMethods: []Authenticator{cator}})
s := New(WithAuthMethods([]Authenticator{cator}))
ctx, err := s.authenticate(&resp, req)
if err != NoSupportedAuth {

98
option.go Normal file

@ -0,0 +1,98 @@
package socks5
import (
"context"
"net"
)
type Option func(s *Server)
// AuthMethods can be provided to implement custom authentication
// By default, "auth-less" mode is enabled.
// For password-based auth use UserPassAuthenticator.
func WithAuthMethods(authMethods []Authenticator) Option {
return func(s *Server) {
if len(authMethods) != 0 {
s.authCustomMethods = make([]Authenticator, 0, len(authMethods))
s.authCustomMethods = append(s.authCustomMethods, authMethods...)
}
}
}
// If provided, username/password authentication is enabled,
// by appending a UserPassAuthenticator to AuthMethods. If not provided,
// and AUthMethods is nil, then "auth-less" mode is enabled.
func WithCredential(cs CredentialStore) Option {
return func(s *Server) {
if cs != nil {
s.credentials = cs
}
}
}
// resolver can be provided to do custom name resolution.
// Defaults to DNSResolver if not provided.
func WithResolver(res NameResolver) Option {
return func(s *Server) {
if res != nil {
s.resolver = res
}
}
}
// rules is provided to enable custom logic around permitting
// various commands. If not provided, PermitAll is used.
func WithRule(rule RuleSet) Option {
return func(s *Server) {
if rule != nil {
s.rules = rule
}
}
}
// rewriter can be used to transparently rewrite addresses.
// This is invoked before the RuleSet is invoked.
// Defaults to NoRewrite.
func WithRewriter(rew AddressRewriter) Option {
return func(s *Server) {
if rew != nil {
s.rewriter = rew
}
}
}
// bindIP is used for bind or udp associate
func WithBindIP(ip net.IP) Option {
return func(s *Server) {
if ip != nil {
}
}
}
// logger can be used to provide a custom log target.
// Defaults to ioutil.Discard.
func WithLogger(l Logger) Option {
return func(s *Server) {
if l != nil {
s.logger = l
}
}
}
// Optional function for dialing out
func WithDial(dial func(ctx context.Context, network, addr string) (net.Conn, error)) Option {
return func(s *Server) {
if dial != nil {
s.dial = dial
}
}
}
func WithGPool(pool GPool) Option {
return func(s *Server) {
if pool != nil {
s.gPool = pool
}
}
}

@ -68,7 +68,7 @@ func (s *Server) handleRequest(write io.Writer, req *Request) error {
// Resolve the address if we have a FQDN
dest := req.DestAddr
if dest.FQDN != "" {
ctx_, addr, err := s.config.Resolver.Resolve(ctx, dest.FQDN)
ctx_, addr, err := s.resolver.Resolve(ctx, dest.FQDN)
if err != nil {
if err := sendReply(write, req.Header, hostUnreachable); err != nil {
return fmt.Errorf("failed to send reply, %v", err)
@ -81,8 +81,8 @@ func (s *Server) handleRequest(write io.Writer, req *Request) error {
// Apply any address rewrites
req.realDestAddr = req.DestAddr
if s.config.Rewriter != nil {
ctx, req.realDestAddr = s.config.Rewriter.Rewrite(ctx, req)
if s.rewriter != nil {
ctx, req.realDestAddr = s.rewriter.Rewrite(ctx, req)
}
// Switch on the command
@ -104,7 +104,7 @@ func (s *Server) handleRequest(write io.Writer, req *Request) error {
// handleConnect is used to handle a connect command
func (s *Server) handleConnect(ctx context.Context, writer io.Writer, req *Request) error {
// Check if this is allowed
if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok {
if ctx_, ok := s.rules.Allow(ctx, req); !ok {
if err := sendReply(writer, req.Header, ruleFailure); err != nil {
return fmt.Errorf("failed to send reply, %v", err)
}
@ -114,7 +114,7 @@ func (s *Server) handleConnect(ctx context.Context, writer io.Writer, req *Reque
}
// Attempt to connect
dial := s.config.Dial
dial := s.dial
if dial == nil {
dial = func(ctx context.Context, net_, addr string) (net.Conn, error) {
return net.Dial(net_, addr)
@ -161,7 +161,7 @@ func (s *Server) handleConnect(ctx context.Context, writer io.Writer, req *Reque
// handleBind is used to handle a connect command
func (s *Server) handleBind(ctx context.Context, writer io.Writer, req *Request) error {
// Check if this is allowed
if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok {
if ctx_, ok := s.rules.Allow(ctx, req); !ok {
if err := sendReply(writer, req.Header, ruleFailure); err != nil {
return fmt.Errorf("failed to send reply, %v", err)
}
@ -180,7 +180,7 @@ func (s *Server) handleBind(ctx context.Context, writer io.Writer, req *Request)
// handleAssociate is used to handle a connect command
func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, req *Request) error {
// Check if this is allowed
if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok {
if ctx_, ok := s.rules.Allow(ctx, req); !ok {
if err := sendReply(writer, req.Header, ruleFailure); err != nil {
return fmt.Errorf("failed to send reply, %v", err)
}
@ -190,7 +190,7 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, req *Req
}
// Attempt to connect
dial := s.config.Dial
dial := s.dial
if dial == nil {
dial = func(ctx context.Context, net_, addr string) (net.Conn, error) {
return net.Dial(net_, addr)
@ -229,7 +229,7 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, req *Req
}
defer bindLn.Close()
s.config.Logger.Errorf("target addr %v, listen addr: %s", targetUdp.RemoteAddr(), bindLn.LocalAddr())
s.logger.Errorf("target addr %v, listen addr: %s", targetUdp.RemoteAddr(), bindLn.LocalAddr())
// send BND.ADDR and BND.PORT, client must
if err = sendReply(writer, req.Header, successReply, bindLn.LocalAddr()); err != nil {
return fmt.Errorf("failed to send reply, %v", err)
@ -256,10 +256,10 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, req *Req
buf := bufPool[:cap(bufPool)]
n, srcAddr, err := bindLn.ReadFrom(buf)
if err != nil {
s.config.Logger.Errorf("read data from bind listen address %s failed, %v", bindLn.LocalAddr(), err)
s.logger.Errorf("read data from bind listen address %s failed, %v", bindLn.LocalAddr(), err)
return
}
s.config.Logger.Errorf("data length: %d,%d", n, len(buf))
if n <= 4+net.IPv4len+2 { // no data
continue
}
@ -299,7 +299,7 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, req *Req
// 把消息写给remote sever
if _, err := targetUdp.Write(buf[headLen:n]); err != nil {
s.config.Logger.Errorf("write data to remote %s failed, %v", targetUdp.RemoteAddr(), err)
s.logger.Errorf("write data to remote %s failed, %v", targetUdp.RemoteAddr(), err)
return
}
@ -317,7 +317,7 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, req *Req
buf := bufPool[:cap(bufPool)]
n, remote, err := targetUdp.ReadFrom(buf)
if err != nil {
s.config.Logger.Errorf("read data from remote %s failed, %v", targetUdp.RemoteAddr(), err)
s.logger.Errorf("read data from remote %s failed, %v", targetUdp.RemoteAddr(), err)
return
}
@ -339,7 +339,7 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, req *Req
proBuf = append(proBuf, buf[:n]...)
if _, err := bindLn.WriteTo(proBuf, srcAddr); err != nil {
s.bufferPool.Put(tmpBufPool)
s.config.Logger.Errorf("write data to client %s failed, %v", bindLn.LocalAddr(), err)
s.logger.Errorf("write data to client %s failed, %v", bindLn.LocalAddr(), err)
return
}
s.bufferPool.Put(tmpBufPool)

@ -50,11 +50,9 @@ func TestRequest_Connect(t *testing.T) {
// Make server
s := &Server{
config: &Config{
Rules: PermitAll(),
Resolver: DNSResolver{},
Logger: NewLogger(log.New(os.Stdout, "socks5: ", log.LstdFlags)),
},
rules: PermitAll(),
resolver: DNSResolver{},
logger: NewLogger(log.New(os.Stdout, "socks5: ", log.LstdFlags)),
bufferPool: newPool(32 * 1024),
}
@ -128,11 +126,9 @@ func TestRequest_Connect_RuleFail(t *testing.T) {
// Make server
s := &Server{
config: &Config{
Rules: PermitNone(),
Resolver: DNSResolver{},
Logger: NewLogger(log.New(os.Stdout, "socks5: ", log.LstdFlags)),
},
rules: PermitNone(),
resolver: DNSResolver{},
logger: NewLogger(log.New(os.Stdout, "socks5: ", log.LstdFlags)),
bufferPool: newPool(32 * 1024),
}

136
socks5.go

@ -14,95 +14,69 @@ type GPool interface {
Submit(f func()) error
}
// Config is used to setup and configure a Server
type Config struct {
// AuthMethods can be provided to implement custom authentication
// By default, "auth-less" mode is enabled.
// For password-based auth use UserPassAuthenticator.
AuthMethods []Authenticator
// If provided, username/password authentication is enabled,
// by appending a UserPassAuthenticator to AuthMethods. If not provided,
// and AUthMethods is nil, then "auth-less" mode is enabled.
Credentials CredentialStore
// Resolver can be provided to do custom name resolution.
// Defaults to DNSResolver if not provided.
Resolver NameResolver
// Rules is provided to enable custom logic around permitting
// various commands. If not provided, PermitAll is used.
Rules RuleSet
// Rewriter can be used to transparently rewrite addresses.
// This is invoked before the RuleSet is invoked.
// Defaults to NoRewrite.
Rewriter AddressRewriter
// BindIP is used for bind or udp associate
BindIP net.IP
// Logger can be used to provide a custom log target.
// Defaults to ioutil.Discard.
Logger Logger
// Optional function for dialing out
Dial func(ctx context.Context, network, addr string) (net.Conn, error)
}
// Server is reponsible for accepting connections and handling
// the details of the SOCKS5 protocol
type Server struct {
config *Config
authMethods map[uint8]Authenticator
bufferPool *pool
gPool GPool
// AuthMethods can be provided to implement custom authentication
// By default, "auth-less" mode is enabled.
// For password-based auth use UserPassAuthenticator.
authCustomMethods []Authenticator
// If provided, username/password authentication is enabled,
// by appending a UserPassAuthenticator to AuthMethods. If not provided,
// and AUthMethods is nil, then "auth-less" mode is enabled.
credentials CredentialStore
// resolver can be provided to do custom name resolution.
// Defaults to DNSResolver if not provided.
resolver NameResolver
// rules is provided to enable custom logic around permitting
// various commands. If not provided, PermitAll is used.
rules RuleSet
// rewriter can be used to transparently rewrite addresses.
// This is invoked before the RuleSet is invoked.
// Defaults to NoRewrite.
rewriter AddressRewriter
// bindIP is used for bind or udp associate
bindIP net.IP
// logger can be used to provide a custom log target.
// Defaults to ioutil.Discard.
logger Logger
// Optional function for dialing out
dial func(ctx context.Context, network, addr string) (net.Conn, error)
// buffer pool
bufferPool *pool
// goroutine pool
gPool GPool
}
// New creates a new Server and potentially returns an error
func New(conf *Config) (*Server, error) {
// Ensure we have at least one authentication method enabled
if len(conf.AuthMethods) == 0 {
if conf.Credentials != nil {
conf.AuthMethods = []Authenticator{&UserPassAuthenticator{conf.Credentials}}
} else {
conf.AuthMethods = []Authenticator{&NoAuthAuthenticator{}}
}
}
// Ensure we have a DNS resolver
if conf.Resolver == nil {
conf.Resolver = DNSResolver{}
}
// Ensure we have a rule set
if conf.Rules == nil {
conf.Rules = PermitAll()
}
// Ensure we have a log target
if conf.Logger == nil {
conf.Logger = NewLogger(log.New(ioutil.Discard, "socks5: ", log.LstdFlags))
}
if conf.Dial == nil {
conf.Dial = func(ctx context.Context, net_, addr string) (net.Conn, error) {
return net.Dial(net_, addr)
}
}
func New(opts ...Option) *Server {
server := &Server{
config: conf,
bufferPool: newPool(2 * 1024),
authMethods: make(map[uint8]Authenticator),
authCustomMethods: []Authenticator{&NoAuthAuthenticator{}},
bufferPool: newPool(2 * 1024),
resolver: DNSResolver{},
rules: PermitAll(),
logger: NewLogger(log.New(ioutil.Discard, "socks5: ", log.LstdFlags)),
dial: func(ctx context.Context, net_, addr string) (net.Conn, error) {
return net.Dial(net_, addr)
},
}
server.authMethods = make(map[uint8]Authenticator)
for _, a := range conf.AuthMethods {
server.authMethods[a.GetCode()] = a
for _, opt := range opts {
opt(server)
}
return server, nil
// Ensure we have at least one authentication method enabled
if len(server.authCustomMethods) == 0 && server.credentials != nil {
server.authCustomMethods = []Authenticator{&UserPassAuthenticator{server.credentials}}
}
for _, v := range server.authCustomMethods {
server.authMethods[v.GetCode()] = v
}
return server
}
// ListenAndServe is used to create a listener and serve on it
@ -135,7 +109,7 @@ func (s *Server) ServeConn(conn net.Conn) (err error) {
// Read the version byte
version := []byte{0}
if _, err = bufConn.Read(version); err != nil {
s.config.Logger.Errorf("failed to get version byte: %v", err)
s.logger.Errorf("failed to get version byte: %v", err)
return err
}
@ -146,12 +120,12 @@ func (s *Server) ServeConn(conn net.Conn) (err error) {
authContext, err = s.authenticate(conn, bufConn)
if err != nil {
err = fmt.Errorf("failed to authenticate: %v", err)
s.config.Logger.Errorf("%v", err)
s.logger.Errorf("%v", err)
return err
}
} else if version[0] != socks4Version {
err := fmt.Errorf("unsupported SOCKS version: %v", version[0])
s.config.Logger.Errorf("%v", err)
s.logger.Errorf("%v", err)
return err
}
@ -173,7 +147,7 @@ func (s *Server) ServeConn(conn net.Conn) (err error) {
// Process the client request
if err := s.handleRequest(conn, request); err != nil {
err = fmt.Errorf("failed to handle request, %v", err)
s.config.Logger.Errorf("%v", err)
s.logger.Errorf("%v", err)
return err
}

@ -41,14 +41,10 @@ func TestSOCKS5_Connect(t *testing.T) {
cator := UserPassAuthenticator{
Credentials: StaticCredentials{"foo": "bar"},
}
conf := &Config{
AuthMethods: []Authenticator{cator},
Logger: NewLogger(log.New(os.Stdout, "socks5: ", log.LstdFlags)),
}
serv, err := New(conf)
if err != nil {
t.Fatalf("err: %v", err)
}
serv := New(
WithAuthMethods([]Authenticator{cator}),
WithLogger(NewLogger(log.New(os.Stdout, "socks5: ", log.LstdFlags))),
)
// Start listening
go func() {
@ -150,15 +146,10 @@ func TestSOCKS5_Associate(t *testing.T) {
// Create a socks server
cator := UserPassAuthenticator{Credentials: StaticCredentials{"foo": "bar"}}
conf := &Config{
AuthMethods: []Authenticator{cator},
Logger: NewLogger(log.New(os.Stdout, "socks5: ", log.LstdFlags)),
}
serv, err := New(conf)
if err != nil {
t.Fatalf("err: %v", err)
}
serv := New(
WithAuthMethods([]Authenticator{cator}),
WithLogger(NewLogger(log.New(os.Stdout, "socks5: ", log.LstdFlags))),
)
// Start listening
go func() {
if err := serv.ListenAndServe("tcp", "127.0.0.1:12355"); err != nil {
@ -262,14 +253,10 @@ func Test_SocksWithProxy(t *testing.T) {
// Create a socks server
cator := UserPassAuthenticator{Credentials: StaticCredentials{"foo": "bar"}}
conf := &Config{
AuthMethods: []Authenticator{cator},
Logger: NewLogger(log.New(os.Stdout, "socks5: ", log.LstdFlags)),
}
serv, err := New(conf)
if err != nil {
t.Fatalf("err: %v", err)
}
serv := New(
WithAuthMethods([]Authenticator{cator}),
WithLogger(NewLogger(log.New(os.Stdout, "socks5: ", log.LstdFlags))),
)
// Start listening
go func() {