diff --git a/auth.go b/auth.go index 86eb36c..d2d8a26 100644 --- a/auth.go +++ b/auth.go @@ -32,7 +32,7 @@ type AuthContext struct { } type Authenticator interface { - Authenticate(reader io.Reader, writer io.Writer) (*AuthContext, error) + Authenticate(reader io.Reader, writer io.Writer, userIP string) (*AuthContext, error) GetCode() uint8 } @@ -43,7 +43,7 @@ func (a NoAuthAuthenticator) GetCode() uint8 { return MethodNoAuth } -func (a NoAuthAuthenticator) Authenticate(reader io.Reader, writer io.Writer) (*AuthContext, error) { +func (a NoAuthAuthenticator) Authenticate(reader io.Reader, writer io.Writer, userIP string) (*AuthContext, error) { _, err := writer.Write([]byte{VersionSocks5, MethodNoAuth}) return &AuthContext{MethodNoAuth, nil}, err } @@ -58,7 +58,7 @@ func (a UserPassAuthenticator) GetCode() uint8 { return MethodUserPassAuth } -func (a UserPassAuthenticator) Authenticate(reader io.Reader, writer io.Writer) (*AuthContext, error) { +func (a UserPassAuthenticator) Authenticate(reader io.Reader, writer io.Writer, userIP string) (*AuthContext, error) { // Tell the client to use user/pass auth if _, err := writer.Write([]byte{VersionSocks5, MethodUserPassAuth}); err != nil { return nil, err @@ -95,7 +95,7 @@ func (a UserPassAuthenticator) Authenticate(reader io.Reader, writer io.Writer) } // Verify the password - if a.Credentials.Valid(string(user), string(pass)) { + if a.Credentials.Valid(string(user), string(pass), userIP) { if _, err := writer.Write([]byte{UserPassAuthVersion, AuthSuccess}); err != nil { return nil, err } @@ -111,7 +111,7 @@ func (a UserPassAuthenticator) Authenticate(reader io.Reader, writer io.Writer) } // authenticate is used to handle connection authentication -func (s *Server) authenticate(conn io.Writer, bufConn io.Reader) (*AuthContext, error) { +func (s *Server) authenticate(conn io.Writer, bufConn io.Reader, userIP string) (*AuthContext, error) { // Get the methods methods, err := readMethods(bufConn) if err != nil { @@ -122,7 +122,7 @@ func (s *Server) authenticate(conn io.Writer, bufConn io.Reader) (*AuthContext, for _, method := range methods { cator, found := s.authMethods[method] if found { - return cator.Authenticate(bufConn, conn) + return cator.Authenticate(bufConn, conn, userIP) } } diff --git a/auth_test.go b/auth_test.go index b7230ff..47f80eb 100644 --- a/auth_test.go +++ b/auth_test.go @@ -11,7 +11,7 @@ func TestNoAuth(t *testing.T) { var resp bytes.Buffer s := New() - ctx, err := s.authenticate(&resp, req) + ctx, err := s.authenticate(&resp, req, "") if err != nil { t.Fatalf("err: %v", err) } @@ -40,7 +40,7 @@ func TestPasswordAuth_Valid(t *testing.T) { s := New(WithAuthMethods([]Authenticator{cator})) - ctx, err := s.authenticate(&resp, req) + ctx, err := s.authenticate(&resp, req, "") if err != nil { t.Fatalf("err: %v", err) } @@ -76,7 +76,7 @@ func TestPasswordAuth_Invalid(t *testing.T) { cator := UserPassAuthenticator{Credentials: cred} s := New(WithAuthMethods([]Authenticator{cator})) - ctx, err := s.authenticate(&resp, req) + ctx, err := s.authenticate(&resp, req, "") if err != UserAuthFailed { t.Fatalf("err: %v", err) } @@ -103,7 +103,7 @@ func TestNoSupportedAuth(t *testing.T) { s := New(WithAuthMethods([]Authenticator{cator})) - ctx, err := s.authenticate(&resp, req) + ctx, err := s.authenticate(&resp, req, "") if err != NoSupportedAuth { t.Fatalf("err: %v", err) } diff --git a/credentials.go b/credentials.go index 9666427..f49683e 100644 --- a/credentials.go +++ b/credentials.go @@ -1,14 +1,15 @@ package socks5 -// CredentialStore is used to support user/pass authentication +// CredentialStore is used to support user/pass authentication optional user ip +// if you want to limit user ip ,you can refuse it. type CredentialStore interface { - Valid(user, password string) bool + Valid(user, password, userIP string) bool } // StaticCredentials enables using a map directly as a credential store type StaticCredentials map[string]string -func (s StaticCredentials) Valid(user, password string) bool { +func (s StaticCredentials) Valid(user, password, userIP string) bool { pass, ok := s[user] if !ok { return false diff --git a/credentials_test.go b/credentials_test.go index e14256b..7e47760 100644 --- a/credentials_test.go +++ b/credentials_test.go @@ -10,15 +10,15 @@ func TestStaticCredentials(t *testing.T) { "baz": "", } - if !creds.Valid("foo", "bar") { + if !creds.Valid("foo", "bar", "") { t.Fatalf("expect valid") } - if !creds.Valid("baz", "") { + if !creds.Valid("baz", "", "") { t.Fatalf("expect valid") } - if creds.Valid("foo", "") { + if creds.Valid("foo", "", "") { t.Fatalf("expect invalid") } } diff --git a/socks5.go b/socks5.go index 836b252..155304b 100644 --- a/socks5.go +++ b/socks5.go @@ -117,7 +117,7 @@ func (s *Server) ServeConn(conn net.Conn) (err error) { // Ensure we are compatible if version[0] == VersionSocks5 { // Authenticate the connection - authContext, err = s.authenticate(conn, bufConn) + authContext, err = s.authenticate(conn, bufConn, conn.RemoteAddr().String()) if err != nil { err = fmt.Errorf("failed to authenticate: %v", err) s.logger.Errorf("%v", err) diff --git a/socks5_test.go b/socks5_test.go index 4d2da66..c70472a 100644 --- a/socks5_test.go +++ b/socks5_test.go @@ -266,7 +266,7 @@ func Test_SocksWithProxy(t *testing.T) { }() time.Sleep(10 * time.Millisecond) - dial, err := proxy.SOCKS5("tcp", "127.0.0.1:12395", &proxy.Auth{"foo", "bar"}, proxy.Direct) + dial, err := proxy.SOCKS5("tcp", "127.0.0.1:12395", &proxy.Auth{User: "foo", Password: "bar"}, proxy.Direct) if err != nil { t.Fatalf("err: %v", err) }