From 1daccf05c2212c88b84f273b9382f430465c7926 Mon Sep 17 00:00:00 2001 From: mo Date: Mon, 20 Apr 2020 21:17:38 +0800 Subject: [PATCH] use options --- _example/udp/udp.go | 48 ++++++++++++++++ auth_test.go | 8 +-- option.go | 98 +++++++++++++++++++++++++++++++ request.go | 28 ++++----- request_test.go | 16 ++---- socks5.go | 136 ++++++++++++++++++-------------------------- socks5_test.go | 37 ++++-------- 7 files changed, 237 insertions(+), 134 deletions(-) create mode 100644 _example/udp/udp.go create mode 100644 option.go diff --git a/_example/udp/udp.go b/_example/udp/udp.go new file mode 100644 index 0000000..0423ca7 --- /dev/null +++ b/_example/udp/udp.go @@ -0,0 +1,48 @@ +package main + +import ( + "bytes" + "log" + "net" +) + +func main() { + lAddr := &net.UDPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: 12398, + } + l, err := net.ListenUDP("udp4", lAddr) + if err != nil { + panic(err) + } + defer l.Close() + go func() { + for { + buf := make([]byte, 2048) + n, remote, err := l.ReadFrom(buf) + if err != nil { + return + } + + if !bytes.Equal(buf[:n], []byte("ping")) { + log.Println("bad: %v", buf) + } + l.WriteTo([]byte("pong"), remote) + } + }() + + cli, err := net.DialUDP("udp", nil, lAddr) + if err != nil { + panic(err) + } + n, err := cli.Write([]byte("ping")) + if err != nil { + panic(err) + } + rsp := make([]byte, 1024) + n, _, err = cli.ReadFrom(rsp) + if err != nil { + panic(err) + } + log.Printf("%s", string(rsp[:n])) +} diff --git a/auth_test.go b/auth_test.go index 90d3d02..95b251e 100644 --- a/auth_test.go +++ b/auth_test.go @@ -10,7 +10,7 @@ func TestNoAuth(t *testing.T) { req.Write([]byte{1, NoAuth}) var resp bytes.Buffer - s, _ := New(&Config{}) + s := New() ctx, err := s.authenticate(&resp, req) if err != nil { t.Fatalf("err: %v", err) @@ -38,7 +38,7 @@ func TestPasswordAuth_Valid(t *testing.T) { cator := UserPassAuthenticator{Credentials: cred} - s, _ := New(&Config{AuthMethods: []Authenticator{cator}}) + s := New(WithAuthMethods([]Authenticator{cator})) ctx, err := s.authenticate(&resp, req) if err != nil { @@ -74,7 +74,7 @@ func TestPasswordAuth_Invalid(t *testing.T) { "foo": "bar", } cator := UserPassAuthenticator{Credentials: cred} - s, _ := New(&Config{AuthMethods: []Authenticator{cator}}) + s := New(WithAuthMethods([]Authenticator{cator})) ctx, err := s.authenticate(&resp, req) if err != UserAuthFailed { @@ -101,7 +101,7 @@ func TestNoSupportedAuth(t *testing.T) { } cator := UserPassAuthenticator{Credentials: cred} - s, _ := New(&Config{AuthMethods: []Authenticator{cator}}) + s := New(WithAuthMethods([]Authenticator{cator})) ctx, err := s.authenticate(&resp, req) if err != NoSupportedAuth { diff --git a/option.go b/option.go new file mode 100644 index 0000000..bd22191 --- /dev/null +++ b/option.go @@ -0,0 +1,98 @@ +package socks5 + +import ( + "context" + "net" +) + +type Option func(s *Server) + +// AuthMethods can be provided to implement custom authentication +// By default, "auth-less" mode is enabled. +// For password-based auth use UserPassAuthenticator. +func WithAuthMethods(authMethods []Authenticator) Option { + return func(s *Server) { + if len(authMethods) != 0 { + s.authCustomMethods = make([]Authenticator, 0, len(authMethods)) + s.authCustomMethods = append(s.authCustomMethods, authMethods...) + } + } +} + +// If provided, username/password authentication is enabled, +// by appending a UserPassAuthenticator to AuthMethods. If not provided, +// and AUthMethods is nil, then "auth-less" mode is enabled. +func WithCredential(cs CredentialStore) Option { + return func(s *Server) { + if cs != nil { + s.credentials = cs + } + } +} + +// resolver can be provided to do custom name resolution. +// Defaults to DNSResolver if not provided. +func WithResolver(res NameResolver) Option { + return func(s *Server) { + if res != nil { + s.resolver = res + } + } +} + +// rules is provided to enable custom logic around permitting +// various commands. If not provided, PermitAll is used. +func WithRule(rule RuleSet) Option { + return func(s *Server) { + if rule != nil { + s.rules = rule + } + } +} + +// rewriter can be used to transparently rewrite addresses. +// This is invoked before the RuleSet is invoked. +// Defaults to NoRewrite. +func WithRewriter(rew AddressRewriter) Option { + return func(s *Server) { + if rew != nil { + s.rewriter = rew + } + } +} + +// bindIP is used for bind or udp associate +func WithBindIP(ip net.IP) Option { + return func(s *Server) { + if ip != nil { + + } + } +} + +// logger can be used to provide a custom log target. +// Defaults to ioutil.Discard. +func WithLogger(l Logger) Option { + return func(s *Server) { + if l != nil { + s.logger = l + } + } +} + +// Optional function for dialing out +func WithDial(dial func(ctx context.Context, network, addr string) (net.Conn, error)) Option { + return func(s *Server) { + if dial != nil { + s.dial = dial + } + } +} + +func WithGPool(pool GPool) Option { + return func(s *Server) { + if pool != nil { + s.gPool = pool + } + } +} diff --git a/request.go b/request.go index 8968940..a803157 100644 --- a/request.go +++ b/request.go @@ -68,7 +68,7 @@ func (s *Server) handleRequest(write io.Writer, req *Request) error { // Resolve the address if we have a FQDN dest := req.DestAddr if dest.FQDN != "" { - ctx_, addr, err := s.config.Resolver.Resolve(ctx, dest.FQDN) + ctx_, addr, err := s.resolver.Resolve(ctx, dest.FQDN) if err != nil { if err := sendReply(write, req.Header, hostUnreachable); err != nil { return fmt.Errorf("failed to send reply, %v", err) @@ -81,8 +81,8 @@ func (s *Server) handleRequest(write io.Writer, req *Request) error { // Apply any address rewrites req.realDestAddr = req.DestAddr - if s.config.Rewriter != nil { - ctx, req.realDestAddr = s.config.Rewriter.Rewrite(ctx, req) + if s.rewriter != nil { + ctx, req.realDestAddr = s.rewriter.Rewrite(ctx, req) } // Switch on the command @@ -104,7 +104,7 @@ func (s *Server) handleRequest(write io.Writer, req *Request) error { // handleConnect is used to handle a connect command func (s *Server) handleConnect(ctx context.Context, writer io.Writer, req *Request) error { // Check if this is allowed - if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok { + if ctx_, ok := s.rules.Allow(ctx, req); !ok { if err := sendReply(writer, req.Header, ruleFailure); err != nil { return fmt.Errorf("failed to send reply, %v", err) } @@ -114,7 +114,7 @@ func (s *Server) handleConnect(ctx context.Context, writer io.Writer, req *Reque } // Attempt to connect - dial := s.config.Dial + dial := s.dial if dial == nil { dial = func(ctx context.Context, net_, addr string) (net.Conn, error) { return net.Dial(net_, addr) @@ -161,7 +161,7 @@ func (s *Server) handleConnect(ctx context.Context, writer io.Writer, req *Reque // handleBind is used to handle a connect command func (s *Server) handleBind(ctx context.Context, writer io.Writer, req *Request) error { // Check if this is allowed - if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok { + if ctx_, ok := s.rules.Allow(ctx, req); !ok { if err := sendReply(writer, req.Header, ruleFailure); err != nil { return fmt.Errorf("failed to send reply, %v", err) } @@ -180,7 +180,7 @@ func (s *Server) handleBind(ctx context.Context, writer io.Writer, req *Request) // handleAssociate is used to handle a connect command func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, req *Request) error { // Check if this is allowed - if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok { + if ctx_, ok := s.rules.Allow(ctx, req); !ok { if err := sendReply(writer, req.Header, ruleFailure); err != nil { return fmt.Errorf("failed to send reply, %v", err) } @@ -190,7 +190,7 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, req *Req } // Attempt to connect - dial := s.config.Dial + dial := s.dial if dial == nil { dial = func(ctx context.Context, net_, addr string) (net.Conn, error) { return net.Dial(net_, addr) @@ -229,7 +229,7 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, req *Req } defer bindLn.Close() - s.config.Logger.Errorf("target addr %v, listen addr: %s", targetUdp.RemoteAddr(), bindLn.LocalAddr()) + s.logger.Errorf("target addr %v, listen addr: %s", targetUdp.RemoteAddr(), bindLn.LocalAddr()) // send BND.ADDR and BND.PORT, client must if err = sendReply(writer, req.Header, successReply, bindLn.LocalAddr()); err != nil { return fmt.Errorf("failed to send reply, %v", err) @@ -256,10 +256,10 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, req *Req buf := bufPool[:cap(bufPool)] n, srcAddr, err := bindLn.ReadFrom(buf) if err != nil { - s.config.Logger.Errorf("read data from bind listen address %s failed, %v", bindLn.LocalAddr(), err) + s.logger.Errorf("read data from bind listen address %s failed, %v", bindLn.LocalAddr(), err) return } - s.config.Logger.Errorf("data length: %d,%d", n, len(buf)) + if n <= 4+net.IPv4len+2 { // no data continue } @@ -299,7 +299,7 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, req *Req // 把消息写给remote sever if _, err := targetUdp.Write(buf[headLen:n]); err != nil { - s.config.Logger.Errorf("write data to remote %s failed, %v", targetUdp.RemoteAddr(), err) + s.logger.Errorf("write data to remote %s failed, %v", targetUdp.RemoteAddr(), err) return } @@ -317,7 +317,7 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, req *Req buf := bufPool[:cap(bufPool)] n, remote, err := targetUdp.ReadFrom(buf) if err != nil { - s.config.Logger.Errorf("read data from remote %s failed, %v", targetUdp.RemoteAddr(), err) + s.logger.Errorf("read data from remote %s failed, %v", targetUdp.RemoteAddr(), err) return } @@ -339,7 +339,7 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, req *Req proBuf = append(proBuf, buf[:n]...) if _, err := bindLn.WriteTo(proBuf, srcAddr); err != nil { s.bufferPool.Put(tmpBufPool) - s.config.Logger.Errorf("write data to client %s failed, %v", bindLn.LocalAddr(), err) + s.logger.Errorf("write data to client %s failed, %v", bindLn.LocalAddr(), err) return } s.bufferPool.Put(tmpBufPool) diff --git a/request_test.go b/request_test.go index 1ad3433..b8d4901 100644 --- a/request_test.go +++ b/request_test.go @@ -50,11 +50,9 @@ func TestRequest_Connect(t *testing.T) { // Make server s := &Server{ - config: &Config{ - Rules: PermitAll(), - Resolver: DNSResolver{}, - Logger: NewLogger(log.New(os.Stdout, "socks5: ", log.LstdFlags)), - }, + rules: PermitAll(), + resolver: DNSResolver{}, + logger: NewLogger(log.New(os.Stdout, "socks5: ", log.LstdFlags)), bufferPool: newPool(32 * 1024), } @@ -128,11 +126,9 @@ func TestRequest_Connect_RuleFail(t *testing.T) { // Make server s := &Server{ - config: &Config{ - Rules: PermitNone(), - Resolver: DNSResolver{}, - Logger: NewLogger(log.New(os.Stdout, "socks5: ", log.LstdFlags)), - }, + rules: PermitNone(), + resolver: DNSResolver{}, + logger: NewLogger(log.New(os.Stdout, "socks5: ", log.LstdFlags)), bufferPool: newPool(32 * 1024), } diff --git a/socks5.go b/socks5.go index af0414a..33bd8f6 100644 --- a/socks5.go +++ b/socks5.go @@ -14,95 +14,69 @@ type GPool interface { Submit(f func()) error } -// Config is used to setup and configure a Server -type Config struct { - // AuthMethods can be provided to implement custom authentication - // By default, "auth-less" mode is enabled. - // For password-based auth use UserPassAuthenticator. - AuthMethods []Authenticator - - // If provided, username/password authentication is enabled, - // by appending a UserPassAuthenticator to AuthMethods. If not provided, - // and AUthMethods is nil, then "auth-less" mode is enabled. - Credentials CredentialStore - - // Resolver can be provided to do custom name resolution. - // Defaults to DNSResolver if not provided. - Resolver NameResolver - - // Rules is provided to enable custom logic around permitting - // various commands. If not provided, PermitAll is used. - Rules RuleSet - - // Rewriter can be used to transparently rewrite addresses. - // This is invoked before the RuleSet is invoked. - // Defaults to NoRewrite. - Rewriter AddressRewriter - - // BindIP is used for bind or udp associate - BindIP net.IP - - // Logger can be used to provide a custom log target. - // Defaults to ioutil.Discard. - Logger Logger - - // Optional function for dialing out - Dial func(ctx context.Context, network, addr string) (net.Conn, error) -} - // Server is reponsible for accepting connections and handling // the details of the SOCKS5 protocol type Server struct { - config *Config authMethods map[uint8]Authenticator - bufferPool *pool - gPool GPool + // AuthMethods can be provided to implement custom authentication + // By default, "auth-less" mode is enabled. + // For password-based auth use UserPassAuthenticator. + authCustomMethods []Authenticator + // If provided, username/password authentication is enabled, + // by appending a UserPassAuthenticator to AuthMethods. If not provided, + // and AUthMethods is nil, then "auth-less" mode is enabled. + credentials CredentialStore + // resolver can be provided to do custom name resolution. + // Defaults to DNSResolver if not provided. + resolver NameResolver + // rules is provided to enable custom logic around permitting + // various commands. If not provided, PermitAll is used. + rules RuleSet + // rewriter can be used to transparently rewrite addresses. + // This is invoked before the RuleSet is invoked. + // Defaults to NoRewrite. + rewriter AddressRewriter + // bindIP is used for bind or udp associate + bindIP net.IP + // logger can be used to provide a custom log target. + // Defaults to ioutil.Discard. + logger Logger + // Optional function for dialing out + dial func(ctx context.Context, network, addr string) (net.Conn, error) + // buffer pool + bufferPool *pool + // goroutine pool + gPool GPool } // New creates a new Server and potentially returns an error -func New(conf *Config) (*Server, error) { - // Ensure we have at least one authentication method enabled - if len(conf.AuthMethods) == 0 { - if conf.Credentials != nil { - conf.AuthMethods = []Authenticator{&UserPassAuthenticator{conf.Credentials}} - } else { - conf.AuthMethods = []Authenticator{&NoAuthAuthenticator{}} - } - } - - // Ensure we have a DNS resolver - if conf.Resolver == nil { - conf.Resolver = DNSResolver{} - } - - // Ensure we have a rule set - if conf.Rules == nil { - conf.Rules = PermitAll() - } - - // Ensure we have a log target - if conf.Logger == nil { - conf.Logger = NewLogger(log.New(ioutil.Discard, "socks5: ", log.LstdFlags)) - } - - if conf.Dial == nil { - conf.Dial = func(ctx context.Context, net_, addr string) (net.Conn, error) { - return net.Dial(net_, addr) - } - } - +func New(opts ...Option) *Server { server := &Server{ - config: conf, - bufferPool: newPool(2 * 1024), + authMethods: make(map[uint8]Authenticator), + authCustomMethods: []Authenticator{&NoAuthAuthenticator{}}, + bufferPool: newPool(2 * 1024), + resolver: DNSResolver{}, + rules: PermitAll(), + logger: NewLogger(log.New(ioutil.Discard, "socks5: ", log.LstdFlags)), + dial: func(ctx context.Context, net_, addr string) (net.Conn, error) { + return net.Dial(net_, addr) + }, } - server.authMethods = make(map[uint8]Authenticator) - - for _, a := range conf.AuthMethods { - server.authMethods[a.GetCode()] = a + for _, opt := range opts { + opt(server) } - return server, nil + // Ensure we have at least one authentication method enabled + if len(server.authCustomMethods) == 0 && server.credentials != nil { + server.authCustomMethods = []Authenticator{&UserPassAuthenticator{server.credentials}} + } + + for _, v := range server.authCustomMethods { + server.authMethods[v.GetCode()] = v + } + + return server } // ListenAndServe is used to create a listener and serve on it @@ -135,7 +109,7 @@ func (s *Server) ServeConn(conn net.Conn) (err error) { // Read the version byte version := []byte{0} if _, err = bufConn.Read(version); err != nil { - s.config.Logger.Errorf("failed to get version byte: %v", err) + s.logger.Errorf("failed to get version byte: %v", err) return err } @@ -146,12 +120,12 @@ func (s *Server) ServeConn(conn net.Conn) (err error) { authContext, err = s.authenticate(conn, bufConn) if err != nil { err = fmt.Errorf("failed to authenticate: %v", err) - s.config.Logger.Errorf("%v", err) + s.logger.Errorf("%v", err) return err } } else if version[0] != socks4Version { err := fmt.Errorf("unsupported SOCKS version: %v", version[0]) - s.config.Logger.Errorf("%v", err) + s.logger.Errorf("%v", err) return err } @@ -173,7 +147,7 @@ func (s *Server) ServeConn(conn net.Conn) (err error) { // Process the client request if err := s.handleRequest(conn, request); err != nil { err = fmt.Errorf("failed to handle request, %v", err) - s.config.Logger.Errorf("%v", err) + s.logger.Errorf("%v", err) return err } diff --git a/socks5_test.go b/socks5_test.go index 7b2800d..58e1f5f 100644 --- a/socks5_test.go +++ b/socks5_test.go @@ -41,14 +41,10 @@ func TestSOCKS5_Connect(t *testing.T) { cator := UserPassAuthenticator{ Credentials: StaticCredentials{"foo": "bar"}, } - conf := &Config{ - AuthMethods: []Authenticator{cator}, - Logger: NewLogger(log.New(os.Stdout, "socks5: ", log.LstdFlags)), - } - serv, err := New(conf) - if err != nil { - t.Fatalf("err: %v", err) - } + serv := New( + WithAuthMethods([]Authenticator{cator}), + WithLogger(NewLogger(log.New(os.Stdout, "socks5: ", log.LstdFlags))), + ) // Start listening go func() { @@ -150,15 +146,10 @@ func TestSOCKS5_Associate(t *testing.T) { // Create a socks server cator := UserPassAuthenticator{Credentials: StaticCredentials{"foo": "bar"}} - conf := &Config{ - AuthMethods: []Authenticator{cator}, - Logger: NewLogger(log.New(os.Stdout, "socks5: ", log.LstdFlags)), - } - serv, err := New(conf) - if err != nil { - t.Fatalf("err: %v", err) - } - + serv := New( + WithAuthMethods([]Authenticator{cator}), + WithLogger(NewLogger(log.New(os.Stdout, "socks5: ", log.LstdFlags))), + ) // Start listening go func() { if err := serv.ListenAndServe("tcp", "127.0.0.1:12355"); err != nil { @@ -262,14 +253,10 @@ func Test_SocksWithProxy(t *testing.T) { // Create a socks server cator := UserPassAuthenticator{Credentials: StaticCredentials{"foo": "bar"}} - conf := &Config{ - AuthMethods: []Authenticator{cator}, - Logger: NewLogger(log.New(os.Stdout, "socks5: ", log.LstdFlags)), - } - serv, err := New(conf) - if err != nil { - t.Fatalf("err: %v", err) - } + serv := New( + WithAuthMethods([]Authenticator{cator}), + WithLogger(NewLogger(log.New(os.Stdout, "socks5: ", log.LstdFlags))), + ) // Start listening go func() {