add user handle
revive inspect
This commit is contained in:
parent
8b383556a2
commit
f77e659826
17
auth.go
17
auth.go
@ -5,6 +5,7 @@ import (
|
||||
"io"
|
||||
)
|
||||
|
||||
// auth defined
|
||||
const (
|
||||
MethodNoAuth = uint8(0)
|
||||
MethodGSSAPI = uint8(1)
|
||||
@ -15,12 +16,13 @@ const (
|
||||
AuthFailure = uint8(1)
|
||||
)
|
||||
|
||||
// auth error defined
|
||||
var (
|
||||
UserAuthFailed = fmt.Errorf("User authentication failed")
|
||||
NoSupportedAuth = fmt.Errorf("No supported authentication mechanism")
|
||||
ErrUserAuthFailed = fmt.Errorf("User authentication failed")
|
||||
ErrNoSupportedAuth = fmt.Errorf("No supported authentication mechanism")
|
||||
)
|
||||
|
||||
// A Request encapsulates authentication state provided
|
||||
// AuthContext A Request encapsulates authentication state provided
|
||||
// during negotiation
|
||||
type AuthContext struct {
|
||||
// Provided auth method
|
||||
@ -31,6 +33,7 @@ type AuthContext struct {
|
||||
Payload map[string]string
|
||||
}
|
||||
|
||||
// Authenticator provide auth
|
||||
type Authenticator interface {
|
||||
Authenticate(reader io.Reader, writer io.Writer, userAddr string) (*AuthContext, error)
|
||||
GetCode() uint8
|
||||
@ -39,10 +42,12 @@ type Authenticator interface {
|
||||
// NoAuthAuthenticator is used to handle the "No Authentication" mode
|
||||
type NoAuthAuthenticator struct{}
|
||||
|
||||
// GetCode implement interface Authenticator
|
||||
func (a NoAuthAuthenticator) GetCode() uint8 {
|
||||
return MethodNoAuth
|
||||
}
|
||||
|
||||
// Authenticate implement interface Authenticator
|
||||
func (a NoAuthAuthenticator) Authenticate(reader io.Reader, writer io.Writer, userAddr string) (*AuthContext, error) {
|
||||
_, err := writer.Write([]byte{VersionSocks5, MethodNoAuth})
|
||||
return &AuthContext{MethodNoAuth, nil}, err
|
||||
@ -54,10 +59,12 @@ type UserPassAuthenticator struct {
|
||||
Credentials CredentialStore
|
||||
}
|
||||
|
||||
// GetCode implement interface Authenticator
|
||||
func (a UserPassAuthenticator) GetCode() uint8 {
|
||||
return MethodUserPassAuth
|
||||
}
|
||||
|
||||
// Authenticate implement interface Authenticator
|
||||
func (a UserPassAuthenticator) Authenticate(reader io.Reader, writer io.Writer, userAddr string) (*AuthContext, error) {
|
||||
// Tell the client to use user/pass auth
|
||||
if _, err := writer.Write([]byte{VersionSocks5, MethodUserPassAuth}); err != nil {
|
||||
@ -103,7 +110,7 @@ func (a UserPassAuthenticator) Authenticate(reader io.Reader, writer io.Writer,
|
||||
if _, err := writer.Write([]byte{UserPassAuthVersion, AuthFailure}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return nil, UserAuthFailed
|
||||
return nil, ErrUserAuthFailed
|
||||
}
|
||||
|
||||
// Done
|
||||
@ -134,7 +141,7 @@ func (s *Server) authenticate(conn io.Writer, bufConn io.Reader, userAddr string
|
||||
// authentication mechanism
|
||||
func noAcceptableAuth(conn io.Writer) error {
|
||||
conn.Write([]byte{VersionSocks5, MethodNoAcceptable})
|
||||
return NoSupportedAuth
|
||||
return ErrNoSupportedAuth
|
||||
}
|
||||
|
||||
// readMethods is used to read the number of methods
|
||||
|
@ -77,7 +77,7 @@ func TestPasswordAuth_Invalid(t *testing.T) {
|
||||
s := New(WithAuthMethods([]Authenticator{cator}))
|
||||
|
||||
ctx, err := s.authenticate(&resp, req, "")
|
||||
if err != UserAuthFailed {
|
||||
if err != ErrUserAuthFailed {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
@ -104,7 +104,7 @@ func TestNoSupportedAuth(t *testing.T) {
|
||||
s := New(WithAuthMethods([]Authenticator{cator}))
|
||||
|
||||
ctx, err := s.authenticate(&resp, req, "")
|
||||
if err != NoSupportedAuth {
|
||||
if err != ErrNoSupportedAuth {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
|
@ -9,6 +9,7 @@ type CredentialStore interface {
|
||||
// StaticCredentials enables using a map directly as a credential store
|
||||
type StaticCredentials map[string]string
|
||||
|
||||
// Valid implement interface CredentialStore
|
||||
func (s StaticCredentials) Valid(user, password, userAddr string) bool {
|
||||
pass, ok := s[user]
|
||||
if !ok {
|
||||
|
@ -1,12 +1,14 @@
|
||||
package socks5
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// socks const defined
|
||||
const (
|
||||
// protocol version
|
||||
VersionSocks4 = uint8(4)
|
||||
@ -83,6 +85,7 @@ type Header struct {
|
||||
addrType uint8
|
||||
}
|
||||
|
||||
// Parse to header
|
||||
func Parse(r io.Reader) (hd Header, err error) {
|
||||
// Read the version and command
|
||||
tmp := make([]byte, headVERLen+headCMDLen)
|
||||
@ -142,13 +145,15 @@ func Parse(r io.Reader) (hd Header, err error) {
|
||||
hd.Address.IP = addr[:net.IPv6len]
|
||||
hd.Address.Port = buildPort(addr[net.IPv6len], addr[net.IPv6len+1])
|
||||
default:
|
||||
return hd, unrecognizedAddrType
|
||||
return hd, errUnrecognizedAddrType
|
||||
}
|
||||
}
|
||||
return hd, nil
|
||||
}
|
||||
|
||||
// Bytes returns a slice of header
|
||||
func (h Header) Bytes() (b []byte) {
|
||||
bytes.Buffer{}.Bytes()
|
||||
b = append(b, h.Version)
|
||||
b = append(b, h.Command)
|
||||
hiPort, loPort := breakPort(h.Address.Port)
|
||||
|
@ -4,19 +4,22 @@ import (
|
||||
"log"
|
||||
)
|
||||
|
||||
// Logger is used to provide debug logger
|
||||
type Logger interface {
|
||||
Errorf(format string, arg ...interface{})
|
||||
}
|
||||
|
||||
// 标准输出
|
||||
// Std std logger
|
||||
type Std struct {
|
||||
*log.Logger
|
||||
}
|
||||
|
||||
// NewLogger new std logger with log.logger
|
||||
func NewLogger(l *log.Logger) *Std {
|
||||
return &Std{l}
|
||||
}
|
||||
|
||||
// Errorf implement interface Logger
|
||||
func (sf Std) Errorf(format string, args ...interface{}) {
|
||||
sf.Logger.Printf("[E]: "+format, args...)
|
||||
}
|
||||
|
40
option.go
40
option.go
@ -2,12 +2,14 @@ package socks5
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
)
|
||||
|
||||
// Option user's option
|
||||
type Option func(s *Server)
|
||||
|
||||
// AuthMethods can be provided to implement custom authentication
|
||||
// WithAuthMethods can be provided to implement custom authentication
|
||||
// By default, "auth-less" mode is enabled.
|
||||
// For password-based auth use UserPassAuthenticator.
|
||||
func WithAuthMethods(authMethods []Authenticator) Option {
|
||||
@ -19,7 +21,7 @@ func WithAuthMethods(authMethods []Authenticator) Option {
|
||||
}
|
||||
}
|
||||
|
||||
// If provided, username/password authentication is enabled,
|
||||
// WithCredential If provided, username/password authentication is enabled,
|
||||
// by appending a UserPassAuthenticator to AuthMethods. If not provided,
|
||||
// and AUthMethods is nil, then "auth-less" mode is enabled.
|
||||
func WithCredential(cs CredentialStore) Option {
|
||||
@ -30,7 +32,7 @@ func WithCredential(cs CredentialStore) Option {
|
||||
}
|
||||
}
|
||||
|
||||
// resolver can be provided to do custom name resolution.
|
||||
// WithResolver can be provided to do custom name resolution.
|
||||
// Defaults to DNSResolver if not provided.
|
||||
func WithResolver(res NameResolver) Option {
|
||||
return func(s *Server) {
|
||||
@ -40,7 +42,7 @@ func WithResolver(res NameResolver) Option {
|
||||
}
|
||||
}
|
||||
|
||||
// rules is provided to enable custom logic around permitting
|
||||
// WithRule is provided to enable custom logic around permitting
|
||||
// various commands. If not provided, PermitAll is used.
|
||||
func WithRule(rule RuleSet) Option {
|
||||
return func(s *Server) {
|
||||
@ -50,7 +52,7 @@ func WithRule(rule RuleSet) Option {
|
||||
}
|
||||
}
|
||||
|
||||
// rewriter can be used to transparently rewrite addresses.
|
||||
// WithRewriter can be used to transparently rewrite addresses.
|
||||
// This is invoked before the RuleSet is invoked.
|
||||
// Defaults to NoRewrite.
|
||||
func WithRewriter(rew AddressRewriter) Option {
|
||||
@ -61,7 +63,7 @@ func WithRewriter(rew AddressRewriter) Option {
|
||||
}
|
||||
}
|
||||
|
||||
// bindIP is used for bind or udp associate
|
||||
// WithBindIP is used for bind or udp associate
|
||||
func WithBindIP(ip net.IP) Option {
|
||||
return func(s *Server) {
|
||||
if len(ip) != 0 {
|
||||
@ -71,7 +73,7 @@ func WithBindIP(ip net.IP) Option {
|
||||
}
|
||||
}
|
||||
|
||||
// logger can be used to provide a custom log target.
|
||||
// WithLogger can be used to provide a custom log target.
|
||||
// Defaults to ioutil.Discard.
|
||||
func WithLogger(l Logger) Option {
|
||||
return func(s *Server) {
|
||||
@ -81,7 +83,7 @@ func WithLogger(l Logger) Option {
|
||||
}
|
||||
}
|
||||
|
||||
// Optional function for dialing out
|
||||
// WithDial Optional function for dialing out
|
||||
func WithDial(dial func(ctx context.Context, network, addr string) (net.Conn, error)) Option {
|
||||
return func(s *Server) {
|
||||
if dial != nil {
|
||||
@ -90,6 +92,7 @@ func WithDial(dial func(ctx context.Context, network, addr string) (net.Conn, er
|
||||
}
|
||||
}
|
||||
|
||||
// WithGPool can be provided to do custom goroutine pool.
|
||||
func WithGPool(pool GPool) Option {
|
||||
return func(s *Server) {
|
||||
if pool != nil {
|
||||
@ -97,3 +100,24 @@ func WithGPool(pool GPool) Option {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WithConnectHandle is used to handle a user's connect command
|
||||
func WithConnectHandle(h func(ctx context.Context, writer io.Writer, req *Request) error) Option {
|
||||
return func(s *Server) {
|
||||
s.userConnectHandle = h
|
||||
}
|
||||
}
|
||||
|
||||
// WithBindHandle is used to handle a user's bind command
|
||||
func WithBindHandle(h func(ctx context.Context, writer io.Writer, req *Request) error) Option {
|
||||
return func(s *Server) {
|
||||
s.userBindHandle = h
|
||||
}
|
||||
}
|
||||
|
||||
// WithAssociateHandle is used to handle a user's associate command
|
||||
func WithAssociateHandle(h func(ctx context.Context, writer io.Writer, req *Request) error) Option {
|
||||
return func(s *Server) {
|
||||
s.userAssociateHandle = h
|
||||
}
|
||||
}
|
||||
|
100
request.go
100
request.go
@ -10,7 +10,7 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
unrecognizedAddrType = fmt.Errorf("Unrecognized address type")
|
||||
errUnrecognizedAddrType = fmt.Errorf("Unrecognized address type")
|
||||
)
|
||||
|
||||
// AddressRewriter is used to rewrite a destination transparently
|
||||
@ -25,11 +25,12 @@ type Request struct {
|
||||
AuthContext *AuthContext
|
||||
// AddrSpec of the the network that sent the request
|
||||
RemoteAddr *AddrSpec
|
||||
// AddrSpec of the desired destination
|
||||
DestAddr *AddrSpec
|
||||
// AddrSpec of the actual destination (might be affected by rewrite)
|
||||
realDestAddr *AddrSpec
|
||||
bufConn io.Reader
|
||||
DestAddr *AddrSpec
|
||||
// Reader connect of request
|
||||
Reader io.Reader
|
||||
// AddrSpec of the desired destination
|
||||
RawDestAddr *AddrSpec
|
||||
}
|
||||
|
||||
type conn interface {
|
||||
@ -56,8 +57,8 @@ func NewRequest(bufConn io.Reader) (*Request, error) {
|
||||
}
|
||||
return &Request{
|
||||
Header: hd,
|
||||
DestAddr: &hd.Address,
|
||||
bufConn: bufConn,
|
||||
RawDestAddr: &hd.Address,
|
||||
Reader: bufConn,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@ -66,45 +67,54 @@ func (s *Server) handleRequest(write io.Writer, req *Request) error {
|
||||
ctx := context.Background()
|
||||
|
||||
// Resolve the address if we have a FQDN
|
||||
dest := req.DestAddr
|
||||
dest := req.RawDestAddr
|
||||
if dest.FQDN != "" {
|
||||
ctx_, addr, err := s.resolver.Resolve(ctx, dest.FQDN)
|
||||
_ctx, addr, err := s.resolver.Resolve(ctx, dest.FQDN)
|
||||
if err != nil {
|
||||
if err := sendReply(write, req.Header, hostUnreachable); err != nil {
|
||||
if err := SendReply(write, req.Header, hostUnreachable); err != nil {
|
||||
return fmt.Errorf("failed to send reply, %v", err)
|
||||
}
|
||||
return fmt.Errorf("failed to resolve destination[%v], %v", dest.FQDN, err)
|
||||
}
|
||||
ctx = ctx_
|
||||
ctx = _ctx
|
||||
dest.IP = addr
|
||||
}
|
||||
|
||||
// Apply any address rewrites
|
||||
req.realDestAddr = req.DestAddr
|
||||
req.DestAddr = req.RawDestAddr
|
||||
if s.rewriter != nil {
|
||||
ctx, req.realDestAddr = s.rewriter.Rewrite(ctx, req)
|
||||
ctx, req.DestAddr = s.rewriter.Rewrite(ctx, req)
|
||||
}
|
||||
|
||||
// Check if this is allowed
|
||||
if ctx_, ok := s.rules.Allow(ctx, req); !ok {
|
||||
if err := sendReply(write, req.Header, ruleFailure); err != nil {
|
||||
_ctx, ok := s.rules.Allow(ctx, req)
|
||||
if !ok {
|
||||
if err := SendReply(write, req.Header, ruleFailure); err != nil {
|
||||
return fmt.Errorf("failed to send reply, %v", err)
|
||||
}
|
||||
return fmt.Errorf("bind to %v blocked by rules", req.DestAddr)
|
||||
} else {
|
||||
ctx = ctx_
|
||||
return fmt.Errorf("bind to %v blocked by rules", req.RawDestAddr)
|
||||
}
|
||||
ctx = _ctx
|
||||
|
||||
// Switch on the command
|
||||
switch req.Command {
|
||||
case CommandConnect:
|
||||
if s.userConnectHandle != nil {
|
||||
return s.userConnectHandle(ctx, write, req)
|
||||
}
|
||||
return s.handleConnect(ctx, write, req)
|
||||
case CommandBind:
|
||||
if s.userBindHandle != nil {
|
||||
return s.userBindHandle(ctx, write, req)
|
||||
}
|
||||
return s.handleBind(ctx, write, req)
|
||||
case CommandAssociate:
|
||||
if s.userAssociateHandle != nil {
|
||||
return s.userAssociateHandle(ctx, write, req)
|
||||
}
|
||||
return s.handleAssociate(ctx, write, req)
|
||||
default:
|
||||
if err := sendReply(write, req.Header, commandNotSupported); err != nil {
|
||||
if err := SendReply(write, req.Header, commandNotSupported); err != nil {
|
||||
return fmt.Errorf("failed to send reply, %v", err)
|
||||
}
|
||||
return fmt.Errorf("unsupported command[%v]", req.Command)
|
||||
@ -120,7 +130,7 @@ func (s *Server) handleConnect(ctx context.Context, writer io.Writer, req *Reque
|
||||
return net.Dial(net_, addr)
|
||||
}
|
||||
}
|
||||
target, err := dial(ctx, "tcp", req.realDestAddr.Address())
|
||||
target, err := dial(ctx, "tcp", req.DestAddr.Address())
|
||||
if err != nil {
|
||||
msg := err.Error()
|
||||
resp := hostUnreachable
|
||||
@ -129,23 +139,23 @@ func (s *Server) handleConnect(ctx context.Context, writer io.Writer, req *Reque
|
||||
} else if strings.Contains(msg, "network is unreachable") {
|
||||
resp = networkUnreachable
|
||||
}
|
||||
if err := sendReply(writer, req.Header, resp); err != nil {
|
||||
if err := SendReply(writer, req.Header, resp); err != nil {
|
||||
return fmt.Errorf("failed to send reply, %v", err)
|
||||
}
|
||||
return fmt.Errorf("connect to %v failed, %v", req.DestAddr, err)
|
||||
return fmt.Errorf("connect to %v failed, %v", req.RawDestAddr, err)
|
||||
}
|
||||
defer target.Close()
|
||||
|
||||
// Send success
|
||||
if err := sendReply(writer, req.Header, successReply, target.LocalAddr()); err != nil {
|
||||
if err := SendReply(writer, req.Header, successReply, target.LocalAddr()); err != nil {
|
||||
return fmt.Errorf("failed to send reply, %v", err)
|
||||
}
|
||||
|
||||
// Start proxying
|
||||
errCh := make(chan error, 2)
|
||||
|
||||
s.submit(func() { errCh <- s.proxy(target, req.bufConn) })
|
||||
s.submit(func() { errCh <- s.proxy(writer, target) })
|
||||
s.submit(func() { errCh <- s.Proxy(target, req.Reader) })
|
||||
s.submit(func() { errCh <- s.Proxy(writer, target) })
|
||||
|
||||
// Wait
|
||||
for i := 0; i < 2; i++ {
|
||||
@ -161,7 +171,7 @@ func (s *Server) handleConnect(ctx context.Context, writer io.Writer, req *Reque
|
||||
// handleBind is used to handle a connect command
|
||||
func (s *Server) handleBind(ctx context.Context, writer io.Writer, req *Request) error {
|
||||
// TODO: Support bind
|
||||
if err := sendReply(writer, req.Header, commandNotSupported); err != nil {
|
||||
if err := SendReply(writer, req.Header, commandNotSupported); err != nil {
|
||||
return fmt.Errorf("failed to send reply: %v", err)
|
||||
}
|
||||
return nil
|
||||
@ -176,7 +186,7 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, req *Req
|
||||
return net.Dial(net_, addr)
|
||||
}
|
||||
}
|
||||
target, err := dial(ctx, "udp", req.realDestAddr.Address())
|
||||
target, err := dial(ctx, "udp", req.DestAddr.Address())
|
||||
if err != nil {
|
||||
msg := err.Error()
|
||||
resp := hostUnreachable
|
||||
@ -185,16 +195,16 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, req *Req
|
||||
} else if strings.Contains(msg, "network is unreachable") {
|
||||
resp = networkUnreachable
|
||||
}
|
||||
if err := sendReply(writer, req.Header, resp); err != nil {
|
||||
if err := SendReply(writer, req.Header, resp); err != nil {
|
||||
return fmt.Errorf("failed to send reply, %v", err)
|
||||
}
|
||||
return fmt.Errorf("connect to %v failed, %v", req.DestAddr, err)
|
||||
return fmt.Errorf("connect to %v failed, %v", req.RawDestAddr, err)
|
||||
}
|
||||
defer target.Close()
|
||||
|
||||
targetUdp, ok := target.(*net.UDPConn)
|
||||
targetUDP, ok := target.(*net.UDPConn)
|
||||
if !ok {
|
||||
if err := sendReply(writer, req.Header, serverFailure); err != nil {
|
||||
if err := SendReply(writer, req.Header, serverFailure); err != nil {
|
||||
return fmt.Errorf("failed to send reply, %v", err)
|
||||
}
|
||||
return fmt.Errorf("dial udp invalid")
|
||||
@ -202,16 +212,16 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, req *Req
|
||||
|
||||
bindLn, err := net.ListenUDP("udp", nil)
|
||||
if err != nil {
|
||||
if err := sendReply(writer, req.Header, serverFailure); err != nil {
|
||||
if err := SendReply(writer, req.Header, serverFailure); err != nil {
|
||||
return fmt.Errorf("failed to send reply, %v", err)
|
||||
}
|
||||
return fmt.Errorf("listen udp failed, %v", err)
|
||||
}
|
||||
defer bindLn.Close()
|
||||
|
||||
s.logger.Errorf("target addr %v, listen addr: %s", targetUdp.RemoteAddr(), bindLn.LocalAddr())
|
||||
s.logger.Errorf("target addr %v, listen addr: %s", targetUDP.RemoteAddr(), bindLn.LocalAddr())
|
||||
// send BND.ADDR and BND.PORT, client must
|
||||
if err = sendReply(writer, req.Header, successReply, bindLn.LocalAddr()); err != nil {
|
||||
if err = SendReply(writer, req.Header, successReply, bindLn.LocalAddr()); err != nil {
|
||||
return fmt.Errorf("failed to send reply, %v", err)
|
||||
}
|
||||
|
||||
@ -228,7 +238,7 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, req *Req
|
||||
conns := sync.Map{}
|
||||
bufPool := s.bufferPool.Get()
|
||||
defer func() {
|
||||
targetUdp.Close()
|
||||
targetUDP.Close()
|
||||
bindLn.Close()
|
||||
s.bufferPool.Put(bufPool)
|
||||
}()
|
||||
@ -278,8 +288,8 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, req *Req
|
||||
}
|
||||
|
||||
// 把消息写给remote sever
|
||||
if _, err := targetUdp.Write(buf[headLen:n]); err != nil {
|
||||
s.logger.Errorf("write data to remote %s failed, %v", targetUdp.RemoteAddr(), err)
|
||||
if _, err := targetUDP.Write(buf[headLen:n]); err != nil {
|
||||
s.logger.Errorf("write data to remote %s failed, %v", targetUDP.RemoteAddr(), err)
|
||||
return
|
||||
}
|
||||
|
||||
@ -288,16 +298,16 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, req *Req
|
||||
// read from remote server and write to client
|
||||
bufPool := s.bufferPool.Get()
|
||||
defer func() {
|
||||
targetUdp.Close()
|
||||
targetUDP.Close()
|
||||
bindLn.Close()
|
||||
s.bufferPool.Put(bufPool)
|
||||
}()
|
||||
|
||||
for {
|
||||
buf := bufPool[:cap(bufPool)]
|
||||
n, remote, err := targetUdp.ReadFrom(buf)
|
||||
n, remote, err := targetUDP.ReadFrom(buf)
|
||||
if err != nil {
|
||||
s.logger.Errorf("read data from remote %s failed, %v", targetUdp.RemoteAddr(), err)
|
||||
s.logger.Errorf("read data from remote %s failed, %v", targetUDP.RemoteAddr(), err)
|
||||
return
|
||||
}
|
||||
|
||||
@ -334,15 +344,15 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, req *Req
|
||||
s.bufferPool.Put(buf)
|
||||
}()
|
||||
for {
|
||||
_, err := req.bufConn.Read(buf[:cap(buf)])
|
||||
_, err := req.Reader.Read(buf[:cap(buf)])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// sendReply is used to send a reply message
|
||||
func sendReply(w io.Writer, head Header, resp uint8, bindAddr ...net.Addr) error {
|
||||
// SendReply is used to send a reply message
|
||||
func SendReply(w io.Writer, head Header, resp uint8, bindAddr ...net.Addr) error {
|
||||
/*
|
||||
The SOCKS response is formed as follows:
|
||||
+----+-----+-------+------+----------+----------+
|
||||
@ -397,9 +407,9 @@ type closeWriter interface {
|
||||
CloseWrite() error
|
||||
}
|
||||
|
||||
// proxy is used to suffle data from src to destination, and sends errors
|
||||
// Proxy is used to suffle data from src to destination, and sends errors
|
||||
// down a dedicated channel
|
||||
func (s *Server) proxy(dst io.Writer, src io.Reader) error {
|
||||
func (s *Server) Proxy(dst io.Writer, src io.Reader) error {
|
||||
buf := s.bufferPool.Get()
|
||||
defer s.bufferPool.Put(buf)
|
||||
_, err := io.CopyBuffer(dst, src, buf[:cap(buf)])
|
||||
|
@ -13,6 +13,7 @@ type NameResolver interface {
|
||||
// DNSResolver uses the system DNS to resolve host names
|
||||
type DNSResolver struct{}
|
||||
|
||||
// Resolve implement interface NameResolver
|
||||
func (d DNSResolver) Resolve(ctx context.Context, name string) (context.Context, net.IP, error) {
|
||||
addr, err := net.ResolveIPAddr("ip", name)
|
||||
if err != nil {
|
||||
|
@ -27,6 +27,7 @@ type PermitCommand struct {
|
||||
EnableAssociate bool
|
||||
}
|
||||
|
||||
// Allow implement interface RuleSet
|
||||
func (p *PermitCommand) Allow(ctx context.Context, req *Request) (context.Context, bool) {
|
||||
switch req.Command {
|
||||
case CommandConnect:
|
||||
|
@ -4,6 +4,7 @@ import (
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net"
|
||||
@ -47,6 +48,10 @@ type Server struct {
|
||||
bufferPool *pool
|
||||
// goroutine pool
|
||||
gPool GPool
|
||||
// user's handle
|
||||
userConnectHandle func(ctx context.Context, writer io.Writer, req *Request) error
|
||||
userBindHandle func(ctx context.Context, writer io.Writer, req *Request) error
|
||||
userAssociateHandle func(ctx context.Context, writer io.Writer, req *Request) error
|
||||
}
|
||||
|
||||
// New creates a new Server and potentially returns an error
|
||||
@ -131,8 +136,8 @@ func (s *Server) ServeConn(conn net.Conn) (err error) {
|
||||
|
||||
request, err := NewRequest(bufConn)
|
||||
if err != nil {
|
||||
if err == unrecognizedAddrType {
|
||||
if err := sendReply(conn, Header{Version: version[0]}, addrTypeNotSupported); err != nil {
|
||||
if err == errUnrecognizedAddrType {
|
||||
if err := SendReply(conn, Header{Version: version[0]}, addrTypeNotSupported); err != nil {
|
||||
return fmt.Errorf("failed to send reply, %v", err)
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user