diff --git a/net.go b/net.go index b9fff2f..aef0d53 100644 --- a/net.go +++ b/net.go @@ -16,9 +16,9 @@ func (b *requestBuilder) add(data ...byte) { _, _ = b.Write(data) } -func (c *config) sendReceive(conn net.Conn, req []byte) (resp []byte, err error) { - if c.Timeout > 0 { - if err := conn.SetWriteDeadline(time.Now().Add(c.Timeout)); err != nil { +func (cfg *config) sendReceive(conn net.Conn, req []byte) (resp []byte, err error) { + if cfg.Timeout > 0 { + if err := conn.SetWriteDeadline(time.Now().Add(cfg.Timeout)); err != nil { return nil, err } } @@ -26,14 +26,14 @@ func (c *config) sendReceive(conn net.Conn, req []byte) (resp []byte, err error) if err != nil { return } - resp, err = c.readAll(conn) + resp, err = cfg.readAll(conn) return } -func (c *config) readAll(conn net.Conn) (resp []byte, err error) { +func (cfg *config) readAll(conn net.Conn) (resp []byte, err error) { resp = make([]byte, 1024) - if c.Timeout > 0 { - if err := conn.SetReadDeadline(time.Now().Add(c.Timeout)); err != nil { + if cfg.Timeout > 0 { + if err := conn.SetReadDeadline(time.Now().Add(cfg.Timeout)); err != nil { return nil, err } } diff --git a/parse.go b/parse.go index d6d3c3f..8c9335b 100644 --- a/parse.go +++ b/parse.go @@ -3,6 +3,7 @@ package socks import ( "errors" "fmt" + "net" "net/url" "time" ) @@ -13,6 +14,7 @@ type ( Host string Auth *auth Timeout time.Duration + conn net.Conn } auth struct { Username string diff --git a/socks.go b/socks.go index d626901..7ef3df3 100644 --- a/socks.go +++ b/socks.go @@ -6,6 +6,7 @@ Package socks implements a SOCKS (SOCKS4, SOCKS4A and SOCKS5) proxy client. A complete example using this package: + package main import ( @@ -43,6 +44,7 @@ package socks // import "h12.io/socks" import ( "fmt" "net" + "time" ) // Constants to choose which version of SOCKS protocol to use. @@ -52,6 +54,19 @@ const ( SOCKS5 ) +// DialWithConn returns the dial function to be used in http.Transport object. +// Argument proxyURI should be in the format: "socks5://user:password@127.0.0.1:1080?timeout=5s". +// The protocol could be socks5, socks4 and socks4a. DialWithConn will use the given connection +// to communicate with the proxy server. +func DialWithConn(proxyURI string, conn net.Conn) func(string, string) (net.Conn, error) { + cfg, err := parse(proxyURI) + if err != nil { + return dialError(err) + } + cfg.conn = conn + return cfg.dialFunc() +} + // Dial returns the dial function to be used in http.Transport object. // Argument proxyURI should be in the format: "socks5://user:password@127.0.0.1:1080?timeout=5s". // The protocol could be socks5, socks4 and socks4a. @@ -70,18 +85,18 @@ func DialSocksProxy(socksType int, proxy string) func(string, string) (net.Conn, return (&config{Proto: socksType, Host: proxy}).dialFunc() } -func (c *config) dialFunc() func(string, string) (net.Conn, error) { - switch c.Proto { +func (cfg *config) dialFunc() func(string, string) (net.Conn, error) { + switch cfg.Proto { case SOCKS5: return func(_, targetAddr string) (conn net.Conn, err error) { - return c.dialSocks5(targetAddr) + return cfg.dialSocks5(targetAddr) } case SOCKS4, SOCKS4A: return func(_, targetAddr string) (conn net.Conn, err error) { - return c.dialSocks4(targetAddr) + return cfg.dialSocks4(targetAddr) } } - return dialError(fmt.Errorf("unknown SOCKS protocol %v", c.Proto)) + return dialError(fmt.Errorf("unknown SOCKS protocol %v", cfg.Proto)) } func dialError(err error) func(string, string) (net.Conn, error) { @@ -89,3 +104,11 @@ func dialError(err error) func(string, string) (net.Conn, error) { return nil, err } } + +func (cfg *config) internalDial() (conn net.Conn, err error) { + if cfg.conn != nil { + err = cfg.conn.SetDeadline(time.Now().Add(cfg.Timeout)) + return cfg.conn, nil + } + return net.DialTimeout("tcp", cfg.Host, cfg.Timeout) +} diff --git a/socks4.go b/socks4.go index b65683d..b819da8 100644 --- a/socks4.go +++ b/socks4.go @@ -8,18 +8,11 @@ import ( func (cfg *config) dialSocks4(targetAddr string) (_ net.Conn, err error) { socksType := cfg.Proto - proxy := cfg.Host - // dial TCP - conn, err := net.DialTimeout("tcp", proxy, cfg.Timeout) + conn, err := cfg.internalDial() if err != nil { return nil, err } - defer func() { - if err != nil { - conn.Close() - } - }() // connection request host, port, err := splitHostPort(targetAddr) diff --git a/socks5.go b/socks5.go index 695f89d..fcf2492 100644 --- a/socks5.go +++ b/socks5.go @@ -6,18 +6,10 @@ import ( ) func (cfg *config) dialSocks5(targetAddr string) (_ net.Conn, err error) { - proxy := cfg.Host - - // dial TCP - conn, err := net.DialTimeout("tcp", proxy, cfg.Timeout) + conn, err := cfg.internalDial() if err != nil { return nil, err } - defer func() { - if err != nil { - conn.Close() - } - }() var req requestBuilder diff --git a/socks5_test.go b/socks5_test.go index a874db6..26757a1 100644 --- a/socks5_test.go +++ b/socks5_test.go @@ -11,7 +11,7 @@ import ( "testing" "time" - socks5 "github.com/h12w/go-socks5" + "github.com/h12w/go-socks5" "github.com/phayes/freeport" ) @@ -91,6 +91,29 @@ func TestSocks5Anonymous(t *testing.T) { } } +func TestSocks5AnonymousWithConn(t *testing.T) { + socksTestPort := newTestSocksServer(false) + conn, err := net.DialTimeout("tcp", fmt.Sprintf("127.0.0.1:%d", socksTestPort), 5*time.Second) + if err != nil { + t.Fatalf("dial socks5 proxy failed: %v", err) + } + dialSocksProxy := DialWithConn(fmt.Sprintf("socks5://127.0.0.1:%d?timeout=5s", socksTestPort), conn) + tr := &http.Transport{Dial: dialSocksProxy} + httpClient := &http.Client{Transport: tr} + resp, err := httpClient.Get(fmt.Sprintf("http://localhost" + httpTestServer.Addr)) + if err != nil { + t.Fatalf("expect response hello but got %s", err) + } + defer resp.Body.Close() + respBody, err := ioutil.ReadAll(resp.Body) + if err != nil { + panic(err) + } + if string(respBody) != "hello" { + t.Fatalf("expect response hello but got %s", respBody) + } +} + func TestSocks5Auth(t *testing.T) { socksTestPort := newTestSocksServer(true) dialSocksProxy := Dial(fmt.Sprintf("socks5://test_user:test_pass@127.0.0.1:%d?timeout=5s", socksTestPort)) @@ -115,5 +138,5 @@ func tcpReady(port int, timeout time.Duration) { if err != nil { panic(err) } - conn.Close() + _ = conn.Close() }