add user handle

revive inspect
This commit is contained in:
mo 2020-04-22 10:15:40 +08:00
parent 8b383556a2
commit f77e659826
10 changed files with 122 additions and 65 deletions

17
auth.go
View File

@ -5,6 +5,7 @@ import (
"io"
)
// auth defined
const (
MethodNoAuth = uint8(0)
MethodGSSAPI = uint8(1)
@ -15,12 +16,13 @@ const (
AuthFailure = uint8(1)
)
// auth error defined
var (
UserAuthFailed = fmt.Errorf("User authentication failed")
NoSupportedAuth = fmt.Errorf("No supported authentication mechanism")
ErrUserAuthFailed = fmt.Errorf("User authentication failed")
ErrNoSupportedAuth = fmt.Errorf("No supported authentication mechanism")
)
// A Request encapsulates authentication state provided
// AuthContext A Request encapsulates authentication state provided
// during negotiation
type AuthContext struct {
// Provided auth method
@ -31,6 +33,7 @@ type AuthContext struct {
Payload map[string]string
}
// Authenticator provide auth
type Authenticator interface {
Authenticate(reader io.Reader, writer io.Writer, userAddr string) (*AuthContext, error)
GetCode() uint8
@ -39,10 +42,12 @@ type Authenticator interface {
// NoAuthAuthenticator is used to handle the "No Authentication" mode
type NoAuthAuthenticator struct{}
// GetCode implement interface Authenticator
func (a NoAuthAuthenticator) GetCode() uint8 {
return MethodNoAuth
}
// Authenticate implement interface Authenticator
func (a NoAuthAuthenticator) Authenticate(reader io.Reader, writer io.Writer, userAddr string) (*AuthContext, error) {
_, err := writer.Write([]byte{VersionSocks5, MethodNoAuth})
return &AuthContext{MethodNoAuth, nil}, err
@ -54,10 +59,12 @@ type UserPassAuthenticator struct {
Credentials CredentialStore
}
// GetCode implement interface Authenticator
func (a UserPassAuthenticator) GetCode() uint8 {
return MethodUserPassAuth
}
// Authenticate implement interface Authenticator
func (a UserPassAuthenticator) Authenticate(reader io.Reader, writer io.Writer, userAddr string) (*AuthContext, error) {
// Tell the client to use user/pass auth
if _, err := writer.Write([]byte{VersionSocks5, MethodUserPassAuth}); err != nil {
@ -103,7 +110,7 @@ func (a UserPassAuthenticator) Authenticate(reader io.Reader, writer io.Writer,
if _, err := writer.Write([]byte{UserPassAuthVersion, AuthFailure}); err != nil {
return nil, err
}
return nil, UserAuthFailed
return nil, ErrUserAuthFailed
}
// Done
@ -134,7 +141,7 @@ func (s *Server) authenticate(conn io.Writer, bufConn io.Reader, userAddr string
// authentication mechanism
func noAcceptableAuth(conn io.Writer) error {
conn.Write([]byte{VersionSocks5, MethodNoAcceptable})
return NoSupportedAuth
return ErrNoSupportedAuth
}
// readMethods is used to read the number of methods

View File

@ -77,7 +77,7 @@ func TestPasswordAuth_Invalid(t *testing.T) {
s := New(WithAuthMethods([]Authenticator{cator}))
ctx, err := s.authenticate(&resp, req, "")
if err != UserAuthFailed {
if err != ErrUserAuthFailed {
t.Fatalf("err: %v", err)
}
@ -104,7 +104,7 @@ func TestNoSupportedAuth(t *testing.T) {
s := New(WithAuthMethods([]Authenticator{cator}))
ctx, err := s.authenticate(&resp, req, "")
if err != NoSupportedAuth {
if err != ErrNoSupportedAuth {
t.Fatalf("err: %v", err)
}

View File

@ -9,6 +9,7 @@ type CredentialStore interface {
// StaticCredentials enables using a map directly as a credential store
type StaticCredentials map[string]string
// Valid implement interface CredentialStore
func (s StaticCredentials) Valid(user, password, userAddr string) bool {
pass, ok := s[user]
if !ok {

View File

@ -1,12 +1,14 @@
package socks5
import (
"bytes"
"fmt"
"io"
"net"
"strconv"
)
// socks const defined
const (
// protocol version
VersionSocks4 = uint8(4)
@ -83,6 +85,7 @@ type Header struct {
addrType uint8
}
// Parse to header
func Parse(r io.Reader) (hd Header, err error) {
// Read the version and command
tmp := make([]byte, headVERLen+headCMDLen)
@ -142,13 +145,15 @@ func Parse(r io.Reader) (hd Header, err error) {
hd.Address.IP = addr[:net.IPv6len]
hd.Address.Port = buildPort(addr[net.IPv6len], addr[net.IPv6len+1])
default:
return hd, unrecognizedAddrType
return hd, errUnrecognizedAddrType
}
}
return hd, nil
}
// Bytes returns a slice of header
func (h Header) Bytes() (b []byte) {
bytes.Buffer{}.Bytes()
b = append(b, h.Version)
b = append(b, h.Command)
hiPort, loPort := breakPort(h.Address.Port)

View File

@ -4,19 +4,22 @@ import (
"log"
)
// Logger is used to provide debug logger
type Logger interface {
Errorf(format string, arg ...interface{})
}
// 标准输出
// Std std logger
type Std struct {
*log.Logger
}
// NewLogger new std logger with log.logger
func NewLogger(l *log.Logger) *Std {
return &Std{l}
}
// Errorf implement interface Logger
func (sf Std) Errorf(format string, args ...interface{}) {
sf.Logger.Printf("[E]: "+format, args...)
}

View File

@ -2,12 +2,14 @@ package socks5
import (
"context"
"io"
"net"
)
// Option user's option
type Option func(s *Server)
// AuthMethods can be provided to implement custom authentication
// WithAuthMethods can be provided to implement custom authentication
// By default, "auth-less" mode is enabled.
// For password-based auth use UserPassAuthenticator.
func WithAuthMethods(authMethods []Authenticator) Option {
@ -19,7 +21,7 @@ func WithAuthMethods(authMethods []Authenticator) Option {
}
}
// If provided, username/password authentication is enabled,
// WithCredential If provided, username/password authentication is enabled,
// by appending a UserPassAuthenticator to AuthMethods. If not provided,
// and AUthMethods is nil, then "auth-less" mode is enabled.
func WithCredential(cs CredentialStore) Option {
@ -30,7 +32,7 @@ func WithCredential(cs CredentialStore) Option {
}
}
// resolver can be provided to do custom name resolution.
// WithResolver can be provided to do custom name resolution.
// Defaults to DNSResolver if not provided.
func WithResolver(res NameResolver) Option {
return func(s *Server) {
@ -40,7 +42,7 @@ func WithResolver(res NameResolver) Option {
}
}
// rules is provided to enable custom logic around permitting
// WithRule is provided to enable custom logic around permitting
// various commands. If not provided, PermitAll is used.
func WithRule(rule RuleSet) Option {
return func(s *Server) {
@ -50,7 +52,7 @@ func WithRule(rule RuleSet) Option {
}
}
// rewriter can be used to transparently rewrite addresses.
// WithRewriter can be used to transparently rewrite addresses.
// This is invoked before the RuleSet is invoked.
// Defaults to NoRewrite.
func WithRewriter(rew AddressRewriter) Option {
@ -61,7 +63,7 @@ func WithRewriter(rew AddressRewriter) Option {
}
}
// bindIP is used for bind or udp associate
// WithBindIP is used for bind or udp associate
func WithBindIP(ip net.IP) Option {
return func(s *Server) {
if len(ip) != 0 {
@ -71,7 +73,7 @@ func WithBindIP(ip net.IP) Option {
}
}
// logger can be used to provide a custom log target.
// WithLogger can be used to provide a custom log target.
// Defaults to ioutil.Discard.
func WithLogger(l Logger) Option {
return func(s *Server) {
@ -81,7 +83,7 @@ func WithLogger(l Logger) Option {
}
}
// Optional function for dialing out
// WithDial Optional function for dialing out
func WithDial(dial func(ctx context.Context, network, addr string) (net.Conn, error)) Option {
return func(s *Server) {
if dial != nil {
@ -90,6 +92,7 @@ func WithDial(dial func(ctx context.Context, network, addr string) (net.Conn, er
}
}
// WithGPool can be provided to do custom goroutine pool.
func WithGPool(pool GPool) Option {
return func(s *Server) {
if pool != nil {
@ -97,3 +100,24 @@ func WithGPool(pool GPool) Option {
}
}
}
// WithConnectHandle is used to handle a user's connect command
func WithConnectHandle(h func(ctx context.Context, writer io.Writer, req *Request) error) Option {
return func(s *Server) {
s.userConnectHandle = h
}
}
// WithBindHandle is used to handle a user's bind command
func WithBindHandle(h func(ctx context.Context, writer io.Writer, req *Request) error) Option {
return func(s *Server) {
s.userBindHandle = h
}
}
// WithAssociateHandle is used to handle a user's associate command
func WithAssociateHandle(h func(ctx context.Context, writer io.Writer, req *Request) error) Option {
return func(s *Server) {
s.userAssociateHandle = h
}
}

View File

@ -10,7 +10,7 @@ import (
)
var (
unrecognizedAddrType = fmt.Errorf("Unrecognized address type")
errUnrecognizedAddrType = fmt.Errorf("Unrecognized address type")
)
// AddressRewriter is used to rewrite a destination transparently
@ -25,11 +25,12 @@ type Request struct {
AuthContext *AuthContext
// AddrSpec of the the network that sent the request
RemoteAddr *AddrSpec
// AddrSpec of the desired destination
DestAddr *AddrSpec
// AddrSpec of the actual destination (might be affected by rewrite)
realDestAddr *AddrSpec
bufConn io.Reader
DestAddr *AddrSpec
// Reader connect of request
Reader io.Reader
// AddrSpec of the desired destination
RawDestAddr *AddrSpec
}
type conn interface {
@ -55,9 +56,9 @@ func NewRequest(bufConn io.Reader) (*Request, error) {
return nil, fmt.Errorf("unrecognized command[%d]", hd.Command)
}
return &Request{
Header: hd,
DestAddr: &hd.Address,
bufConn: bufConn,
Header: hd,
RawDestAddr: &hd.Address,
Reader: bufConn,
}, nil
}
@ -66,45 +67,54 @@ func (s *Server) handleRequest(write io.Writer, req *Request) error {
ctx := context.Background()
// Resolve the address if we have a FQDN
dest := req.DestAddr
dest := req.RawDestAddr
if dest.FQDN != "" {
ctx_, addr, err := s.resolver.Resolve(ctx, dest.FQDN)
_ctx, addr, err := s.resolver.Resolve(ctx, dest.FQDN)
if err != nil {
if err := sendReply(write, req.Header, hostUnreachable); err != nil {
if err := SendReply(write, req.Header, hostUnreachable); err != nil {
return fmt.Errorf("failed to send reply, %v", err)
}
return fmt.Errorf("failed to resolve destination[%v], %v", dest.FQDN, err)
}
ctx = ctx_
ctx = _ctx
dest.IP = addr
}
// Apply any address rewrites
req.realDestAddr = req.DestAddr
req.DestAddr = req.RawDestAddr
if s.rewriter != nil {
ctx, req.realDestAddr = s.rewriter.Rewrite(ctx, req)
ctx, req.DestAddr = s.rewriter.Rewrite(ctx, req)
}
// Check if this is allowed
if ctx_, ok := s.rules.Allow(ctx, req); !ok {
if err := sendReply(write, req.Header, ruleFailure); err != nil {
_ctx, ok := s.rules.Allow(ctx, req)
if !ok {
if err := SendReply(write, req.Header, ruleFailure); err != nil {
return fmt.Errorf("failed to send reply, %v", err)
}
return fmt.Errorf("bind to %v blocked by rules", req.DestAddr)
} else {
ctx = ctx_
return fmt.Errorf("bind to %v blocked by rules", req.RawDestAddr)
}
ctx = _ctx
// Switch on the command
switch req.Command {
case CommandConnect:
if s.userConnectHandle != nil {
return s.userConnectHandle(ctx, write, req)
}
return s.handleConnect(ctx, write, req)
case CommandBind:
if s.userBindHandle != nil {
return s.userBindHandle(ctx, write, req)
}
return s.handleBind(ctx, write, req)
case CommandAssociate:
if s.userAssociateHandle != nil {
return s.userAssociateHandle(ctx, write, req)
}
return s.handleAssociate(ctx, write, req)
default:
if err := sendReply(write, req.Header, commandNotSupported); err != nil {
if err := SendReply(write, req.Header, commandNotSupported); err != nil {
return fmt.Errorf("failed to send reply, %v", err)
}
return fmt.Errorf("unsupported command[%v]", req.Command)
@ -120,7 +130,7 @@ func (s *Server) handleConnect(ctx context.Context, writer io.Writer, req *Reque
return net.Dial(net_, addr)
}
}
target, err := dial(ctx, "tcp", req.realDestAddr.Address())
target, err := dial(ctx, "tcp", req.DestAddr.Address())
if err != nil {
msg := err.Error()
resp := hostUnreachable
@ -129,23 +139,23 @@ func (s *Server) handleConnect(ctx context.Context, writer io.Writer, req *Reque
} else if strings.Contains(msg, "network is unreachable") {
resp = networkUnreachable
}
if err := sendReply(writer, req.Header, resp); err != nil {
if err := SendReply(writer, req.Header, resp); err != nil {
return fmt.Errorf("failed to send reply, %v", err)
}
return fmt.Errorf("connect to %v failed, %v", req.DestAddr, err)
return fmt.Errorf("connect to %v failed, %v", req.RawDestAddr, err)
}
defer target.Close()
// Send success
if err := sendReply(writer, req.Header, successReply, target.LocalAddr()); err != nil {
if err := SendReply(writer, req.Header, successReply, target.LocalAddr()); err != nil {
return fmt.Errorf("failed to send reply, %v", err)
}
// Start proxying
errCh := make(chan error, 2)
s.submit(func() { errCh <- s.proxy(target, req.bufConn) })
s.submit(func() { errCh <- s.proxy(writer, target) })
s.submit(func() { errCh <- s.Proxy(target, req.Reader) })
s.submit(func() { errCh <- s.Proxy(writer, target) })
// Wait
for i := 0; i < 2; i++ {
@ -161,7 +171,7 @@ func (s *Server) handleConnect(ctx context.Context, writer io.Writer, req *Reque
// handleBind is used to handle a connect command
func (s *Server) handleBind(ctx context.Context, writer io.Writer, req *Request) error {
// TODO: Support bind
if err := sendReply(writer, req.Header, commandNotSupported); err != nil {
if err := SendReply(writer, req.Header, commandNotSupported); err != nil {
return fmt.Errorf("failed to send reply: %v", err)
}
return nil
@ -176,7 +186,7 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, req *Req
return net.Dial(net_, addr)
}
}
target, err := dial(ctx, "udp", req.realDestAddr.Address())
target, err := dial(ctx, "udp", req.DestAddr.Address())
if err != nil {
msg := err.Error()
resp := hostUnreachable
@ -185,16 +195,16 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, req *Req
} else if strings.Contains(msg, "network is unreachable") {
resp = networkUnreachable
}
if err := sendReply(writer, req.Header, resp); err != nil {
if err := SendReply(writer, req.Header, resp); err != nil {
return fmt.Errorf("failed to send reply, %v", err)
}
return fmt.Errorf("connect to %v failed, %v", req.DestAddr, err)
return fmt.Errorf("connect to %v failed, %v", req.RawDestAddr, err)
}
defer target.Close()
targetUdp, ok := target.(*net.UDPConn)
targetUDP, ok := target.(*net.UDPConn)
if !ok {
if err := sendReply(writer, req.Header, serverFailure); err != nil {
if err := SendReply(writer, req.Header, serverFailure); err != nil {
return fmt.Errorf("failed to send reply, %v", err)
}
return fmt.Errorf("dial udp invalid")
@ -202,16 +212,16 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, req *Req
bindLn, err := net.ListenUDP("udp", nil)
if err != nil {
if err := sendReply(writer, req.Header, serverFailure); err != nil {
if err := SendReply(writer, req.Header, serverFailure); err != nil {
return fmt.Errorf("failed to send reply, %v", err)
}
return fmt.Errorf("listen udp failed, %v", err)
}
defer bindLn.Close()
s.logger.Errorf("target addr %v, listen addr: %s", targetUdp.RemoteAddr(), bindLn.LocalAddr())
s.logger.Errorf("target addr %v, listen addr: %s", targetUDP.RemoteAddr(), bindLn.LocalAddr())
// send BND.ADDR and BND.PORT, client must
if err = sendReply(writer, req.Header, successReply, bindLn.LocalAddr()); err != nil {
if err = SendReply(writer, req.Header, successReply, bindLn.LocalAddr()); err != nil {
return fmt.Errorf("failed to send reply, %v", err)
}
@ -228,7 +238,7 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, req *Req
conns := sync.Map{}
bufPool := s.bufferPool.Get()
defer func() {
targetUdp.Close()
targetUDP.Close()
bindLn.Close()
s.bufferPool.Put(bufPool)
}()
@ -278,8 +288,8 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, req *Req
}
// 把消息写给remote sever
if _, err := targetUdp.Write(buf[headLen:n]); err != nil {
s.logger.Errorf("write data to remote %s failed, %v", targetUdp.RemoteAddr(), err)
if _, err := targetUDP.Write(buf[headLen:n]); err != nil {
s.logger.Errorf("write data to remote %s failed, %v", targetUDP.RemoteAddr(), err)
return
}
@ -288,16 +298,16 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, req *Req
// read from remote server and write to client
bufPool := s.bufferPool.Get()
defer func() {
targetUdp.Close()
targetUDP.Close()
bindLn.Close()
s.bufferPool.Put(bufPool)
}()
for {
buf := bufPool[:cap(bufPool)]
n, remote, err := targetUdp.ReadFrom(buf)
n, remote, err := targetUDP.ReadFrom(buf)
if err != nil {
s.logger.Errorf("read data from remote %s failed, %v", targetUdp.RemoteAddr(), err)
s.logger.Errorf("read data from remote %s failed, %v", targetUDP.RemoteAddr(), err)
return
}
@ -334,15 +344,15 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, req *Req
s.bufferPool.Put(buf)
}()
for {
_, err := req.bufConn.Read(buf[:cap(buf)])
_, err := req.Reader.Read(buf[:cap(buf)])
if err != nil {
return err
}
}
}
// sendReply is used to send a reply message
func sendReply(w io.Writer, head Header, resp uint8, bindAddr ...net.Addr) error {
// SendReply is used to send a reply message
func SendReply(w io.Writer, head Header, resp uint8, bindAddr ...net.Addr) error {
/*
The SOCKS response is formed as follows:
+----+-----+-------+------+----------+----------+
@ -397,9 +407,9 @@ type closeWriter interface {
CloseWrite() error
}
// proxy is used to suffle data from src to destination, and sends errors
// Proxy is used to suffle data from src to destination, and sends errors
// down a dedicated channel
func (s *Server) proxy(dst io.Writer, src io.Reader) error {
func (s *Server) Proxy(dst io.Writer, src io.Reader) error {
buf := s.bufferPool.Get()
defer s.bufferPool.Put(buf)
_, err := io.CopyBuffer(dst, src, buf[:cap(buf)])

View File

@ -13,6 +13,7 @@ type NameResolver interface {
// DNSResolver uses the system DNS to resolve host names
type DNSResolver struct{}
// Resolve implement interface NameResolver
func (d DNSResolver) Resolve(ctx context.Context, name string) (context.Context, net.IP, error) {
addr, err := net.ResolveIPAddr("ip", name)
if err != nil {

View File

@ -27,6 +27,7 @@ type PermitCommand struct {
EnableAssociate bool
}
// Allow implement interface RuleSet
func (p *PermitCommand) Allow(ctx context.Context, req *Request) (context.Context, bool) {
switch req.Command {
case CommandConnect:

View File

@ -4,6 +4,7 @@ import (
"bufio"
"context"
"fmt"
"io"
"io/ioutil"
"log"
"net"
@ -47,6 +48,10 @@ type Server struct {
bufferPool *pool
// goroutine pool
gPool GPool
// user's handle
userConnectHandle func(ctx context.Context, writer io.Writer, req *Request) error
userBindHandle func(ctx context.Context, writer io.Writer, req *Request) error
userAssociateHandle func(ctx context.Context, writer io.Writer, req *Request) error
}
// New creates a new Server and potentially returns an error
@ -131,8 +136,8 @@ func (s *Server) ServeConn(conn net.Conn) (err error) {
request, err := NewRequest(bufConn)
if err != nil {
if err == unrecognizedAddrType {
if err := sendReply(conn, Header{Version: version[0]}, addrTypeNotSupported); err != nil {
if err == errUnrecognizedAddrType {
if err := SendReply(conn, Header{Version: version[0]}, addrTypeNotSupported); err != nil {
return fmt.Errorf("failed to send reply, %v", err)
}
}