add user handle
revive inspect
This commit is contained in:
parent
8b383556a2
commit
f77e659826
17
auth.go
17
auth.go
@ -5,6 +5,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// auth defined
|
||||||
const (
|
const (
|
||||||
MethodNoAuth = uint8(0)
|
MethodNoAuth = uint8(0)
|
||||||
MethodGSSAPI = uint8(1)
|
MethodGSSAPI = uint8(1)
|
||||||
@ -15,12 +16,13 @@ const (
|
|||||||
AuthFailure = uint8(1)
|
AuthFailure = uint8(1)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// auth error defined
|
||||||
var (
|
var (
|
||||||
UserAuthFailed = fmt.Errorf("User authentication failed")
|
ErrUserAuthFailed = fmt.Errorf("User authentication failed")
|
||||||
NoSupportedAuth = fmt.Errorf("No supported authentication mechanism")
|
ErrNoSupportedAuth = fmt.Errorf("No supported authentication mechanism")
|
||||||
)
|
)
|
||||||
|
|
||||||
// A Request encapsulates authentication state provided
|
// AuthContext A Request encapsulates authentication state provided
|
||||||
// during negotiation
|
// during negotiation
|
||||||
type AuthContext struct {
|
type AuthContext struct {
|
||||||
// Provided auth method
|
// Provided auth method
|
||||||
@ -31,6 +33,7 @@ type AuthContext struct {
|
|||||||
Payload map[string]string
|
Payload map[string]string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Authenticator provide auth
|
||||||
type Authenticator interface {
|
type Authenticator interface {
|
||||||
Authenticate(reader io.Reader, writer io.Writer, userAddr string) (*AuthContext, error)
|
Authenticate(reader io.Reader, writer io.Writer, userAddr string) (*AuthContext, error)
|
||||||
GetCode() uint8
|
GetCode() uint8
|
||||||
@ -39,10 +42,12 @@ type Authenticator interface {
|
|||||||
// NoAuthAuthenticator is used to handle the "No Authentication" mode
|
// NoAuthAuthenticator is used to handle the "No Authentication" mode
|
||||||
type NoAuthAuthenticator struct{}
|
type NoAuthAuthenticator struct{}
|
||||||
|
|
||||||
|
// GetCode implement interface Authenticator
|
||||||
func (a NoAuthAuthenticator) GetCode() uint8 {
|
func (a NoAuthAuthenticator) GetCode() uint8 {
|
||||||
return MethodNoAuth
|
return MethodNoAuth
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Authenticate implement interface Authenticator
|
||||||
func (a NoAuthAuthenticator) Authenticate(reader io.Reader, writer io.Writer, userAddr string) (*AuthContext, error) {
|
func (a NoAuthAuthenticator) Authenticate(reader io.Reader, writer io.Writer, userAddr string) (*AuthContext, error) {
|
||||||
_, err := writer.Write([]byte{VersionSocks5, MethodNoAuth})
|
_, err := writer.Write([]byte{VersionSocks5, MethodNoAuth})
|
||||||
return &AuthContext{MethodNoAuth, nil}, err
|
return &AuthContext{MethodNoAuth, nil}, err
|
||||||
@ -54,10 +59,12 @@ type UserPassAuthenticator struct {
|
|||||||
Credentials CredentialStore
|
Credentials CredentialStore
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetCode implement interface Authenticator
|
||||||
func (a UserPassAuthenticator) GetCode() uint8 {
|
func (a UserPassAuthenticator) GetCode() uint8 {
|
||||||
return MethodUserPassAuth
|
return MethodUserPassAuth
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Authenticate implement interface Authenticator
|
||||||
func (a UserPassAuthenticator) Authenticate(reader io.Reader, writer io.Writer, userAddr string) (*AuthContext, error) {
|
func (a UserPassAuthenticator) Authenticate(reader io.Reader, writer io.Writer, userAddr string) (*AuthContext, error) {
|
||||||
// Tell the client to use user/pass auth
|
// Tell the client to use user/pass auth
|
||||||
if _, err := writer.Write([]byte{VersionSocks5, MethodUserPassAuth}); err != nil {
|
if _, err := writer.Write([]byte{VersionSocks5, MethodUserPassAuth}); err != nil {
|
||||||
@ -103,7 +110,7 @@ func (a UserPassAuthenticator) Authenticate(reader io.Reader, writer io.Writer,
|
|||||||
if _, err := writer.Write([]byte{UserPassAuthVersion, AuthFailure}); err != nil {
|
if _, err := writer.Write([]byte{UserPassAuthVersion, AuthFailure}); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return nil, UserAuthFailed
|
return nil, ErrUserAuthFailed
|
||||||
}
|
}
|
||||||
|
|
||||||
// Done
|
// Done
|
||||||
@ -134,7 +141,7 @@ func (s *Server) authenticate(conn io.Writer, bufConn io.Reader, userAddr string
|
|||||||
// authentication mechanism
|
// authentication mechanism
|
||||||
func noAcceptableAuth(conn io.Writer) error {
|
func noAcceptableAuth(conn io.Writer) error {
|
||||||
conn.Write([]byte{VersionSocks5, MethodNoAcceptable})
|
conn.Write([]byte{VersionSocks5, MethodNoAcceptable})
|
||||||
return NoSupportedAuth
|
return ErrNoSupportedAuth
|
||||||
}
|
}
|
||||||
|
|
||||||
// readMethods is used to read the number of methods
|
// readMethods is used to read the number of methods
|
||||||
|
@ -77,7 +77,7 @@ func TestPasswordAuth_Invalid(t *testing.T) {
|
|||||||
s := New(WithAuthMethods([]Authenticator{cator}))
|
s := New(WithAuthMethods([]Authenticator{cator}))
|
||||||
|
|
||||||
ctx, err := s.authenticate(&resp, req, "")
|
ctx, err := s.authenticate(&resp, req, "")
|
||||||
if err != UserAuthFailed {
|
if err != ErrUserAuthFailed {
|
||||||
t.Fatalf("err: %v", err)
|
t.Fatalf("err: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -104,7 +104,7 @@ func TestNoSupportedAuth(t *testing.T) {
|
|||||||
s := New(WithAuthMethods([]Authenticator{cator}))
|
s := New(WithAuthMethods([]Authenticator{cator}))
|
||||||
|
|
||||||
ctx, err := s.authenticate(&resp, req, "")
|
ctx, err := s.authenticate(&resp, req, "")
|
||||||
if err != NoSupportedAuth {
|
if err != ErrNoSupportedAuth {
|
||||||
t.Fatalf("err: %v", err)
|
t.Fatalf("err: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -9,6 +9,7 @@ type CredentialStore interface {
|
|||||||
// StaticCredentials enables using a map directly as a credential store
|
// StaticCredentials enables using a map directly as a credential store
|
||||||
type StaticCredentials map[string]string
|
type StaticCredentials map[string]string
|
||||||
|
|
||||||
|
// Valid implement interface CredentialStore
|
||||||
func (s StaticCredentials) Valid(user, password, userAddr string) bool {
|
func (s StaticCredentials) Valid(user, password, userAddr string) bool {
|
||||||
pass, ok := s[user]
|
pass, ok := s[user]
|
||||||
if !ok {
|
if !ok {
|
||||||
|
@ -1,12 +1,14 @@
|
|||||||
package socks5
|
package socks5
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"strconv"
|
"strconv"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// socks const defined
|
||||||
const (
|
const (
|
||||||
// protocol version
|
// protocol version
|
||||||
VersionSocks4 = uint8(4)
|
VersionSocks4 = uint8(4)
|
||||||
@ -83,6 +85,7 @@ type Header struct {
|
|||||||
addrType uint8
|
addrType uint8
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Parse to header
|
||||||
func Parse(r io.Reader) (hd Header, err error) {
|
func Parse(r io.Reader) (hd Header, err error) {
|
||||||
// Read the version and command
|
// Read the version and command
|
||||||
tmp := make([]byte, headVERLen+headCMDLen)
|
tmp := make([]byte, headVERLen+headCMDLen)
|
||||||
@ -142,13 +145,15 @@ func Parse(r io.Reader) (hd Header, err error) {
|
|||||||
hd.Address.IP = addr[:net.IPv6len]
|
hd.Address.IP = addr[:net.IPv6len]
|
||||||
hd.Address.Port = buildPort(addr[net.IPv6len], addr[net.IPv6len+1])
|
hd.Address.Port = buildPort(addr[net.IPv6len], addr[net.IPv6len+1])
|
||||||
default:
|
default:
|
||||||
return hd, unrecognizedAddrType
|
return hd, errUnrecognizedAddrType
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return hd, nil
|
return hd, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Bytes returns a slice of header
|
||||||
func (h Header) Bytes() (b []byte) {
|
func (h Header) Bytes() (b []byte) {
|
||||||
|
bytes.Buffer{}.Bytes()
|
||||||
b = append(b, h.Version)
|
b = append(b, h.Version)
|
||||||
b = append(b, h.Command)
|
b = append(b, h.Command)
|
||||||
hiPort, loPort := breakPort(h.Address.Port)
|
hiPort, loPort := breakPort(h.Address.Port)
|
||||||
|
@ -4,19 +4,22 @@ import (
|
|||||||
"log"
|
"log"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Logger is used to provide debug logger
|
||||||
type Logger interface {
|
type Logger interface {
|
||||||
Errorf(format string, arg ...interface{})
|
Errorf(format string, arg ...interface{})
|
||||||
}
|
}
|
||||||
|
|
||||||
// 标准输出
|
// Std std logger
|
||||||
type Std struct {
|
type Std struct {
|
||||||
*log.Logger
|
*log.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewLogger new std logger with log.logger
|
||||||
func NewLogger(l *log.Logger) *Std {
|
func NewLogger(l *log.Logger) *Std {
|
||||||
return &Std{l}
|
return &Std{l}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Errorf implement interface Logger
|
||||||
func (sf Std) Errorf(format string, args ...interface{}) {
|
func (sf Std) Errorf(format string, args ...interface{}) {
|
||||||
sf.Logger.Printf("[E]: "+format, args...)
|
sf.Logger.Printf("[E]: "+format, args...)
|
||||||
}
|
}
|
||||||
|
40
option.go
40
option.go
@ -2,12 +2,14 @@ package socks5
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"io"
|
||||||
"net"
|
"net"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Option user's option
|
||||||
type Option func(s *Server)
|
type Option func(s *Server)
|
||||||
|
|
||||||
// AuthMethods can be provided to implement custom authentication
|
// WithAuthMethods can be provided to implement custom authentication
|
||||||
// By default, "auth-less" mode is enabled.
|
// By default, "auth-less" mode is enabled.
|
||||||
// For password-based auth use UserPassAuthenticator.
|
// For password-based auth use UserPassAuthenticator.
|
||||||
func WithAuthMethods(authMethods []Authenticator) Option {
|
func WithAuthMethods(authMethods []Authenticator) Option {
|
||||||
@ -19,7 +21,7 @@ func WithAuthMethods(authMethods []Authenticator) Option {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// If provided, username/password authentication is enabled,
|
// WithCredential If provided, username/password authentication is enabled,
|
||||||
// by appending a UserPassAuthenticator to AuthMethods. If not provided,
|
// by appending a UserPassAuthenticator to AuthMethods. If not provided,
|
||||||
// and AUthMethods is nil, then "auth-less" mode is enabled.
|
// and AUthMethods is nil, then "auth-less" mode is enabled.
|
||||||
func WithCredential(cs CredentialStore) Option {
|
func WithCredential(cs CredentialStore) Option {
|
||||||
@ -30,7 +32,7 @@ func WithCredential(cs CredentialStore) Option {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// resolver can be provided to do custom name resolution.
|
// WithResolver can be provided to do custom name resolution.
|
||||||
// Defaults to DNSResolver if not provided.
|
// Defaults to DNSResolver if not provided.
|
||||||
func WithResolver(res NameResolver) Option {
|
func WithResolver(res NameResolver) Option {
|
||||||
return func(s *Server) {
|
return func(s *Server) {
|
||||||
@ -40,7 +42,7 @@ func WithResolver(res NameResolver) Option {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// rules is provided to enable custom logic around permitting
|
// WithRule is provided to enable custom logic around permitting
|
||||||
// various commands. If not provided, PermitAll is used.
|
// various commands. If not provided, PermitAll is used.
|
||||||
func WithRule(rule RuleSet) Option {
|
func WithRule(rule RuleSet) Option {
|
||||||
return func(s *Server) {
|
return func(s *Server) {
|
||||||
@ -50,7 +52,7 @@ func WithRule(rule RuleSet) Option {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// rewriter can be used to transparently rewrite addresses.
|
// WithRewriter can be used to transparently rewrite addresses.
|
||||||
// This is invoked before the RuleSet is invoked.
|
// This is invoked before the RuleSet is invoked.
|
||||||
// Defaults to NoRewrite.
|
// Defaults to NoRewrite.
|
||||||
func WithRewriter(rew AddressRewriter) Option {
|
func WithRewriter(rew AddressRewriter) Option {
|
||||||
@ -61,7 +63,7 @@ func WithRewriter(rew AddressRewriter) Option {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// bindIP is used for bind or udp associate
|
// WithBindIP is used for bind or udp associate
|
||||||
func WithBindIP(ip net.IP) Option {
|
func WithBindIP(ip net.IP) Option {
|
||||||
return func(s *Server) {
|
return func(s *Server) {
|
||||||
if len(ip) != 0 {
|
if len(ip) != 0 {
|
||||||
@ -71,7 +73,7 @@ func WithBindIP(ip net.IP) Option {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// logger can be used to provide a custom log target.
|
// WithLogger can be used to provide a custom log target.
|
||||||
// Defaults to ioutil.Discard.
|
// Defaults to ioutil.Discard.
|
||||||
func WithLogger(l Logger) Option {
|
func WithLogger(l Logger) Option {
|
||||||
return func(s *Server) {
|
return func(s *Server) {
|
||||||
@ -81,7 +83,7 @@ func WithLogger(l Logger) Option {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Optional function for dialing out
|
// WithDial Optional function for dialing out
|
||||||
func WithDial(dial func(ctx context.Context, network, addr string) (net.Conn, error)) Option {
|
func WithDial(dial func(ctx context.Context, network, addr string) (net.Conn, error)) Option {
|
||||||
return func(s *Server) {
|
return func(s *Server) {
|
||||||
if dial != nil {
|
if dial != nil {
|
||||||
@ -90,6 +92,7 @@ func WithDial(dial func(ctx context.Context, network, addr string) (net.Conn, er
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WithGPool can be provided to do custom goroutine pool.
|
||||||
func WithGPool(pool GPool) Option {
|
func WithGPool(pool GPool) Option {
|
||||||
return func(s *Server) {
|
return func(s *Server) {
|
||||||
if pool != nil {
|
if pool != nil {
|
||||||
@ -97,3 +100,24 @@ func WithGPool(pool GPool) Option {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WithConnectHandle is used to handle a user's connect command
|
||||||
|
func WithConnectHandle(h func(ctx context.Context, writer io.Writer, req *Request) error) Option {
|
||||||
|
return func(s *Server) {
|
||||||
|
s.userConnectHandle = h
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithBindHandle is used to handle a user's bind command
|
||||||
|
func WithBindHandle(h func(ctx context.Context, writer io.Writer, req *Request) error) Option {
|
||||||
|
return func(s *Server) {
|
||||||
|
s.userBindHandle = h
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithAssociateHandle is used to handle a user's associate command
|
||||||
|
func WithAssociateHandle(h func(ctx context.Context, writer io.Writer, req *Request) error) Option {
|
||||||
|
return func(s *Server) {
|
||||||
|
s.userAssociateHandle = h
|
||||||
|
}
|
||||||
|
}
|
||||||
|
102
request.go
102
request.go
@ -10,7 +10,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
unrecognizedAddrType = fmt.Errorf("Unrecognized address type")
|
errUnrecognizedAddrType = fmt.Errorf("Unrecognized address type")
|
||||||
)
|
)
|
||||||
|
|
||||||
// AddressRewriter is used to rewrite a destination transparently
|
// AddressRewriter is used to rewrite a destination transparently
|
||||||
@ -25,11 +25,12 @@ type Request struct {
|
|||||||
AuthContext *AuthContext
|
AuthContext *AuthContext
|
||||||
// AddrSpec of the the network that sent the request
|
// AddrSpec of the the network that sent the request
|
||||||
RemoteAddr *AddrSpec
|
RemoteAddr *AddrSpec
|
||||||
// AddrSpec of the desired destination
|
|
||||||
DestAddr *AddrSpec
|
|
||||||
// AddrSpec of the actual destination (might be affected by rewrite)
|
// AddrSpec of the actual destination (might be affected by rewrite)
|
||||||
realDestAddr *AddrSpec
|
DestAddr *AddrSpec
|
||||||
bufConn io.Reader
|
// Reader connect of request
|
||||||
|
Reader io.Reader
|
||||||
|
// AddrSpec of the desired destination
|
||||||
|
RawDestAddr *AddrSpec
|
||||||
}
|
}
|
||||||
|
|
||||||
type conn interface {
|
type conn interface {
|
||||||
@ -55,9 +56,9 @@ func NewRequest(bufConn io.Reader) (*Request, error) {
|
|||||||
return nil, fmt.Errorf("unrecognized command[%d]", hd.Command)
|
return nil, fmt.Errorf("unrecognized command[%d]", hd.Command)
|
||||||
}
|
}
|
||||||
return &Request{
|
return &Request{
|
||||||
Header: hd,
|
Header: hd,
|
||||||
DestAddr: &hd.Address,
|
RawDestAddr: &hd.Address,
|
||||||
bufConn: bufConn,
|
Reader: bufConn,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -66,45 +67,54 @@ func (s *Server) handleRequest(write io.Writer, req *Request) error {
|
|||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
// Resolve the address if we have a FQDN
|
// Resolve the address if we have a FQDN
|
||||||
dest := req.DestAddr
|
dest := req.RawDestAddr
|
||||||
if dest.FQDN != "" {
|
if dest.FQDN != "" {
|
||||||
ctx_, addr, err := s.resolver.Resolve(ctx, dest.FQDN)
|
_ctx, addr, err := s.resolver.Resolve(ctx, dest.FQDN)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err := sendReply(write, req.Header, hostUnreachable); err != nil {
|
if err := SendReply(write, req.Header, hostUnreachable); err != nil {
|
||||||
return fmt.Errorf("failed to send reply, %v", err)
|
return fmt.Errorf("failed to send reply, %v", err)
|
||||||
}
|
}
|
||||||
return fmt.Errorf("failed to resolve destination[%v], %v", dest.FQDN, err)
|
return fmt.Errorf("failed to resolve destination[%v], %v", dest.FQDN, err)
|
||||||
}
|
}
|
||||||
ctx = ctx_
|
ctx = _ctx
|
||||||
dest.IP = addr
|
dest.IP = addr
|
||||||
}
|
}
|
||||||
|
|
||||||
// Apply any address rewrites
|
// Apply any address rewrites
|
||||||
req.realDestAddr = req.DestAddr
|
req.DestAddr = req.RawDestAddr
|
||||||
if s.rewriter != nil {
|
if s.rewriter != nil {
|
||||||
ctx, req.realDestAddr = s.rewriter.Rewrite(ctx, req)
|
ctx, req.DestAddr = s.rewriter.Rewrite(ctx, req)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if this is allowed
|
// Check if this is allowed
|
||||||
if ctx_, ok := s.rules.Allow(ctx, req); !ok {
|
_ctx, ok := s.rules.Allow(ctx, req)
|
||||||
if err := sendReply(write, req.Header, ruleFailure); err != nil {
|
if !ok {
|
||||||
|
if err := SendReply(write, req.Header, ruleFailure); err != nil {
|
||||||
return fmt.Errorf("failed to send reply, %v", err)
|
return fmt.Errorf("failed to send reply, %v", err)
|
||||||
}
|
}
|
||||||
return fmt.Errorf("bind to %v blocked by rules", req.DestAddr)
|
return fmt.Errorf("bind to %v blocked by rules", req.RawDestAddr)
|
||||||
} else {
|
|
||||||
ctx = ctx_
|
|
||||||
}
|
}
|
||||||
|
ctx = _ctx
|
||||||
|
|
||||||
// Switch on the command
|
// Switch on the command
|
||||||
switch req.Command {
|
switch req.Command {
|
||||||
case CommandConnect:
|
case CommandConnect:
|
||||||
|
if s.userConnectHandle != nil {
|
||||||
|
return s.userConnectHandle(ctx, write, req)
|
||||||
|
}
|
||||||
return s.handleConnect(ctx, write, req)
|
return s.handleConnect(ctx, write, req)
|
||||||
case CommandBind:
|
case CommandBind:
|
||||||
|
if s.userBindHandle != nil {
|
||||||
|
return s.userBindHandle(ctx, write, req)
|
||||||
|
}
|
||||||
return s.handleBind(ctx, write, req)
|
return s.handleBind(ctx, write, req)
|
||||||
case CommandAssociate:
|
case CommandAssociate:
|
||||||
|
if s.userAssociateHandle != nil {
|
||||||
|
return s.userAssociateHandle(ctx, write, req)
|
||||||
|
}
|
||||||
return s.handleAssociate(ctx, write, req)
|
return s.handleAssociate(ctx, write, req)
|
||||||
default:
|
default:
|
||||||
if err := sendReply(write, req.Header, commandNotSupported); err != nil {
|
if err := SendReply(write, req.Header, commandNotSupported); err != nil {
|
||||||
return fmt.Errorf("failed to send reply, %v", err)
|
return fmt.Errorf("failed to send reply, %v", err)
|
||||||
}
|
}
|
||||||
return fmt.Errorf("unsupported command[%v]", req.Command)
|
return fmt.Errorf("unsupported command[%v]", req.Command)
|
||||||
@ -120,7 +130,7 @@ func (s *Server) handleConnect(ctx context.Context, writer io.Writer, req *Reque
|
|||||||
return net.Dial(net_, addr)
|
return net.Dial(net_, addr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
target, err := dial(ctx, "tcp", req.realDestAddr.Address())
|
target, err := dial(ctx, "tcp", req.DestAddr.Address())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
msg := err.Error()
|
msg := err.Error()
|
||||||
resp := hostUnreachable
|
resp := hostUnreachable
|
||||||
@ -129,23 +139,23 @@ func (s *Server) handleConnect(ctx context.Context, writer io.Writer, req *Reque
|
|||||||
} else if strings.Contains(msg, "network is unreachable") {
|
} else if strings.Contains(msg, "network is unreachable") {
|
||||||
resp = networkUnreachable
|
resp = networkUnreachable
|
||||||
}
|
}
|
||||||
if err := sendReply(writer, req.Header, resp); err != nil {
|
if err := SendReply(writer, req.Header, resp); err != nil {
|
||||||
return fmt.Errorf("failed to send reply, %v", err)
|
return fmt.Errorf("failed to send reply, %v", err)
|
||||||
}
|
}
|
||||||
return fmt.Errorf("connect to %v failed, %v", req.DestAddr, err)
|
return fmt.Errorf("connect to %v failed, %v", req.RawDestAddr, err)
|
||||||
}
|
}
|
||||||
defer target.Close()
|
defer target.Close()
|
||||||
|
|
||||||
// Send success
|
// Send success
|
||||||
if err := sendReply(writer, req.Header, successReply, target.LocalAddr()); err != nil {
|
if err := SendReply(writer, req.Header, successReply, target.LocalAddr()); err != nil {
|
||||||
return fmt.Errorf("failed to send reply, %v", err)
|
return fmt.Errorf("failed to send reply, %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start proxying
|
// Start proxying
|
||||||
errCh := make(chan error, 2)
|
errCh := make(chan error, 2)
|
||||||
|
|
||||||
s.submit(func() { errCh <- s.proxy(target, req.bufConn) })
|
s.submit(func() { errCh <- s.Proxy(target, req.Reader) })
|
||||||
s.submit(func() { errCh <- s.proxy(writer, target) })
|
s.submit(func() { errCh <- s.Proxy(writer, target) })
|
||||||
|
|
||||||
// Wait
|
// Wait
|
||||||
for i := 0; i < 2; i++ {
|
for i := 0; i < 2; i++ {
|
||||||
@ -161,7 +171,7 @@ func (s *Server) handleConnect(ctx context.Context, writer io.Writer, req *Reque
|
|||||||
// handleBind is used to handle a connect command
|
// handleBind is used to handle a connect command
|
||||||
func (s *Server) handleBind(ctx context.Context, writer io.Writer, req *Request) error {
|
func (s *Server) handleBind(ctx context.Context, writer io.Writer, req *Request) error {
|
||||||
// TODO: Support bind
|
// TODO: Support bind
|
||||||
if err := sendReply(writer, req.Header, commandNotSupported); err != nil {
|
if err := SendReply(writer, req.Header, commandNotSupported); err != nil {
|
||||||
return fmt.Errorf("failed to send reply: %v", err)
|
return fmt.Errorf("failed to send reply: %v", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@ -176,7 +186,7 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, req *Req
|
|||||||
return net.Dial(net_, addr)
|
return net.Dial(net_, addr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
target, err := dial(ctx, "udp", req.realDestAddr.Address())
|
target, err := dial(ctx, "udp", req.DestAddr.Address())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
msg := err.Error()
|
msg := err.Error()
|
||||||
resp := hostUnreachable
|
resp := hostUnreachable
|
||||||
@ -185,16 +195,16 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, req *Req
|
|||||||
} else if strings.Contains(msg, "network is unreachable") {
|
} else if strings.Contains(msg, "network is unreachable") {
|
||||||
resp = networkUnreachable
|
resp = networkUnreachable
|
||||||
}
|
}
|
||||||
if err := sendReply(writer, req.Header, resp); err != nil {
|
if err := SendReply(writer, req.Header, resp); err != nil {
|
||||||
return fmt.Errorf("failed to send reply, %v", err)
|
return fmt.Errorf("failed to send reply, %v", err)
|
||||||
}
|
}
|
||||||
return fmt.Errorf("connect to %v failed, %v", req.DestAddr, err)
|
return fmt.Errorf("connect to %v failed, %v", req.RawDestAddr, err)
|
||||||
}
|
}
|
||||||
defer target.Close()
|
defer target.Close()
|
||||||
|
|
||||||
targetUdp, ok := target.(*net.UDPConn)
|
targetUDP, ok := target.(*net.UDPConn)
|
||||||
if !ok {
|
if !ok {
|
||||||
if err := sendReply(writer, req.Header, serverFailure); err != nil {
|
if err := SendReply(writer, req.Header, serverFailure); err != nil {
|
||||||
return fmt.Errorf("failed to send reply, %v", err)
|
return fmt.Errorf("failed to send reply, %v", err)
|
||||||
}
|
}
|
||||||
return fmt.Errorf("dial udp invalid")
|
return fmt.Errorf("dial udp invalid")
|
||||||
@ -202,16 +212,16 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, req *Req
|
|||||||
|
|
||||||
bindLn, err := net.ListenUDP("udp", nil)
|
bindLn, err := net.ListenUDP("udp", nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err := sendReply(writer, req.Header, serverFailure); err != nil {
|
if err := SendReply(writer, req.Header, serverFailure); err != nil {
|
||||||
return fmt.Errorf("failed to send reply, %v", err)
|
return fmt.Errorf("failed to send reply, %v", err)
|
||||||
}
|
}
|
||||||
return fmt.Errorf("listen udp failed, %v", err)
|
return fmt.Errorf("listen udp failed, %v", err)
|
||||||
}
|
}
|
||||||
defer bindLn.Close()
|
defer bindLn.Close()
|
||||||
|
|
||||||
s.logger.Errorf("target addr %v, listen addr: %s", targetUdp.RemoteAddr(), bindLn.LocalAddr())
|
s.logger.Errorf("target addr %v, listen addr: %s", targetUDP.RemoteAddr(), bindLn.LocalAddr())
|
||||||
// send BND.ADDR and BND.PORT, client must
|
// send BND.ADDR and BND.PORT, client must
|
||||||
if err = sendReply(writer, req.Header, successReply, bindLn.LocalAddr()); err != nil {
|
if err = SendReply(writer, req.Header, successReply, bindLn.LocalAddr()); err != nil {
|
||||||
return fmt.Errorf("failed to send reply, %v", err)
|
return fmt.Errorf("failed to send reply, %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -228,7 +238,7 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, req *Req
|
|||||||
conns := sync.Map{}
|
conns := sync.Map{}
|
||||||
bufPool := s.bufferPool.Get()
|
bufPool := s.bufferPool.Get()
|
||||||
defer func() {
|
defer func() {
|
||||||
targetUdp.Close()
|
targetUDP.Close()
|
||||||
bindLn.Close()
|
bindLn.Close()
|
||||||
s.bufferPool.Put(bufPool)
|
s.bufferPool.Put(bufPool)
|
||||||
}()
|
}()
|
||||||
@ -278,8 +288,8 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, req *Req
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 把消息写给remote sever
|
// 把消息写给remote sever
|
||||||
if _, err := targetUdp.Write(buf[headLen:n]); err != nil {
|
if _, err := targetUDP.Write(buf[headLen:n]); err != nil {
|
||||||
s.logger.Errorf("write data to remote %s failed, %v", targetUdp.RemoteAddr(), err)
|
s.logger.Errorf("write data to remote %s failed, %v", targetUDP.RemoteAddr(), err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -288,16 +298,16 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, req *Req
|
|||||||
// read from remote server and write to client
|
// read from remote server and write to client
|
||||||
bufPool := s.bufferPool.Get()
|
bufPool := s.bufferPool.Get()
|
||||||
defer func() {
|
defer func() {
|
||||||
targetUdp.Close()
|
targetUDP.Close()
|
||||||
bindLn.Close()
|
bindLn.Close()
|
||||||
s.bufferPool.Put(bufPool)
|
s.bufferPool.Put(bufPool)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
buf := bufPool[:cap(bufPool)]
|
buf := bufPool[:cap(bufPool)]
|
||||||
n, remote, err := targetUdp.ReadFrom(buf)
|
n, remote, err := targetUDP.ReadFrom(buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.logger.Errorf("read data from remote %s failed, %v", targetUdp.RemoteAddr(), err)
|
s.logger.Errorf("read data from remote %s failed, %v", targetUDP.RemoteAddr(), err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -334,15 +344,15 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, req *Req
|
|||||||
s.bufferPool.Put(buf)
|
s.bufferPool.Put(buf)
|
||||||
}()
|
}()
|
||||||
for {
|
for {
|
||||||
_, err := req.bufConn.Read(buf[:cap(buf)])
|
_, err := req.Reader.Read(buf[:cap(buf)])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// sendReply is used to send a reply message
|
// SendReply is used to send a reply message
|
||||||
func sendReply(w io.Writer, head Header, resp uint8, bindAddr ...net.Addr) error {
|
func SendReply(w io.Writer, head Header, resp uint8, bindAddr ...net.Addr) error {
|
||||||
/*
|
/*
|
||||||
The SOCKS response is formed as follows:
|
The SOCKS response is formed as follows:
|
||||||
+----+-----+-------+------+----------+----------+
|
+----+-----+-------+------+----------+----------+
|
||||||
@ -397,9 +407,9 @@ type closeWriter interface {
|
|||||||
CloseWrite() error
|
CloseWrite() error
|
||||||
}
|
}
|
||||||
|
|
||||||
// proxy is used to suffle data from src to destination, and sends errors
|
// Proxy is used to suffle data from src to destination, and sends errors
|
||||||
// down a dedicated channel
|
// down a dedicated channel
|
||||||
func (s *Server) proxy(dst io.Writer, src io.Reader) error {
|
func (s *Server) Proxy(dst io.Writer, src io.Reader) error {
|
||||||
buf := s.bufferPool.Get()
|
buf := s.bufferPool.Get()
|
||||||
defer s.bufferPool.Put(buf)
|
defer s.bufferPool.Put(buf)
|
||||||
_, err := io.CopyBuffer(dst, src, buf[:cap(buf)])
|
_, err := io.CopyBuffer(dst, src, buf[:cap(buf)])
|
||||||
|
@ -13,6 +13,7 @@ type NameResolver interface {
|
|||||||
// DNSResolver uses the system DNS to resolve host names
|
// DNSResolver uses the system DNS to resolve host names
|
||||||
type DNSResolver struct{}
|
type DNSResolver struct{}
|
||||||
|
|
||||||
|
// Resolve implement interface NameResolver
|
||||||
func (d DNSResolver) Resolve(ctx context.Context, name string) (context.Context, net.IP, error) {
|
func (d DNSResolver) Resolve(ctx context.Context, name string) (context.Context, net.IP, error) {
|
||||||
addr, err := net.ResolveIPAddr("ip", name)
|
addr, err := net.ResolveIPAddr("ip", name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -27,6 +27,7 @@ type PermitCommand struct {
|
|||||||
EnableAssociate bool
|
EnableAssociate bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Allow implement interface RuleSet
|
||||||
func (p *PermitCommand) Allow(ctx context.Context, req *Request) (context.Context, bool) {
|
func (p *PermitCommand) Allow(ctx context.Context, req *Request) (context.Context, bool) {
|
||||||
switch req.Command {
|
switch req.Command {
|
||||||
case CommandConnect:
|
case CommandConnect:
|
||||||
|
@ -4,6 +4,7 @@ import (
|
|||||||
"bufio"
|
"bufio"
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
@ -47,6 +48,10 @@ type Server struct {
|
|||||||
bufferPool *pool
|
bufferPool *pool
|
||||||
// goroutine pool
|
// goroutine pool
|
||||||
gPool GPool
|
gPool GPool
|
||||||
|
// user's handle
|
||||||
|
userConnectHandle func(ctx context.Context, writer io.Writer, req *Request) error
|
||||||
|
userBindHandle func(ctx context.Context, writer io.Writer, req *Request) error
|
||||||
|
userAssociateHandle func(ctx context.Context, writer io.Writer, req *Request) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates a new Server and potentially returns an error
|
// New creates a new Server and potentially returns an error
|
||||||
@ -131,8 +136,8 @@ func (s *Server) ServeConn(conn net.Conn) (err error) {
|
|||||||
|
|
||||||
request, err := NewRequest(bufConn)
|
request, err := NewRequest(bufConn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == unrecognizedAddrType {
|
if err == errUnrecognizedAddrType {
|
||||||
if err := sendReply(conn, Header{Version: version[0]}, addrTypeNotSupported); err != nil {
|
if err := SendReply(conn, Header{Version: version[0]}, addrTypeNotSupported); err != nil {
|
||||||
return fmt.Errorf("failed to send reply, %v", err)
|
return fmt.Errorf("failed to send reply, %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user