Testing the CONNECT handling
This commit is contained in:
parent
5437f80e57
commit
9aca0ed614
26
request.go
26
request.go
@ -7,6 +7,7 @@ import (
|
||||
"log"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -42,6 +43,11 @@ type addrSpec struct {
|
||||
port int
|
||||
}
|
||||
|
||||
type conn interface {
|
||||
Write([]byte) (int, error)
|
||||
RemoteAddr() net.Addr
|
||||
}
|
||||
|
||||
func (a *addrSpec) String() string {
|
||||
if a.fqdn != "" {
|
||||
return fmt.Sprintf("%s (%s):%d", a.fqdn, a.ip, a.port)
|
||||
@ -50,7 +56,7 @@ func (a *addrSpec) String() string {
|
||||
}
|
||||
|
||||
// handleRequest is used for request processing after authentication
|
||||
func (s *Server) handleRequest(conn net.Conn, bufConn io.Reader) error {
|
||||
func (s *Server) handleRequest(conn conn, bufConn io.Reader) error {
|
||||
// Read the version byte
|
||||
header := []byte{0, 0, 0}
|
||||
if _, err := io.ReadAtLeast(bufConn, header, 3); err != nil {
|
||||
@ -102,7 +108,7 @@ func (s *Server) handleRequest(conn net.Conn, bufConn io.Reader) error {
|
||||
}
|
||||
|
||||
// handleConnect is used to handle a connect command
|
||||
func (s *Server) handleConnect(conn net.Conn, bufConn io.Reader, dest *addrSpec) error {
|
||||
func (s *Server) handleConnect(conn conn, bufConn io.Reader, dest *addrSpec) error {
|
||||
// Check if this is allowed
|
||||
client := conn.RemoteAddr().(*net.TCPAddr)
|
||||
if !s.config.Rules.AllowConnect(dest.ip, dest.port, client.IP, client.Port) {
|
||||
@ -148,7 +154,7 @@ func (s *Server) handleConnect(conn net.Conn, bufConn io.Reader, dest *addrSpec)
|
||||
}
|
||||
|
||||
// handleBind is used to handle a connect command
|
||||
func (s *Server) handleBind(conn net.Conn, bufConn io.Reader, dest *addrSpec) error {
|
||||
func (s *Server) handleBind(conn conn, bufConn io.Reader, dest *addrSpec) error {
|
||||
// Check if this is allowed
|
||||
client := conn.RemoteAddr().(*net.TCPAddr)
|
||||
if !s.config.Rules.AllowBind(dest.ip, dest.port, client.IP, client.Port) {
|
||||
@ -166,7 +172,7 @@ func (s *Server) handleBind(conn net.Conn, bufConn io.Reader, dest *addrSpec) er
|
||||
}
|
||||
|
||||
// handleAssociate is used to handle a connect command
|
||||
func (s *Server) handleAssociate(conn net.Conn, bufConn io.Reader, dest *addrSpec) error {
|
||||
func (s *Server) handleAssociate(conn conn, bufConn io.Reader, dest *addrSpec) error {
|
||||
// Check if this is allowed
|
||||
client := conn.RemoteAddr().(*net.TCPAddr)
|
||||
if !s.config.Rules.AllowAssociate(dest.ip, dest.port, client.IP, client.Port) {
|
||||
@ -277,9 +283,15 @@ func sendReply(w io.Writer, resp uint8, addr *addrSpec) error {
|
||||
|
||||
// proxy is used to suffle data from src to destination, and sends errors
|
||||
// down a dedicated channel
|
||||
func proxy(name string, dst io.WriteCloser, src io.Reader, errCh chan error) {
|
||||
defer dst.Close()
|
||||
func proxy(name string, dst io.Writer, src io.Reader, errCh chan error) {
|
||||
// Copy
|
||||
n, err := io.Copy(dst, src)
|
||||
errCh <- err
|
||||
|
||||
// Log, and sleep. This is jank but allows the otherside
|
||||
// to finish a pending copy
|
||||
log.Printf("[DEBUG] Copied %d bytes for %s", n, name)
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
// Send any errors
|
||||
errCh <- err
|
||||
}
|
||||
|
86
request_test.go
Normal file
86
request_test.go
Normal file
@ -0,0 +1,86 @@
|
||||
package socks5
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type MockConn struct {
|
||||
buf bytes.Buffer
|
||||
}
|
||||
|
||||
func (m *MockConn) Write(b []byte) (int, error) {
|
||||
return m.buf.Write(b)
|
||||
}
|
||||
|
||||
func (m *MockConn) RemoteAddr() net.Addr {
|
||||
return &net.TCPAddr{IP: []byte{127, 0, 0, 1}, Port: 65432}
|
||||
}
|
||||
|
||||
func TestRequest_Connect(t *testing.T) {
|
||||
// Create a local listener
|
||||
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
go func() {
|
||||
conn, err := l.Accept()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
buf := make([]byte, 4)
|
||||
if _, err := io.ReadAtLeast(conn, buf, 4); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(buf, []byte("ping")) {
|
||||
t.Fatalf("bad: %v", buf)
|
||||
}
|
||||
conn.Write([]byte("pong"))
|
||||
}()
|
||||
lAddr := l.Addr().(*net.TCPAddr)
|
||||
|
||||
// Make server
|
||||
s := &Server{config: &Config{
|
||||
Rules: PermitAll(),
|
||||
Resolver: DNSResolver{},
|
||||
}}
|
||||
|
||||
// Create the connect request
|
||||
req := bytes.NewBuffer(nil)
|
||||
req.Write([]byte{5, 1, 0, 1, 127, 0, 0, 1})
|
||||
|
||||
port := []byte{0, 0}
|
||||
binary.BigEndian.PutUint16(port, uint16(lAddr.Port))
|
||||
req.Write(port)
|
||||
|
||||
// Send a ping
|
||||
req.Write([]byte("ping"))
|
||||
|
||||
// Handle the request
|
||||
resp := &MockConn{}
|
||||
if err := s.handleRequest(resp, req); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
// Verify response
|
||||
out := resp.buf.Bytes()
|
||||
expected := []byte{
|
||||
5,
|
||||
0,
|
||||
0,
|
||||
1,
|
||||
127, 0, 0, 1,
|
||||
port[0],
|
||||
port[1],
|
||||
'p', 'o', 'n', 'g',
|
||||
}
|
||||
if !bytes.Equal(out, expected) {
|
||||
t.Fatalf("bad: %v", out)
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user