fix handle requst name

This commit is contained in:
mo 2020-04-24 15:52:47 +08:00
parent d444ab9f9d
commit 80604b63bb
4 changed files with 37 additions and 33 deletions

@ -102,21 +102,21 @@ func WithGPool(pool GPool) Option {
}
// WithConnectHandle is used to handle a user's connect command
func WithConnectHandle(h func(ctx context.Context, writer io.Writer, req *Request) error) Option {
func WithConnectHandle(h func(ctx context.Context, writer io.Writer, request *Request) error) Option {
return func(s *Server) {
s.userConnectHandle = h
}
}
// WithBindHandle is used to handle a user's bind command
func WithBindHandle(h func(ctx context.Context, writer io.Writer, req *Request) error) Option {
func WithBindHandle(h func(ctx context.Context, writer io.Writer, request *Request) error) Option {
return func(s *Server) {
s.userBindHandle = h
}
}
// WithAssociateHandle is used to handle a user's associate command
func WithAssociateHandle(h func(ctx context.Context, writer io.Writer, req *Request) error) Option {
func WithAssociateHandle(h func(ctx context.Context, writer io.Writer, request *Request) error) Option {
return func(s *Server) {
s.userAssociateHandle = h
}

@ -123,3 +123,7 @@ func (sf *Packet) Header() []byte {
bs = append(bs, hi, lo)
return bs
}
func (sf *Packet) Bytes() []byte {
return append(sf.Header(), sf.Data...)
}

@ -124,7 +124,7 @@ func (s *Server) handleRequest(write io.Writer, req *Request) error {
}
// handleConnect is used to handle a connect command
func (s *Server) handleConnect(ctx context.Context, writer io.Writer, req *Request) error {
func (s *Server) handleConnect(ctx context.Context, writer io.Writer, request *Request) error {
// Attempt to connect
dial := s.dial
if dial == nil {
@ -132,7 +132,7 @@ func (s *Server) handleConnect(ctx context.Context, writer io.Writer, req *Reque
return net.Dial(net_, addr)
}
}
target, err := dial(ctx, "tcp", req.DestAddr.String())
target, err := dial(ctx, "tcp", request.DestAddr.String())
if err != nil {
msg := err.Error()
resp := RepHostUnreachable
@ -141,24 +141,22 @@ func (s *Server) handleConnect(ctx context.Context, writer io.Writer, req *Reque
} else if strings.Contains(msg, "network is unreachable") {
resp = RepNetworkUnreachable
}
if err := SendReply(writer, req.Header, resp); err != nil {
if err := SendReply(writer, request.Header, resp); err != nil {
return fmt.Errorf("failed to send reply, %v", err)
}
return fmt.Errorf("connect to %v failed, %v", req.RawDestAddr, err)
return fmt.Errorf("connect to %v failed, %v", request.RawDestAddr, err)
}
defer target.Close()
// Send success
if err := SendReply(writer, req.Header, RepSuccess, target.LocalAddr()); err != nil {
if err := SendReply(writer, request.Header, RepSuccess, target.LocalAddr()); err != nil {
return fmt.Errorf("failed to send reply, %v", err)
}
// Start proxying
errCh := make(chan error, 2)
s.submit(func() { errCh <- s.Proxy(target, req.Reader) })
s.submit(func() { errCh <- s.Proxy(target, request.Reader) })
s.submit(func() { errCh <- s.Proxy(writer, target) })
// Wait
for i := 0; i < 2; i++ {
e := <-errCh
@ -171,16 +169,16 @@ func (s *Server) handleConnect(ctx context.Context, writer io.Writer, req *Reque
}
// handleBind is used to handle a connect command
func (s *Server) handleBind(_ context.Context, writer io.Writer, req *Request) error {
func (s *Server) handleBind(_ context.Context, writer io.Writer, request *Request) error {
// TODO: Support bind
if err := SendReply(writer, req.Header, RepCommandNotSupported); err != nil {
if err := SendReply(writer, request.Header, RepCommandNotSupported); err != nil {
return fmt.Errorf("failed to send reply: %v", err)
}
return nil
}
// handleAssociate is used to handle a connect command
func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, req *Request) error {
func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, request *Request) error {
// Attempt to connect
dial := s.dial
if dial == nil {
@ -188,7 +186,7 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, req *Req
return net.Dial(net_, addr)
}
}
target, err := dial(ctx, "udp", req.DestAddr.String())
target, err := dial(ctx, "udp", request.DestAddr.String())
if err != nil {
msg := err.Error()
resp := RepHostUnreachable
@ -197,16 +195,16 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, req *Req
} else if strings.Contains(msg, "network is unreachable") {
resp = RepNetworkUnreachable
}
if err := SendReply(writer, req.Header, resp); err != nil {
if err := SendReply(writer, request.Header, resp); err != nil {
return fmt.Errorf("failed to send reply, %v", err)
}
return fmt.Errorf("connect to %v failed, %v", req.RawDestAddr, err)
return fmt.Errorf("connect to %v failed, %v", request.RawDestAddr, err)
}
defer target.Close()
targetUDP, ok := target.(*net.UDPConn)
if !ok {
if err := SendReply(writer, req.Header, RepServerFailure); err != nil {
if err := SendReply(writer, request.Header, RepServerFailure); err != nil {
return fmt.Errorf("failed to send reply, %v", err)
}
return fmt.Errorf("dial udp invalid")
@ -214,7 +212,7 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, req *Req
bindLn, err := net.ListenUDP("udp", nil)
if err != nil {
if err := SendReply(writer, req.Header, RepServerFailure); err != nil {
if err := SendReply(writer, request.Header, RepServerFailure); err != nil {
return fmt.Errorf("failed to send reply, %v", err)
}
return fmt.Errorf("listen udp failed, %v", err)
@ -223,7 +221,7 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, req *Req
s.logger.Errorf("target addr %v, listen addr: %s", targetUDP.RemoteAddr(), bindLn.LocalAddr())
// send BND.ADDR and BND.PORT, client must
if err = SendReply(writer, req.Header, RepSuccess, bindLn.LocalAddr()); err != nil {
if err = SendReply(writer, request.Header, RepSuccess, bindLn.LocalAddr()); err != nil {
return fmt.Errorf("failed to send reply, %v", err)
}
@ -245,8 +243,7 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, req *Req
s.bufferPool.Put(bufPool)
}()
for {
buf := bufPool[:cap(bufPool)]
n, srcAddr, err := bindLn.ReadFrom(buf)
n, srcAddr, err := bindLn.ReadFrom(bufPool[:cap(bufPool)])
if err != nil {
if strings.Contains(err.Error(), "use of closed network connection") {
s.logger.Errorf("read data from bind listen address %s failed, %v", bindLn.LocalAddr(), err)
@ -256,14 +253,9 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, req *Req
}
pk := NewEmptyPacket()
if err := pk.Parse(buf[:n]); err != nil {
if err := pk.Parse(bufPool[:n]); err != nil {
continue
}
// 把消息写给remote sever
if _, err := targetUDP.Write(pk.Data); err != nil {
s.logger.Errorf("write data to remote %s failed, %v", targetUDP.RemoteAddr(), err)
return
}
if _, ok := conns.LoadOrStore(srcAddr.String(), struct{}{}); !ok {
s.submit(func() {
@ -300,6 +292,12 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, req *Req
}
})
}
// 把消息写给remote sever
if _, err := targetUDP.Write(pk.Data); err != nil {
s.logger.Errorf("write data to remote %s failed, %v", targetUDP.RemoteAddr(), err)
return
}
}
})
@ -308,9 +306,11 @@ func (s *Server) handleAssociate(ctx context.Context, writer io.Writer, req *Req
s.bufferPool.Put(buf)
}()
for {
_, err := req.Reader.Read(buf[:cap(buf)])
_, err := request.Reader.Read(buf[:cap(buf)])
if err != nil {
return err
if strings.Contains(err.Error(), "use of closed network connection") {
return err
}
}
}
}

@ -49,9 +49,9 @@ type Server struct {
// goroutine pool
gPool GPool
// user's handle
userConnectHandle func(ctx context.Context, writer io.Writer, req *Request) error
userBindHandle func(ctx context.Context, writer io.Writer, req *Request) error
userAssociateHandle func(ctx context.Context, writer io.Writer, req *Request) error
userConnectHandle func(ctx context.Context, writer io.Writer, request *Request) error
userBindHandle func(ctx context.Context, writer io.Writer, request *Request) error
userAssociateHandle func(ctx context.Context, writer io.Writer, request *Request) error
}
// New creates a new Server and potentially returns an error