add valid check add userIP

This commit is contained in:
mo 2020-04-21 22:28:34 +08:00
parent 881ed95bb4
commit acba51a242
6 changed files with 19 additions and 18 deletions

12
auth.go

@ -32,7 +32,7 @@ type AuthContext struct {
}
type Authenticator interface {
Authenticate(reader io.Reader, writer io.Writer) (*AuthContext, error)
Authenticate(reader io.Reader, writer io.Writer, userIP string) (*AuthContext, error)
GetCode() uint8
}
@ -43,7 +43,7 @@ func (a NoAuthAuthenticator) GetCode() uint8 {
return MethodNoAuth
}
func (a NoAuthAuthenticator) Authenticate(reader io.Reader, writer io.Writer) (*AuthContext, error) {
func (a NoAuthAuthenticator) Authenticate(reader io.Reader, writer io.Writer, userIP string) (*AuthContext, error) {
_, err := writer.Write([]byte{VersionSocks5, MethodNoAuth})
return &AuthContext{MethodNoAuth, nil}, err
}
@ -58,7 +58,7 @@ func (a UserPassAuthenticator) GetCode() uint8 {
return MethodUserPassAuth
}
func (a UserPassAuthenticator) Authenticate(reader io.Reader, writer io.Writer) (*AuthContext, error) {
func (a UserPassAuthenticator) Authenticate(reader io.Reader, writer io.Writer, userIP string) (*AuthContext, error) {
// Tell the client to use user/pass auth
if _, err := writer.Write([]byte{VersionSocks5, MethodUserPassAuth}); err != nil {
return nil, err
@ -95,7 +95,7 @@ func (a UserPassAuthenticator) Authenticate(reader io.Reader, writer io.Writer)
}
// Verify the password
if a.Credentials.Valid(string(user), string(pass)) {
if a.Credentials.Valid(string(user), string(pass), userIP) {
if _, err := writer.Write([]byte{UserPassAuthVersion, AuthSuccess}); err != nil {
return nil, err
}
@ -111,7 +111,7 @@ func (a UserPassAuthenticator) Authenticate(reader io.Reader, writer io.Writer)
}
// authenticate is used to handle connection authentication
func (s *Server) authenticate(conn io.Writer, bufConn io.Reader) (*AuthContext, error) {
func (s *Server) authenticate(conn io.Writer, bufConn io.Reader, userIP string) (*AuthContext, error) {
// Get the methods
methods, err := readMethods(bufConn)
if err != nil {
@ -122,7 +122,7 @@ func (s *Server) authenticate(conn io.Writer, bufConn io.Reader) (*AuthContext,
for _, method := range methods {
cator, found := s.authMethods[method]
if found {
return cator.Authenticate(bufConn, conn)
return cator.Authenticate(bufConn, conn, userIP)
}
}

@ -11,7 +11,7 @@ func TestNoAuth(t *testing.T) {
var resp bytes.Buffer
s := New()
ctx, err := s.authenticate(&resp, req)
ctx, err := s.authenticate(&resp, req, "")
if err != nil {
t.Fatalf("err: %v", err)
}
@ -40,7 +40,7 @@ func TestPasswordAuth_Valid(t *testing.T) {
s := New(WithAuthMethods([]Authenticator{cator}))
ctx, err := s.authenticate(&resp, req)
ctx, err := s.authenticate(&resp, req, "")
if err != nil {
t.Fatalf("err: %v", err)
}
@ -76,7 +76,7 @@ func TestPasswordAuth_Invalid(t *testing.T) {
cator := UserPassAuthenticator{Credentials: cred}
s := New(WithAuthMethods([]Authenticator{cator}))
ctx, err := s.authenticate(&resp, req)
ctx, err := s.authenticate(&resp, req, "")
if err != UserAuthFailed {
t.Fatalf("err: %v", err)
}
@ -103,7 +103,7 @@ func TestNoSupportedAuth(t *testing.T) {
s := New(WithAuthMethods([]Authenticator{cator}))
ctx, err := s.authenticate(&resp, req)
ctx, err := s.authenticate(&resp, req, "")
if err != NoSupportedAuth {
t.Fatalf("err: %v", err)
}

@ -1,14 +1,15 @@
package socks5
// CredentialStore is used to support user/pass authentication
// CredentialStore is used to support user/pass authentication optional user ip
// if you want to limit user ip ,you can refuse it.
type CredentialStore interface {
Valid(user, password string) bool
Valid(user, password, userIP string) bool
}
// StaticCredentials enables using a map directly as a credential store
type StaticCredentials map[string]string
func (s StaticCredentials) Valid(user, password string) bool {
func (s StaticCredentials) Valid(user, password, userIP string) bool {
pass, ok := s[user]
if !ok {
return false

@ -10,15 +10,15 @@ func TestStaticCredentials(t *testing.T) {
"baz": "",
}
if !creds.Valid("foo", "bar") {
if !creds.Valid("foo", "bar", "") {
t.Fatalf("expect valid")
}
if !creds.Valid("baz", "") {
if !creds.Valid("baz", "", "") {
t.Fatalf("expect valid")
}
if creds.Valid("foo", "") {
if creds.Valid("foo", "", "") {
t.Fatalf("expect invalid")
}
}

@ -117,7 +117,7 @@ func (s *Server) ServeConn(conn net.Conn) (err error) {
// Ensure we are compatible
if version[0] == VersionSocks5 {
// Authenticate the connection
authContext, err = s.authenticate(conn, bufConn)
authContext, err = s.authenticate(conn, bufConn, conn.RemoteAddr().String())
if err != nil {
err = fmt.Errorf("failed to authenticate: %v", err)
s.logger.Errorf("%v", err)

@ -266,7 +266,7 @@ func Test_SocksWithProxy(t *testing.T) {
}()
time.Sleep(10 * time.Millisecond)
dial, err := proxy.SOCKS5("tcp", "127.0.0.1:12395", &proxy.Auth{"foo", "bar"}, proxy.Direct)
dial, err := proxy.SOCKS5("tcp", "127.0.0.1:12395", &proxy.Auth{User: "foo", Password: "bar"}, proxy.Direct)
if err != nil {
t.Fatalf("err: %v", err)
}