add valid check add userIP
This commit is contained in:
parent
881ed95bb4
commit
acba51a242
12
auth.go
12
auth.go
@ -32,7 +32,7 @@ type AuthContext struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Authenticator interface {
|
type Authenticator interface {
|
||||||
Authenticate(reader io.Reader, writer io.Writer) (*AuthContext, error)
|
Authenticate(reader io.Reader, writer io.Writer, userIP string) (*AuthContext, error)
|
||||||
GetCode() uint8
|
GetCode() uint8
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -43,7 +43,7 @@ func (a NoAuthAuthenticator) GetCode() uint8 {
|
|||||||
return MethodNoAuth
|
return MethodNoAuth
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a NoAuthAuthenticator) Authenticate(reader io.Reader, writer io.Writer) (*AuthContext, error) {
|
func (a NoAuthAuthenticator) Authenticate(reader io.Reader, writer io.Writer, userIP string) (*AuthContext, error) {
|
||||||
_, err := writer.Write([]byte{VersionSocks5, MethodNoAuth})
|
_, err := writer.Write([]byte{VersionSocks5, MethodNoAuth})
|
||||||
return &AuthContext{MethodNoAuth, nil}, err
|
return &AuthContext{MethodNoAuth, nil}, err
|
||||||
}
|
}
|
||||||
@ -58,7 +58,7 @@ func (a UserPassAuthenticator) GetCode() uint8 {
|
|||||||
return MethodUserPassAuth
|
return MethodUserPassAuth
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a UserPassAuthenticator) Authenticate(reader io.Reader, writer io.Writer) (*AuthContext, error) {
|
func (a UserPassAuthenticator) Authenticate(reader io.Reader, writer io.Writer, userIP string) (*AuthContext, error) {
|
||||||
// Tell the client to use user/pass auth
|
// Tell the client to use user/pass auth
|
||||||
if _, err := writer.Write([]byte{VersionSocks5, MethodUserPassAuth}); err != nil {
|
if _, err := writer.Write([]byte{VersionSocks5, MethodUserPassAuth}); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -95,7 +95,7 @@ func (a UserPassAuthenticator) Authenticate(reader io.Reader, writer io.Writer)
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Verify the password
|
// Verify the password
|
||||||
if a.Credentials.Valid(string(user), string(pass)) {
|
if a.Credentials.Valid(string(user), string(pass), userIP) {
|
||||||
if _, err := writer.Write([]byte{UserPassAuthVersion, AuthSuccess}); err != nil {
|
if _, err := writer.Write([]byte{UserPassAuthVersion, AuthSuccess}); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -111,7 +111,7 @@ func (a UserPassAuthenticator) Authenticate(reader io.Reader, writer io.Writer)
|
|||||||
}
|
}
|
||||||
|
|
||||||
// authenticate is used to handle connection authentication
|
// authenticate is used to handle connection authentication
|
||||||
func (s *Server) authenticate(conn io.Writer, bufConn io.Reader) (*AuthContext, error) {
|
func (s *Server) authenticate(conn io.Writer, bufConn io.Reader, userIP string) (*AuthContext, error) {
|
||||||
// Get the methods
|
// Get the methods
|
||||||
methods, err := readMethods(bufConn)
|
methods, err := readMethods(bufConn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -122,7 +122,7 @@ func (s *Server) authenticate(conn io.Writer, bufConn io.Reader) (*AuthContext,
|
|||||||
for _, method := range methods {
|
for _, method := range methods {
|
||||||
cator, found := s.authMethods[method]
|
cator, found := s.authMethods[method]
|
||||||
if found {
|
if found {
|
||||||
return cator.Authenticate(bufConn, conn)
|
return cator.Authenticate(bufConn, conn, userIP)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -11,7 +11,7 @@ func TestNoAuth(t *testing.T) {
|
|||||||
var resp bytes.Buffer
|
var resp bytes.Buffer
|
||||||
|
|
||||||
s := New()
|
s := New()
|
||||||
ctx, err := s.authenticate(&resp, req)
|
ctx, err := s.authenticate(&resp, req, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %v", err)
|
t.Fatalf("err: %v", err)
|
||||||
}
|
}
|
||||||
@ -40,7 +40,7 @@ func TestPasswordAuth_Valid(t *testing.T) {
|
|||||||
|
|
||||||
s := New(WithAuthMethods([]Authenticator{cator}))
|
s := New(WithAuthMethods([]Authenticator{cator}))
|
||||||
|
|
||||||
ctx, err := s.authenticate(&resp, req)
|
ctx, err := s.authenticate(&resp, req, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %v", err)
|
t.Fatalf("err: %v", err)
|
||||||
}
|
}
|
||||||
@ -76,7 +76,7 @@ func TestPasswordAuth_Invalid(t *testing.T) {
|
|||||||
cator := UserPassAuthenticator{Credentials: cred}
|
cator := UserPassAuthenticator{Credentials: cred}
|
||||||
s := New(WithAuthMethods([]Authenticator{cator}))
|
s := New(WithAuthMethods([]Authenticator{cator}))
|
||||||
|
|
||||||
ctx, err := s.authenticate(&resp, req)
|
ctx, err := s.authenticate(&resp, req, "")
|
||||||
if err != UserAuthFailed {
|
if err != UserAuthFailed {
|
||||||
t.Fatalf("err: %v", err)
|
t.Fatalf("err: %v", err)
|
||||||
}
|
}
|
||||||
@ -103,7 +103,7 @@ func TestNoSupportedAuth(t *testing.T) {
|
|||||||
|
|
||||||
s := New(WithAuthMethods([]Authenticator{cator}))
|
s := New(WithAuthMethods([]Authenticator{cator}))
|
||||||
|
|
||||||
ctx, err := s.authenticate(&resp, req)
|
ctx, err := s.authenticate(&resp, req, "")
|
||||||
if err != NoSupportedAuth {
|
if err != NoSupportedAuth {
|
||||||
t.Fatalf("err: %v", err)
|
t.Fatalf("err: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -1,14 +1,15 @@
|
|||||||
package socks5
|
package socks5
|
||||||
|
|
||||||
// CredentialStore is used to support user/pass authentication
|
// CredentialStore is used to support user/pass authentication optional user ip
|
||||||
|
// if you want to limit user ip ,you can refuse it.
|
||||||
type CredentialStore interface {
|
type CredentialStore interface {
|
||||||
Valid(user, password string) bool
|
Valid(user, password, userIP string) bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// StaticCredentials enables using a map directly as a credential store
|
// StaticCredentials enables using a map directly as a credential store
|
||||||
type StaticCredentials map[string]string
|
type StaticCredentials map[string]string
|
||||||
|
|
||||||
func (s StaticCredentials) Valid(user, password string) bool {
|
func (s StaticCredentials) Valid(user, password, userIP string) bool {
|
||||||
pass, ok := s[user]
|
pass, ok := s[user]
|
||||||
if !ok {
|
if !ok {
|
||||||
return false
|
return false
|
||||||
|
@ -10,15 +10,15 @@ func TestStaticCredentials(t *testing.T) {
|
|||||||
"baz": "",
|
"baz": "",
|
||||||
}
|
}
|
||||||
|
|
||||||
if !creds.Valid("foo", "bar") {
|
if !creds.Valid("foo", "bar", "") {
|
||||||
t.Fatalf("expect valid")
|
t.Fatalf("expect valid")
|
||||||
}
|
}
|
||||||
|
|
||||||
if !creds.Valid("baz", "") {
|
if !creds.Valid("baz", "", "") {
|
||||||
t.Fatalf("expect valid")
|
t.Fatalf("expect valid")
|
||||||
}
|
}
|
||||||
|
|
||||||
if creds.Valid("foo", "") {
|
if creds.Valid("foo", "", "") {
|
||||||
t.Fatalf("expect invalid")
|
t.Fatalf("expect invalid")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -117,7 +117,7 @@ func (s *Server) ServeConn(conn net.Conn) (err error) {
|
|||||||
// Ensure we are compatible
|
// Ensure we are compatible
|
||||||
if version[0] == VersionSocks5 {
|
if version[0] == VersionSocks5 {
|
||||||
// Authenticate the connection
|
// Authenticate the connection
|
||||||
authContext, err = s.authenticate(conn, bufConn)
|
authContext, err = s.authenticate(conn, bufConn, conn.RemoteAddr().String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = fmt.Errorf("failed to authenticate: %v", err)
|
err = fmt.Errorf("failed to authenticate: %v", err)
|
||||||
s.logger.Errorf("%v", err)
|
s.logger.Errorf("%v", err)
|
||||||
|
@ -266,7 +266,7 @@ func Test_SocksWithProxy(t *testing.T) {
|
|||||||
}()
|
}()
|
||||||
time.Sleep(10 * time.Millisecond)
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
|
||||||
dial, err := proxy.SOCKS5("tcp", "127.0.0.1:12395", &proxy.Auth{"foo", "bar"}, proxy.Direct)
|
dial, err := proxy.SOCKS5("tcp", "127.0.0.1:12395", &proxy.Auth{User: "foo", Password: "bar"}, proxy.Direct)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %v", err)
|
t.Fatalf("err: %v", err)
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user