diff --git a/auth.go b/auth.go index fd656a6..bd625c3 100644 --- a/auth.go +++ b/auth.go @@ -5,6 +5,7 @@ import ( "io" ) +// auth defined const ( MethodNoAuth = uint8(0) MethodGSSAPI = uint8(1) @@ -15,12 +16,13 @@ const ( AuthFailure = uint8(1) ) +// auth error defined var ( - UserAuthFailed = fmt.Errorf("User authentication failed") - NoSupportedAuth = fmt.Errorf("No supported authentication mechanism") + ErrUserAuthFailed = fmt.Errorf("User authentication failed") + ErrNoSupportedAuth = fmt.Errorf("No supported authentication mechanism") ) -// A Request encapsulates authentication state provided +// AuthContext A Request encapsulates authentication state provided // during negotiation type AuthContext struct { // Provided auth method @@ -31,6 +33,7 @@ type AuthContext struct { Payload map[string]string } +// Authenticator provide auth type Authenticator interface { Authenticate(reader io.Reader, writer io.Writer, userAddr string) (*AuthContext, error) GetCode() uint8 @@ -39,10 +42,12 @@ type Authenticator interface { // NoAuthAuthenticator is used to handle the "No Authentication" mode type NoAuthAuthenticator struct{} +// GetCode implement interface Authenticator func (a NoAuthAuthenticator) GetCode() uint8 { return MethodNoAuth } +// Authenticate implement interface Authenticator func (a NoAuthAuthenticator) Authenticate(reader io.Reader, writer io.Writer, userAddr string) (*AuthContext, error) { _, err := writer.Write([]byte{VersionSocks5, MethodNoAuth}) return &AuthContext{MethodNoAuth, nil}, err @@ -54,10 +59,12 @@ type UserPassAuthenticator struct { Credentials CredentialStore } +// GetCode implement interface Authenticator func (a UserPassAuthenticator) GetCode() uint8 { return MethodUserPassAuth } +// Authenticate implement interface Authenticator func (a UserPassAuthenticator) Authenticate(reader io.Reader, writer io.Writer, userAddr string) (*AuthContext, error) { // Tell the client to use user/pass auth if _, err := writer.Write([]byte{VersionSocks5, MethodUserPassAuth}); err != nil { @@ -103,7 +110,7 @@ func (a UserPassAuthenticator) Authenticate(reader io.Reader, writer io.Writer, if _, err := writer.Write([]byte{UserPassAuthVersion, AuthFailure}); err != nil { return nil, err } - return nil, UserAuthFailed + return nil, ErrUserAuthFailed } // Done @@ -134,7 +141,7 @@ func (s *Server) authenticate(conn io.Writer, bufConn io.Reader, userAddr string // authentication mechanism func noAcceptableAuth(conn io.Writer) error { conn.Write([]byte{VersionSocks5, MethodNoAcceptable}) - return NoSupportedAuth + return ErrNoSupportedAuth } // readMethods is used to read the number of methods diff --git a/auth_test.go b/auth_test.go index 47f80eb..f3591ce 100644 --- a/auth_test.go +++ b/auth_test.go @@ -77,7 +77,7 @@ func TestPasswordAuth_Invalid(t *testing.T) { s := New(WithAuthMethods([]Authenticator{cator})) ctx, err := s.authenticate(&resp, req, "") - if err != UserAuthFailed { + if err != ErrUserAuthFailed { t.Fatalf("err: %v", err) } @@ -104,7 +104,7 @@ func TestNoSupportedAuth(t *testing.T) { s := New(WithAuthMethods([]Authenticator{cator})) ctx, err := s.authenticate(&resp, req, "") - if err != NoSupportedAuth { + if err != ErrNoSupportedAuth { t.Fatalf("err: %v", err) } diff --git a/credentials.go b/credentials.go index 46cb1fc..6a38312 100644 --- a/credentials.go +++ b/credentials.go @@ -9,6 +9,7 @@ type CredentialStore interface { // StaticCredentials enables using a map directly as a credential store type StaticCredentials map[string]string +// Valid implement interface CredentialStore func (s StaticCredentials) Valid(user, password, userAddr string) bool { pass, ok := s[user] if !ok { diff --git a/header.go b/header.go index 71b7e95..1a6c932 100644 --- a/header.go +++ b/header.go @@ -1,12 +1,14 @@ package socks5 import ( + "bytes" "fmt" "io" "net" "strconv" ) +// socks const defined const ( // protocol version VersionSocks4 = uint8(4) @@ -83,6 +85,7 @@ type Header struct { addrType uint8 } +// Parse to header func Parse(r io.Reader) (hd Header, err error) { // Read the version and command tmp := make([]byte, headVERLen+headCMDLen) @@ -142,13 +145,15 @@ func Parse(r io.Reader) (hd Header, err error) { hd.Address.IP = addr[:net.IPv6len] hd.Address.Port = buildPort(addr[net.IPv6len], addr[net.IPv6len+1]) default: - return hd, unrecognizedAddrType + return hd, errUnrecognizedAddrType } } return hd, nil } +// Bytes returns a slice of header func (h Header) Bytes() (b []byte) { + bytes.Buffer{}.Bytes() b = append(b, h.Version) b = append(b, h.Command) hiPort, loPort := breakPort(h.Address.Port) diff --git a/logger.go b/logger.go index 9bda141..1a2bb3f 100644 --- a/logger.go +++ b/logger.go @@ -4,19 +4,22 @@ import ( "log" ) +// Logger is used to provide debug logger type Logger interface { Errorf(format string, arg ...interface{}) } -// 标准输出 +// Std std logger type Std struct { *log.Logger } +// NewLogger new std logger with log.logger func NewLogger(l *log.Logger) *Std { return &Std{l} } +// Errorf implement interface Logger func (sf Std) Errorf(format string, args ...interface{}) { sf.Logger.Printf("[E]: "+format, args...) } diff --git a/option.go b/option.go index 7f667d5..542d497 100644 --- a/option.go +++ b/option.go @@ -2,12 +2,14 @@ package socks5 import ( "context" + "io" "net" ) +// Option user's option type Option func(s *Server) -// AuthMethods can be provided to implement custom authentication +// WithAuthMethods can be provided to implement custom authentication // By default, "auth-less" mode is enabled. // For password-based auth use UserPassAuthenticator. func WithAuthMethods(authMethods []Authenticator) Option { @@ -19,7 +21,7 @@ func WithAuthMethods(authMethods []Authenticator) Option { } } -// If provided, username/password authentication is enabled, +// WithCredential If provided, username/password authentication is enabled, // by appending a UserPassAuthenticator to AuthMethods. If not provided, // and AUthMethods is nil, then "auth-less" mode is enabled. func WithCredential(cs CredentialStore) Option { @@ -30,7 +32,7 @@ func WithCredential(cs CredentialStore) Option { } } -// resolver can be provided to do custom name resolution. +// WithResolver can be provided to do custom name resolution. // Defaults to DNSResolver if not provided. func WithResolver(res NameResolver) Option { return func(s *Server) { @@ -40,7 +42,7 @@ func WithResolver(res NameResolver) Option { } } -// rules is provided to enable custom logic around permitting +// WithRule is provided to enable custom logic around permitting // various commands. If not provided, PermitAll is used. func WithRule(rule RuleSet) Option { return func(s *Server) { @@ -50,7 +52,7 @@ func WithRule(rule RuleSet) Option { } } -// rewriter can be used to transparently rewrite addresses. +// WithRewriter can be used to transparently rewrite addresses. // This is invoked before the RuleSet is invoked. // Defaults to NoRewrite. func WithRewriter(rew AddressRewriter) Option { @@ -61,7 +63,7 @@ func WithRewriter(rew AddressRewriter) Option { } } -// bindIP is used for bind or udp associate +// WithBindIP is used for bind or udp associate func WithBindIP(ip net.IP) Option { return func(s *Server) { if len(ip) != 0 { @@ -71,7 +73,7 @@ func WithBindIP(ip net.IP) Option { } } -// logger can be used to provide a custom log target. +// WithLogger can be used to provide a custom log target. // Defaults to ioutil.Discard. func WithLogger(l Logger) Option { return func(s *Server) { @@ -81,7 +83,7 @@ func WithLogger(l Logger) Option { } } -// Optional function for dialing out +// WithDial Optional function for dialing out func WithDial(dial func(ctx context.Context, network, addr string) (net.Conn, error)) Option { return func(s *Server) { if dial != nil { @@ -90,6 +92,7 @@ func WithDial(dial func(ctx context.Context, network, addr string) (net.Conn, er } } +// WithGPool can be provided to do custom goroutine pool. func WithGPool(pool GPool) Option { return func(s *Server) { if pool != nil { @@ -97,3 +100,24 @@ func WithGPool(pool GPool) Option { } } } + +// WithConnectHandle is used to handle a user's connect command +func WithConnectHandle(h func(ctx context.Context, writer io.Writer, req *Request) error) Option { + return func(s *Server) { + s.userConnectHandle = h + } +} + +// WithBindHandle is used to handle a user's bind command +func WithBindHandle(h func(ctx context.Context, writer io.Writer, req *Request) error) Option { + return func(s *Server) { + s.userBindHandle = h + } +} + +// WithAssociateHandle is used to handle a user's associate command +func WithAssociateHandle(h func(ctx context.Context, writer io.Writer, req *Request) error) Option { + return func(s *Server) { + s.userAssociateHandle = h + } +} diff --git a/request.go b/request.go index 1ec4826..04ab5e7 100644 --- a/request.go +++ b/request.go @@ -10,7 +10,7 @@ import ( ) var ( - unrecognizedAddrType = fmt.Errorf("Unrecognized address type") + errUnrecognizedAddrType = fmt.Errorf("Unrecognized address type") ) // AddressRewriter is used to rewrite a destination transparently @@ -25,11 +25,12 @@ type Request struct { AuthContext *AuthContext // AddrSpec of the the network that sent the request RemoteAddr *AddrSpec - // AddrSpec of the desired destination - DestAddr *AddrSpec // AddrSpec of the actual destination (might be affected by rewrite) - realDestAddr *AddrSpec - bufConn io.Reader + DestAddr *AddrSpec + // Reader connect of request + Reader io.Reader + // AddrSpec of the desired destination + RawDestAddr *AddrSpec } type conn interface { @@ -55,9 +56,9 @@ func NewRequest(bufConn io.Reader) (*Request, error) { return nil, fmt.Errorf("unrecognized command[%d]", hd.Command) } return &Request{ - Header: hd, - DestAddr: &hd.Address, - bufConn: bufConn, + Header: hd, + RawDestAddr: &hd.Address, + Reader: bufConn, }, nil } @@ -66,45 +67,54 @@ func (s *Server) handleRequest(write io.Writer, req *Request) error { ctx := context.Background() // Resolve the address if we have a FQDN - dest := req.DestAddr + dest := req.RawDestAddr if dest.FQDN != "" { - ctx_, addr, err := s.resolver.Resolve(ctx, dest.FQDN) + _ctx, addr, err := s.resolver.Resolve(ctx, dest.FQDN) if err != nil { - if err := sendReply(write, req.Header, hostUnreachable); err != nil { + if err := SendReply(write, req.Header, hostUnreachable); err != nil { return fmt.Errorf("failed to send reply, %v", err) } return fmt.Errorf("failed to resolve destination[%v], %v", dest.FQDN, err) } - ctx = ctx_ + ctx = _ctx dest.IP = addr } // Apply any address rewrites - req.realDestAddr = req.DestAddr + req.DestAddr = req.RawDestAddr if s.rewriter != nil { - ctx, req.realDestAddr = s.rewriter.Rewrite(ctx, req) + ctx, req.DestAddr = s.rewriter.Rewrite(ctx, req) } // Check if this is allowed - if ctx_, ok := s.rules.Allow(ctx, req); !ok { - if err := sendReply(write, req.Header, ruleFailure); err != nil { + _ctx, ok := s.rules.Allow(ctx, req) + if !ok { + if err := SendReply(write, req.Header, ruleFailure); err != nil { return fmt.Errorf("failed to send reply, %v", err) } - return fmt.Errorf("bind to %v blocked by rules", req.DestAddr) - } else { - ctx = ctx_ + return fmt.Errorf("bind to %v blocked by rules", req.RawDestAddr) } + ctx = _ctx // Switch on the command switch req.Command { case CommandConnect: + if s.userConnectHandle != nil { + return s.userConnectHandle(ctx, write, req) + } return s.handleConnect(ctx, write, req) case CommandBind: + if s.userBindHandle != nil { + return s.userBindHandle(ctx, write, req) + } return s.handleBind(ctx, write, req) case CommandAssociate: + if s.userAssociateHandle != nil { + return s.userAssociateHandle(ctx, write, req) + } return s.handleAssociate(ctx, write, req) default: - if err := sendReply(write, req.Header, commandNotSupported); err != nil { + if err := SendReply(write, req.Header, commandNotSupported); err != nil { return fmt.Errorf("failed to send reply, %v", err) } return fmt.Errorf("unsupported command[%v]", req.Command) @@ -120,7 +130,7 @@ func (s *Server) handleConnect(ctx context.Context, writer io.Writer, req *Reque return net.Dial(net_, addr) } } - target, err := dial(ctx, "tcp", req.realDestAddr.Address()) + target, err := dial(ctx, "tcp", req.DestAddr.Address()) if err != nil { msg := err.Error() resp := hostUnreachable @@ -129,23 +139,23 @@ func (s *Server) handleConnect(ctx context.Context, writer io.Writer, req *Reque } else if strings.Contains(msg, "network is unreachable") { resp = networkUnreachable } - if err := sendReply(writer, req.Header, resp); err != nil { + if err := SendReply(writer, req.Header, resp); err != nil { return fmt.Errorf("failed to send reply, %v", err) } - return fmt.Errorf("connect to %v failed, %v", req.DestAddr, err) + return fmt.Errorf("connect to %v failed, %v", req.RawDestAddr, err) } defer target.Close() // Send success - if err := sendReply(writer, req.Header, successReply, target.LocalAddr()); err != nil { + if err := SendReply(writer, req.Header, successReply, target.LocalAddr()); err != nil { return fmt.Errorf("failed to send reply, %v", err) } // Start proxying errCh := make(chan error, 2) - s.submit(func() { errCh <- s.proxy(target, req.bufConn) }) - s.submit(func() { errCh <- s.proxy(writer, target) }) + s.submit(func() { errCh <- s.Proxy(target, req.Reader) }) + s.submit(func() { errCh <- s.Proxy(writer, target) }) // Wait for i := 0; i < 2; i++ { @@ -161,7 +171,7 @@ func (s *Server) handleConnect(ctx context.Context, writer io.Writer, req *Reque // handleBind is used to handle a connect command func (s *Server) handleBind(ctx context.Context, writer io.Writer, req *Request) error { // TODO: Support bind - if err := sendReply(writer, req.Header, commandNotSupported); err != nil { + if err := SendReply(writer, req.Header, commandNotSupported); err != nil { return fmt.Errorf("failed to send reply: %v", err) } return nil @@ -176,7 +186,7 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, req *Req return net.Dial(net_, addr) } } - target, err := dial(ctx, "udp", req.realDestAddr.Address()) + target, err := dial(ctx, "udp", req.DestAddr.Address()) if err != nil { msg := err.Error() resp := hostUnreachable @@ -185,16 +195,16 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, req *Req } else if strings.Contains(msg, "network is unreachable") { resp = networkUnreachable } - if err := sendReply(writer, req.Header, resp); err != nil { + if err := SendReply(writer, req.Header, resp); err != nil { return fmt.Errorf("failed to send reply, %v", err) } - return fmt.Errorf("connect to %v failed, %v", req.DestAddr, err) + return fmt.Errorf("connect to %v failed, %v", req.RawDestAddr, err) } defer target.Close() - targetUdp, ok := target.(*net.UDPConn) + targetUDP, ok := target.(*net.UDPConn) if !ok { - if err := sendReply(writer, req.Header, serverFailure); err != nil { + if err := SendReply(writer, req.Header, serverFailure); err != nil { return fmt.Errorf("failed to send reply, %v", err) } return fmt.Errorf("dial udp invalid") @@ -202,16 +212,16 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, req *Req bindLn, err := net.ListenUDP("udp", nil) if err != nil { - if err := sendReply(writer, req.Header, serverFailure); err != nil { + if err := SendReply(writer, req.Header, serverFailure); err != nil { return fmt.Errorf("failed to send reply, %v", err) } return fmt.Errorf("listen udp failed, %v", err) } defer bindLn.Close() - s.logger.Errorf("target addr %v, listen addr: %s", targetUdp.RemoteAddr(), bindLn.LocalAddr()) + s.logger.Errorf("target addr %v, listen addr: %s", targetUDP.RemoteAddr(), bindLn.LocalAddr()) // send BND.ADDR and BND.PORT, client must - if err = sendReply(writer, req.Header, successReply, bindLn.LocalAddr()); err != nil { + if err = SendReply(writer, req.Header, successReply, bindLn.LocalAddr()); err != nil { return fmt.Errorf("failed to send reply, %v", err) } @@ -228,7 +238,7 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, req *Req conns := sync.Map{} bufPool := s.bufferPool.Get() defer func() { - targetUdp.Close() + targetUDP.Close() bindLn.Close() s.bufferPool.Put(bufPool) }() @@ -278,8 +288,8 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, req *Req } // 把消息写给remote sever - if _, err := targetUdp.Write(buf[headLen:n]); err != nil { - s.logger.Errorf("write data to remote %s failed, %v", targetUdp.RemoteAddr(), err) + if _, err := targetUDP.Write(buf[headLen:n]); err != nil { + s.logger.Errorf("write data to remote %s failed, %v", targetUDP.RemoteAddr(), err) return } @@ -288,16 +298,16 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, req *Req // read from remote server and write to client bufPool := s.bufferPool.Get() defer func() { - targetUdp.Close() + targetUDP.Close() bindLn.Close() s.bufferPool.Put(bufPool) }() for { buf := bufPool[:cap(bufPool)] - n, remote, err := targetUdp.ReadFrom(buf) + n, remote, err := targetUDP.ReadFrom(buf) if err != nil { - s.logger.Errorf("read data from remote %s failed, %v", targetUdp.RemoteAddr(), err) + s.logger.Errorf("read data from remote %s failed, %v", targetUDP.RemoteAddr(), err) return } @@ -334,15 +344,15 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, req *Req s.bufferPool.Put(buf) }() for { - _, err := req.bufConn.Read(buf[:cap(buf)]) + _, err := req.Reader.Read(buf[:cap(buf)]) if err != nil { return err } } } -// sendReply is used to send a reply message -func sendReply(w io.Writer, head Header, resp uint8, bindAddr ...net.Addr) error { +// SendReply is used to send a reply message +func SendReply(w io.Writer, head Header, resp uint8, bindAddr ...net.Addr) error { /* The SOCKS response is formed as follows: +----+-----+-------+------+----------+----------+ @@ -397,9 +407,9 @@ type closeWriter interface { CloseWrite() error } -// proxy is used to suffle data from src to destination, and sends errors +// Proxy is used to suffle data from src to destination, and sends errors // down a dedicated channel -func (s *Server) proxy(dst io.Writer, src io.Reader) error { +func (s *Server) Proxy(dst io.Writer, src io.Reader) error { buf := s.bufferPool.Get() defer s.bufferPool.Put(buf) _, err := io.CopyBuffer(dst, src, buf[:cap(buf)]) diff --git a/resolver.go b/resolver.go index ef10f3a..cb9d122 100644 --- a/resolver.go +++ b/resolver.go @@ -13,6 +13,7 @@ type NameResolver interface { // DNSResolver uses the system DNS to resolve host names type DNSResolver struct{} +// Resolve implement interface NameResolver func (d DNSResolver) Resolve(ctx context.Context, name string) (context.Context, net.IP, error) { addr, err := net.ResolveIPAddr("ip", name) if err != nil { diff --git a/ruleset.go b/ruleset.go index 48118c2..a953487 100644 --- a/ruleset.go +++ b/ruleset.go @@ -27,6 +27,7 @@ type PermitCommand struct { EnableAssociate bool } +// Allow implement interface RuleSet func (p *PermitCommand) Allow(ctx context.Context, req *Request) (context.Context, bool) { switch req.Command { case CommandConnect: diff --git a/socks5.go b/socks5.go index 155304b..68c52bc 100644 --- a/socks5.go +++ b/socks5.go @@ -4,6 +4,7 @@ import ( "bufio" "context" "fmt" + "io" "io/ioutil" "log" "net" @@ -47,6 +48,10 @@ type Server struct { bufferPool *pool // goroutine pool gPool GPool + // user's handle + userConnectHandle func(ctx context.Context, writer io.Writer, req *Request) error + userBindHandle func(ctx context.Context, writer io.Writer, req *Request) error + userAssociateHandle func(ctx context.Context, writer io.Writer, req *Request) error } // New creates a new Server and potentially returns an error @@ -131,8 +136,8 @@ func (s *Server) ServeConn(conn net.Conn) (err error) { request, err := NewRequest(bufConn) if err != nil { - if err == unrecognizedAddrType { - if err := sendReply(conn, Header{Version: version[0]}, addrTypeNotSupported); err != nil { + if err == errUnrecognizedAddrType { + if err := SendReply(conn, Header{Version: version[0]}, addrTypeNotSupported); err != nil { return fmt.Errorf("failed to send reply, %v", err) } }