Compare commits

...

33 Commits

Author SHA1 Message Date
kayos@tcp.direct adc175e09b
Fix[testing]: another workaround for slow CI runners 2024-05-02 15:34:17 -07:00
kayos@tcp.direct 68e6aed64d
Fix: dep change 2024-05-02 15:29:36 -07:00
kayos@tcp.direct 21d0867118
Chore: deps 2024-05-02 15:29:10 -07:00
kayos@tcp.direct 143220da8c
Merge branch 'ssh' into development 2024-05-02 15:27:54 -07:00
kayos@tcp.direct 9983d6ba5a
Merge branch 'main' into development 2024-05-02 15:26:12 -07:00
kayos@tcp.direct 141ddb75f3
Feat: support SSH proxies 2024-05-02 15:22:56 -07:00
kayos@tcp.direct b84164bbad
Feat: parse and use given protocol URIs 2024-05-02 15:22:15 -07:00
kayos@tcp.direct 9d88a045f4
Fix: dialer timeout 2023-10-28 00:24:06 -07:00
kayos@tcp.direct e49afae3a2
Chore: update kayos/common 2023-10-27 23:50:52 -07:00
kayos@tcp.direct c21cba86d4
Fix: probably fix recycling being off and a slew of other bugs (make proxymap a ptr) 2023-10-27 23:48:58 -07:00
kayos@tcp.direct a4012cd5a9
Fix: fix CloseAllConns 2023-10-27 23:44:44 -07:00
kayos@tcp.direct 7b91d42cbf
Merge branch 'main' into development 2023-10-21 13:47:40 -07:00
kayos@tcp.direct df7b567666
gomod: retract premature tag 2023-09-07 00:11:08 -07:00
kayos@tcp.direct 8592c4f63f
Testing: add coverage for GetTotalValidated 2023-09-07 00:10:59 -07:00
kayos@tcp.direct aded278c6b
Fix: race condition in stats due to half baked atomic loads 2023-09-07 00:09:16 -07:00
kayos@tcp.direct 3c6bba9946
Merge branch 'erm' into development 2023-09-06 23:57:31 -07:00
kayos a1c31c9f14
Merge branch 'main' into development 2023-08-11 22:48:39 -07:00
kayos@tcp.direct db3eaf355d
Chore[CI]: fix branch in workflow 2023-08-11 22:46:42 -07:00
kayos@tcp.direct 0c61142d51
Chore: gomod 2023-08-11 22:45:46 -07:00
kayos@tcp.direct d8c0a3a5ae
CI: Add PR summarizer 2023-08-11 22:28:46 -07:00
kayos@tcp.direct 4f3c842f47
Fix: remove gotrace 2023-08-07 15:51:18 -07:00
kayos@tcp.direct 40ed2878dd
Fix go vet: returning fatal from goroutine 2023-08-07 15:46:50 -07:00
kayos@tcp.direct 7ae3ebd8f6
Update CI 2023-08-07 15:43:53 -07:00
kayos@tcp.direct 05b9e36ac6
Testing: Make integration test more realistic 2023-08-07 15:38:47 -07:00
kayos@tcp.direct bbbc9b74c0
Fix: bad var name 2023-08-07 15:22:22 -07:00
kayos@tcp.direct 39bc5ac542
Chore: tidy up 2023-08-07 15:21:42 -07:00
kayos@tcp.direct 84977926e1
Merge branch 'master' into development 2023-08-07 15:18:23 -07:00
kayos@tcp.direct 2e85c817c7
Fix race cond. in socks lib + add tests + update uagents 2023-08-07 15:17:55 -07:00
kayos@tcp.direct 503223183a
Fix: if proxy disqualified then it's not still good 2023-04-06 17:17:25 -07:00
kayos@tcp.direct 121543bbf8
Fix: mutex locking snafu 2023-03-06 01:39:37 -08:00
kayos@tcp.direct 54e641ff21
Fix: part of my commit message ended up in the commit :^) 2023-03-06 01:34:33 -08:00
kayos@tcp.direct 69187d7f8e
Refactor: reduce complexity 2023-03-06 01:32:51 -08:00
kayos@tcp.direct d8f03bc8c6
Fix: Didn't unlock ._. 2023-02-25 03:00:43 -08:00
12 changed files with 729 additions and 47 deletions

View File

@ -3,6 +3,7 @@ package prox5
import (
"errors"
"strconv"
"strings"
"time"
"git.tcp.direct/kayos/common/entropy"
@ -15,7 +16,17 @@ type proxyMap struct {
}
func (sm proxyMap) add(sock string) (*Proxy, bool) {
sm.plot.SetIfAbsent(sock, &Proxy{
prot := ProtoNull
if strings.Contains(sock, "://") {
prot, sock = extractProtoFromProxyString(sock)
}
// if the proxy already exists, noop and !ok
if sm.plot.Has(sock) {
return nil, false
}
newProx := &Proxy{
Endpoint: sock,
protocol: newImmutableProto(),
lastValidated: time.UnixMilli(0),
@ -23,7 +34,12 @@ func (sm proxyMap) add(sock string) (*Proxy, bool) {
timesBad: 0,
parent: sm.parent,
lock: stateUnlocked,
})
}
if prot != ProtoNull {
newProx.protocol.set(prot)
}
sm.plot.SetIfAbsent(sock, newProx)
return sm.plot.Get(sock)
}

19
defs.go
View File

@ -23,8 +23,8 @@ type proxyList struct {
func (pl *proxyList) add(p *Proxy) {
pl.Lock()
defer pl.Unlock()
pl.PushBack(p)
pl.Unlock()
}
func (pl *proxyList) pop() *Proxy {
@ -40,14 +40,18 @@ func (pl *proxyList) pop() *Proxy {
// ProxyChannels will likely be unexported in the future.
type ProxyChannels struct {
// SOCKS5 is a constant stream of verified SOCKS5 proxies
// SOCKS5 is a constant stream of verified SOCKS5 proxies.
SOCKS5 proxyList
// SOCKS4 is a constant stream of verified SOCKS4 proxies
// SOCKS4 is a constant stream of verified SOCKS4 proxies.
SOCKS4 proxyList
// SOCKS4a is a constant stream of verified SOCKS5 proxies
// SOCKS4a is a constant stream of verified SOCKS4a proxies.
SOCKS4a proxyList
// HTTP is a constant stream of verified SOCKS5 proxies
// HTTP is a constant stream of verified HTTP proxies.
HTTP proxyList
// HTTPS is a constant stream of verified HTTPS proxies.
HTTPS proxyList
// SSH is a constant stream of verified SSH proxies.
SSH proxyList
}
// Slice returns a slice of all proxyLists in ProxyChannels, note that HTTP is not included.
@ -59,7 +63,7 @@ func (pc ProxyChannels) Slice() []*proxyList {
return lists
}
// ProxyEngine represents a proxy pool
// ProxyEngine represents a proxy pool. This is the main component of the prox5 package.
type ProxyEngine struct {
Valids ProxyChannels
DebugLogger logger.Logger
@ -67,6 +71,9 @@ type ProxyEngine struct {
// stats holds the Statistics for ProxyEngine
stats Statistics
// Status is the current state of the ProxyEngine.
// This is modified with atomics and should only be accessed with atomic.LoadUint32 and atomic.StoreUint32.
// Likely to be unexported in the future and replaced with a ProxyEngine method.
Status uint32
// Pending is a constant stream of proxy strings to be verified

View File

@ -1,11 +1,16 @@
package prox5
import (
"context"
"sync/atomic"
"time"
)
func (p5 *ProxyEngine) getSocksStr(proto ProxyProtocol) string {
func (p5 *ProxyEngine) GetSocksStr(proto ProxyProtocol) string {
return p5.GetSocksStrCtx(context.Background(), proto)
}
func (p5 *ProxyEngine) GetSocksStrCtx(ctx context.Context, proto ProxyProtocol) string {
var sock *Proxy
var list *proxyList
switch proto {
@ -17,8 +22,21 @@ func (p5 *ProxyEngine) getSocksStr(proto ProxyProtocol) string {
list = &p5.Valids.SOCKS5
case ProtoHTTP:
list = &p5.Valids.HTTP
case ProtoHTTPS:
list = &p5.Valids.HTTPS
case ProtoSSH:
list = &p5.Valids.SSH
case ProtoNull:
return ""
default:
panic("unknown protocol")
}
for {
select {
case <-ctx.Done():
return ""
default:
}
if list.Len() == 0 {
p5.recycling()
time.Sleep(250 * time.Millisecond)
@ -29,6 +47,11 @@ func (p5 *ProxyEngine) getSocksStr(proto ProxyProtocol) string {
list.Unlock()
switch {
case sock == nil:
select {
case <-ctx.Done():
return ""
default:
}
p5.recycling()
time.Sleep(250 * time.Millisecond)
continue
@ -44,24 +67,24 @@ func (p5 *ProxyEngine) getSocksStr(proto ProxyProtocol) string {
// Socks5Str gets a SOCKS5 proxy that we have fully verified (dialed and then retrieved our IP address from a what-is-my-ip endpoint.
// Will block if one is not available!
func (p5 *ProxyEngine) Socks5Str() string {
return p5.getSocksStr(ProtoSOCKS5)
return p5.GetSocksStr(ProtoSOCKS5)
}
// Socks4Str gets a SOCKS4 proxy that we have fully verified.
// Will block if one is not available!
func (p5 *ProxyEngine) Socks4Str() string {
return p5.getSocksStr(ProtoSOCKS4)
return p5.GetSocksStr(ProtoSOCKS4)
}
// Socks4aStr gets a SOCKS4 proxy that we have fully verified.
// Will block if one is not available!
func (p5 *ProxyEngine) Socks4aStr() string {
return p5.getSocksStr(ProtoSOCKS4a)
return p5.GetSocksStr(ProtoSOCKS4a)
}
// GetHTTPTunnel checks for an available HTTP CONNECT proxy in our pool.
func (p5 *ProxyEngine) GetHTTPTunnel() string {
return p5.getSocksStr(ProtoHTTP)
return p5.GetSocksStr(ProtoHTTP)
}
// GetAnySOCKS retrieves any version SOCKS proxy as a Proxy type

3
go.mod
View File

@ -6,6 +6,7 @@ require (
git.tcp.direct/kayos/common v0.9.7
git.tcp.direct/kayos/go-socks5 v0.3.0
git.tcp.direct/kayos/socks v0.1.3
github.com/davecgh/go-spew v1.1.1
github.com/gdamore/tcell/v2 v2.7.1
github.com/miekg/dns v1.1.58
github.com/ooni/oohttp v0.6.7
@ -14,6 +15,7 @@ require (
github.com/refraction-networking/utls v1.6.0
github.com/rivo/tview v0.0.0-20230208211350-7dfff1ce7854
github.com/yunginnanet/Rate5 v1.3.0
golang.org/x/crypto v0.21.0
golang.org/x/net v0.22.0
)
@ -27,7 +29,6 @@ require (
github.com/patrickmn/go-cache v2.1.0+incompatible // indirect
github.com/quic-go/quic-go v0.37.4 // indirect
github.com/rivo/uniseg v0.4.3 // indirect
golang.org/x/crypto v0.21.0 // indirect
golang.org/x/mod v0.14.0 // indirect
golang.org/x/sys v0.18.0 // indirect
golang.org/x/term v0.18.0 // indirect

View File

@ -1,11 +1,11 @@
package prox5
import (
"net/netip"
"strconv"
"strings"
"github.com/miekg/dns"
"net/netip"
)
func filterv6(in string) (filtered string, ok bool) {
@ -50,7 +50,7 @@ func buildProxyString(username, password, address, port string, v6 bool) (result
builder.MustWriteString(password)
builder.MustWriteString("@")
}
builder.MustWriteString(address)
builder.MustWriteString(strings.ToLower(address))
if v6 {
builder.MustWriteString("]")
}
@ -60,11 +60,17 @@ func buildProxyString(username, password, address, port string, v6 bool) (result
}
func filter(in string) (filtered string, ok bool) { //nolint:cyclop
if !strings.Contains(in, ":") {
return "", false
protoStr, protoNormalized, protoOK := protoStrNormalize(in)
if protoOK {
in = protoNormalized
}
split := strings.Split(in, ":")
if !strings.Contains(in, ":") {
in = in + ":1080"
}
if len(split) < 2 {
return "", false
}
@ -72,7 +78,7 @@ func filter(in string) (filtered string, ok bool) { //nolint:cyclop
case 2:
_, isDomain := dns.IsDomainName(split[0])
if isDomain && isNumber(split[1]) {
return in, true
return protoStr + strings.ToLower(in), true
}
combo, err := netip.ParseAddrPort(in)
if err != nil {
@ -83,40 +89,44 @@ func filter(in string) (filtered string, ok bool) { //nolint:cyclop
if !strings.Contains(in, "@") {
return "", false
}
split := strings.Split(in, "@")
if !strings.Contains(split[0], ":") {
domSplit := strings.Split(in, "@")
if !strings.Contains(domSplit[0], ":") {
return "", false
}
splitAuth := strings.Split(split[0], ":")
splitServ := strings.Split(split[1], ":")
splitAuth := strings.Split(domSplit[0], ":")
splitServ := strings.Split(domSplit[1], ":")
_, isDomain := dns.IsDomainName(splitServ[0])
if isDomain && isNumber(splitServ[1]) {
return buildProxyString(splitAuth[0], splitAuth[1],
return protoStr + buildProxyString(splitAuth[0], splitAuth[1],
splitServ[0], splitServ[1], false), true
}
if _, err := netip.ParseAddrPort(split[1]); err == nil {
return buildProxyString(splitAuth[0], splitAuth[1],
if _, err := netip.ParseAddrPort(domSplit[1]); err == nil {
return protoStr + buildProxyString(splitAuth[0], splitAuth[1],
splitServ[0], splitServ[1], false), true
}
case 4:
_, isDomain := dns.IsDomainName(split[0])
if isDomain && isNumber(split[1]) {
return buildProxyString(split[2], split[3], split[0], split[1], false), true
return protoStr + buildProxyString(split[2], split[3], split[0], split[1], false), true
}
_, isDomain = dns.IsDomainName(split[2])
if isDomain && isNumber(split[3]) {
return buildProxyString(split[0], split[1], split[2], split[3], false), true
return protoStr + buildProxyString(split[0], split[1], split[2], split[3], false), true
}
if _, err := netip.ParseAddrPort(split[2] + ":" + split[3]); err == nil {
return buildProxyString(split[0], split[1], split[2], split[3], false), true
return protoStr + buildProxyString(split[0], split[1], split[2], split[3], false), true
}
if _, err := netip.ParseAddrPort(split[0] + ":" + split[1]); err == nil {
return buildProxyString(split[2], split[3], split[0], split[1], false), true
return protoStr + buildProxyString(split[2], split[3], split[0], split[1], false), true
}
default:
if !strings.Contains(in, "[") || !strings.Contains(in, "]:") {
return "", false
}
}
return filterv6(in)
v6Filt, v6Ok := filterv6(in)
if v6Ok {
v6Filt = protoStr + v6Filt
}
return v6Filt, v6Ok
}

View File

@ -1,6 +1,9 @@
package prox5
import "testing"
import (
"strings"
"testing"
)
func Test_filter(t *testing.T) {
type args struct {
@ -45,6 +48,14 @@ func Test_filter(t *testing.T) {
wantFiltered: "yeet.com:1080",
wantOk: true,
},
{
name: "simpleDomainWithAlpha",
args: args{
in: "YeEt.com:1080",
},
wantFiltered: "yeet.com:1080",
wantOk: true,
},
{
name: "domainWithAuth",
args: args{
@ -53,6 +64,14 @@ func Test_filter(t *testing.T) {
wantFiltered: "user:pass@yeet.com:1080",
wantOk: true,
},
{
name: "domainWithAuthWithAlpha",
args: args{
in: "YeEt.com:1080:uSer:pAss",
},
wantFiltered: "uSer:pAss@yeet.com:1080",
wantOk: true,
},
{
name: "ipv6",
args: args{
@ -78,15 +97,57 @@ func Test_filter(t *testing.T) {
wantOk: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotFiltered, gotOk := filter(tt.args.in)
if gotFiltered != tt.wantFiltered {
t.Errorf("filter() gotFiltered = %v, want %v", gotFiltered, tt.wantFiltered)
var prefixTests = []test{{
name: "invalid prefix",
args: args{
in: "yeet://yeet.com:1080",
},
wantFiltered: "",
wantOk: false,
}}
for protoStr, protoPrefix := range protoStrs {
for _, tt := range tests {
if strings.Contains(tt.name, "invalid") || strings.Contains(tt.name, "prefix") {
continue
}
if gotOk != tt.wantOk {
t.Errorf("filter() gotOk = %v, want %v", gotOk, tt.wantOk)
}
})
tt.args.in = protoPrefix + tt.args.in
tt.name = tt.name + " with " + protoStr + " prefix"
prefixTests = append(prefixTests, test{
name: tt.name,
args: tt.args,
wantFiltered: protoPrefix + tt.wantFiltered,
wantOk: tt.wantOk,
})
}
}
t.Run("tests without protocol prefixes", func(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotFiltered, gotOk := filter(tt.args.in)
if gotFiltered != tt.wantFiltered {
t.Errorf("filter() gotFiltered = %v, want %v", gotFiltered, tt.wantFiltered)
}
if gotOk != tt.wantOk {
t.Errorf("filter() gotOk = %v, want %v", gotOk, tt.wantOk)
}
})
}
})
t.Run("tests with protocol prefixes", func(t *testing.T) {
for _, tt := range prefixTests {
t.Run(tt.name, func(t *testing.T) {
gotFiltered, gotOk := filter(tt.args.in)
if gotFiltered != tt.wantFiltered {
t.Errorf("filter() gotFiltered = %v, want %v", gotFiltered, tt.wantFiltered)
}
if gotOk != tt.wantOk {
t.Errorf("filter() gotOk = %v, want %v", gotOk, tt.wantOk)
}
})
}
})
}

View File

@ -1,6 +1,7 @@
package prox5
import (
"strings"
"sync"
"sync/atomic"
@ -16,12 +17,58 @@ const (
ProtoSOCKS4a
ProtoSOCKS5
ProtoHTTP
ProtoHTTPS
ProtoSSH
)
var protoMap = map[ProxyProtocol]string{
ProtoSOCKS5: "socks5", ProtoNull: "unknown", ProtoSOCKS4: "socks4", ProtoSOCKS4a: "socks4a",
}
var strToProto = map[string]ProxyProtocol{
"socks5": ProtoSOCKS5, "socks4": ProtoSOCKS4, "socks4a": ProtoSOCKS4a, "http": ProtoHTTP, "ssh": ProtoSSH,
}
func protoFromStr(s string) (ProxyProtocol, bool) {
if strings.Contains(s, "://") {
s = strings.Split(s, "://")[0]
}
prot, ok := strToProto[s]
if !ok {
prot = ProtoNull
}
return prot, ok
}
var protoStrs = map[string]string{
"socks5": "socks5://",
"socks4": "socks4://",
"socks4a": "socks4://",
"http": "http://",
"https": "https://",
"ssh": "ssh://",
}
func protoStrNormalize(s string) (protoStr string, cleaned string, ok bool) {
cleaned = s
if !strings.Contains(s, "://") {
return
}
cleaned = strings.Split(cleaned, "://")[1]
s = strings.ToLower(strings.Split(s, "://")[0])
protoStr, ok = protoStrs[s]
return
}
func extractProtoFromProxyString(s string) (prot ProxyProtocol, cleaned string) {
cleaned = s
prot, _ = protoFromStr(s)
if prot != ProtoNull {
cleaned = strings.Split(s, "://")[1]
}
return prot, cleaned
}
func (p ProxyProtocol) String() string {
return protoMap[p]
}

View File

@ -202,6 +202,8 @@ func TestProx5(t *testing.T) {
}
var successCount int64 = 0
var fin = &atomic.Bool{}
fin.Store(false)
makeReq := func() {
select {
@ -211,7 +213,7 @@ func TestProx5(t *testing.T) {
}
resp, err := p5.GetHTTPClient().Get("http://127.0.0.1:8055")
if err != nil && !errors.Is(err, ErrNoProxies) && !errors.Is(err, net.ErrClosed) {
if !fin.Load() && err != nil && !errors.Is(err, ErrNoProxies) && !errors.Is(err, net.ErrClosed) {
t.Error("[FAIL] " + err.Error())
}
if err != nil && errors.Is(err, ErrNoProxies) {
@ -261,6 +263,7 @@ testLoop:
if err := p5.Close(); err != nil {
t.Fatal(err)
}
fin.Store(true)
// let the proxy engine close gracefully
time.Sleep(time.Second * 5)
}

View File

@ -26,6 +26,9 @@ const (
type Proxy struct {
// Endpoint is the address:port of the proxy that we connect to
Endpoint string
sshDialer *SSHDialer
// ProxiedIP is the address that we end up having when making proxied requests through this proxy
// TODO: parse this and store as flat int type
ProxiedIP string

190
ssh.go Normal file
View File

@ -0,0 +1,190 @@
package prox5
import (
"context"
"errors"
"fmt"
"io"
"net"
"net/url"
"os"
"runtime"
"strings"
"sync"
"time"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
)
type SSHDialer struct {
host string
clientConfig *ssh.ClientConfig
clientConn *ssh.Client
timeout time.Duration
mu sync.RWMutex
}
func ioClose(closer io.Closer) {
_ = closer.Close() // we don't care about this error. consider it "handled"
}
func getSignersFromSocket(uri string) (signers []ssh.Signer, err error) {
if strings.Contains(uri, "://") {
if uriSplit := strings.Split(uri, "://"); len(uriSplit) == 2 {
uri = uriSplit[1]
}
}
var conn net.Conn
if conn, err = net.Dial("unix", uri); err != nil {
return nil, fmt.Errorf("failed to connect to ssh-agent: %w", err)
}
defer ioClose(conn)
sshAgent := agent.NewClient(conn)
if signers, err = sshAgent.Signers(); err != nil {
return nil, fmt.Errorf("failed to get signers from ssh-agent socket: %w", err)
}
if len(signers) == 0 {
return nil, errors.New("no signers provided by ssh-agent socket")
}
return signers, nil
}
func NewPasswordSSHDialer(endpoint string, user, pass string) *SSHDialer {
clientConfig := &ssh.ClientConfig{
User: user,
Auth: []ssh.AuthMethod{ssh.Password(pass)},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
}
clientConfig.SetDefaults()
return &SSHDialer{
clientConfig: clientConfig,
host: endpoint,
}
}
func NewAgentSSHDialer(endpoint string, user string, signers ...any) (*SSHDialer, error) {
agentURI := os.Getenv("SSH_AUTH_SOCK")
if len(signers) == 0 && (agentURI == "" || runtime.GOOS == "windows") {
return nil, errors.New("no signers provided and no SSH_AUTH_SOCK available")
}
var sshSigners []ssh.Signer
signersProvided := false
for i, signer := range signers {
switch castedSigner := signer.(type) {
case ssh.Signer:
if i == 0 {
signersProvided = true
}
if !signersProvided {
return nil, errors.New("multiple signers provided but they aren't all ssh.Signer")
}
sshSigners = append(sshSigners, castedSigner)
case url.URL:
if signersProvided {
return nil, errors.New("multiple signers provided but they aren't all ssh.Signer")
}
agentURI = castedSigner.String()
case string:
if signersProvided {
return nil, errors.New("multiple signers provided but they aren't all ssh.Signer")
}
agentURI = castedSigner
}
}
if !signersProvided {
var err error
sshSigners, err = getSignersFromSocket(agentURI)
if err != nil {
return nil, err // these are wrapped
}
}
clientConfig := &ssh.ClientConfig{
User: user,
Auth: []ssh.AuthMethod{ssh.PublicKeys(sshSigners...)},
}
clientConfig.SetDefaults()
clientConfig.HostKeyCallback = ssh.InsecureIgnoreHostKey()
return &SSHDialer{
clientConfig: clientConfig,
host: endpoint,
}, nil
}
func (sshd *SSHDialer) WithHostKeyVerification(callback ssh.HostKeyCallback) *SSHDialer {
sshd.clientConfig.HostKeyCallback = callback
return sshd
}
func (sshd *SSHDialer) WithTimeout(timeout time.Duration) *SSHDialer {
sshd.timeout = timeout
return sshd
}
func (sshd *SSHDialer) Close() error {
sshd.mu.Lock()
defer sshd.mu.Unlock()
if sshd.clientConn == nil {
return nil
}
err := sshd.clientConn.Close()
sshd.clientConn = nil
return err
}
type dialRes struct {
conn net.Conn
err error
}
func (sshd *SSHDialer) dial(resChan chan dialRes, network, addr string) {
sshd.mu.RLock()
if sshd.clientConn == nil {
sshd.mu.RUnlock()
sshd.mu.Lock()
var err error
sshd.clientConn, err = ssh.Dial("tcp", sshd.host, sshd.clientConfig)
if err != nil {
if sshd.clientConn != nil {
ioClose(sshd.clientConn)
}
sshd.clientConn = nil
sshd.mu.Unlock()
resChan <- dialRes{nil, err}
return
}
sshd.mu.Unlock()
sshd.mu.RLock()
}
sshd.mu.RUnlock()
c, e := sshd.clientConn.Dial(network, addr)
resChan <- dialRes{c, e}
}
func (sshd *SSHDialer) DialCtx(ctx context.Context, network, addr string) (net.Conn, error) {
resChan := make(chan dialRes)
go sshd.dial(resChan, network, addr)
select {
case <-ctx.Done():
return nil, fmt.Errorf("context cancelled: %w", ctx.Err())
case res := <-resChan:
return res.conn, res.err
}
}
func (sshd *SSHDialer) Dial(network, addr string) (net.Conn, error) {
if sshd.timeout == 0 {
return sshd.DialCtx(context.Background(), network, addr)
}
ctx, cancel := context.WithTimeout(context.Background(), sshd.timeout)
defer cancel()
return sshd.DialCtx(ctx, network, addr)
}

316
ssh_test.go Normal file
View File

@ -0,0 +1,316 @@
package prox5
import (
"context"
"crypto/rand"
"crypto/rsa"
"fmt"
"io"
"net"
"net/http"
"net/http/httptest"
"strconv"
"strings"
"sync/atomic"
"testing"
"github.com/davecgh/go-spew/spew"
"golang.org/x/crypto/ssh"
)
const (
testUser = "yeetersonmcgee"
testPass = "yeetinemallday"
)
type testSSHServer struct {
listener net.Listener
config *ssh.ServerConfig
t *testing.T
errChan chan error
closed *atomic.Bool
testHTTP *atomic.Value
}
func newTestSSHServer(t *testing.T) *testSSHServer {
s := &testSSHServer{
t: t,
errChan: make(chan error, 1), // Buffered channel to handle non-blocking error reporting
closed: &atomic.Bool{},
testHTTP: &atomic.Value{},
}
key := genKey()
signer := signerFromKey(key)
s.config = &ssh.ServerConfig{
PasswordCallback: func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
if conn.User() == testUser && string(password) == testPass {
return nil, nil
}
return nil, ssh.ErrNoAuth
},
}
s.config.AddHostKey(signer)
s.closed.Store(false)
return s
}
func genKey() *rsa.PrivateKey {
k, _ := rsa.GenerateKey(rand.Reader, 2048)
return k
}
func signerFromKey(key *rsa.PrivateKey) ssh.Signer {
signer, _ := ssh.NewSignerFromKey(key)
return signer
}
func (s *testSSHServer) start() string {
var err error
if s.listener, err = net.Listen("tcp", "127.0.0.1:0"); err != nil {
s.t.Fatal(err)
}
go s.handler()
return s.listener.Addr().String()
}
func (s *testSSHServer) handler() {
for {
conn, err := s.listener.Accept()
if err != nil && !strings.Contains(err.Error(), "use of closed") {
s.errChan <- err // Send the error to the channel
return
}
go s.handleConnection(conn)
}
}
type tcpIPRequest struct {
HostToConnect string
PortToConnect uint32
OriginatorIPAddress string
OriginatorPort uint32
}
func (s *testSSHServer) testHTTPServer() string {
if s.testHTTP.Load() != nil {
return s.testHTTP.Load().(string)
}
serv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte("<b>yeet</b>"))
}))
s.testHTTP.Store(serv.URL)
s.t.Logf("test HTTP server started on %s", serv.URL)
s.t.Cleanup(serv.Close)
return serv.URL
}
func (s *testSSHServer) handleDirectTCPIP(req *ssh.Request, channel io.ReadWriteCloser, reply bool) {
// direct-tcpip request data structure as per RFC 4254, section 7.2
data := &tcpIPRequest{}
if req == nil {
return
}
if err := ssh.Unmarshal(req.Payload, data); err != nil {
s.t.Errorf("Failed to unmarshal direct-tcpip request: %v", err)
channel.Close()
return
}
s.t.Logf("direct-tcpip request: %+v", data)
srvURL := s.testHTTPServer()
s.t.Logf("faking connection to remote host: %s:%d", data.HostToConnect, data.PortToConnect)
prt, _ := strconv.Atoi(strings.Split(strings.TrimPrefix(srvURL, "http://"), ":")[1])
data.HostToConnect = strings.TrimSuffix(strings.Split(srvURL, "://")[1], fmt.Sprintf(":%d", prt))
data.PortToConnect = uint32(prt)
remoteConn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", data.HostToConnect, data.PortToConnect))
if err != nil {
s.t.Logf("Failed to connect to remote host: %s:%d, error: %v", data.HostToConnect, data.PortToConnect, err)
if reply {
_ = req.Reply(false, nil)
}
_ = channel.Close()
return
}
s.t.Logf("connected to remote host: %s", remoteConn.RemoteAddr())
if reply {
s.t.Logf("replying to request")
if err = req.Reply(true, nil); err != nil {
s.t.Errorf("Failed to reply to request: %v", err)
_ = channel.Close()
_ = remoteConn.Close()
return
}
}
go func() {
defer func() {
_ = channel.Close()
_ = remoteConn.Close()
}()
if _, err = io.Copy(channel, remoteConn); err != nil && !strings.Contains(err.Error(), "use of closed") {
s.t.Errorf("failed to copy from remote to channel: %v", err)
}
}()
go func() {
defer func() {
_ = channel.Close()
_ = remoteConn.Close()
}()
if _, err = io.Copy(remoteConn, channel); err != nil && !strings.Contains(err.Error(), "use of closed") {
s.t.Errorf("failed to copy from channel to remote: %v", err)
}
}()
}
func (s *testSSHServer) handleConnection(conn net.Conn) {
if conn == nil {
return
}
ctx, cancel := context.WithCancel(context.Background())
defer func() {
cancel()
if conn != nil {
_ = conn.Close()
}
}()
clientConn, channels, requests, err := ssh.NewServerConn(conn, s.config)
if err != nil && !strings.Contains(err.Error(), "use of closed") {
s.t.Logf("failed to establish server connection: %v", err)
return
}
if clientConn == nil {
return
}
go func() {
for {
select {
case <-ctx.Done():
return
case req := <-requests:
s.t.Log(spew.Sdump(req))
case newChannel := <-channels:
s.t.Logf("new channel: %s", newChannel.ChannelType())
switch newChannel.ChannelType() {
case "direct-tcpip":
tcpIPChan, tcpIPReqs, chanErr := newChannel.Accept()
if chanErr != nil {
s.t.Errorf("failed to accept direct-tcpip channel: %v", chanErr)
return
}
s.t.Logf("accepted direct-tcpip channel")
if len(newChannel.ExtraData()) > 0 {
go s.handleDirectTCPIP(&ssh.Request{
Type: "direct-tcpip",
WantReply: true,
Payload: newChannel.ExtraData(),
}, tcpIPChan, false)
}
go func() {
for {
select {
case <-ctx.Done():
return
case req := <-tcpIPReqs:
go s.handleDirectTCPIP(req, tcpIPChan, false)
}
}
}()
default:
s.t.Errorf("unhandled channel type: %s", newChannel.ChannelType())
if err = newChannel.Reject(ssh.UnknownChannelType, "unhandled channel type"); err != nil {
s.t.Errorf("failed to reject channel: %v", err)
}
_ = clientConn.Close()
}
}
}
}()
s.t.Logf("new connection from %s", clientConn.RemoteAddr())
if err := clientConn.Wait(); err != nil {
s.t.Logf("failed to wait for client connection: %v", err)
}
}
func (s *testSSHServer) stop() {
if err := s.listener.Close(); err != nil {
s.t.Errorf("failed to close listener: %v", err)
}
select {
case err := <-s.errChan:
s.t.Logf("server stopped with error: %v", err)
default:
}
close(s.errChan)
}
func TestSSHDialer(t *testing.T) {
t.Run("TestSuccessfulConnection", func(t *testing.T) {
server := newTestSSHServer(t)
serverAddr := server.start()
defer server.stop()
dialer := NewPasswordSSHDialer(serverAddr, testUser, testPass)
conn, err := dialer.Dial("tcp", "google.com:80")
if err != nil {
t.Fatalf("failed to establish connection: %v", err)
}
var n int
if n, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: google.com\r\n\r\n")); err != nil {
t.Fatalf("failed to write to connection: %v", err)
}
t.Logf("[client] wrote %d bytes to connection", n)
buf := make([]byte, 1024)
n, err = conn.Read(buf)
if err != nil {
t.Fatalf("failed to read from connection: %v", err)
}
t.Logf("[client] read %d bytes from connection", n)
t.Log(string(buf[:n]))
if !strings.Contains(string(buf[:n]), "<b>yeet</b>") {
t.Fatal("expected response to contain '<b>yeet</b>'")
}
if err = conn.Close(); err != nil {
t.Fatalf("failed to close connection: %v", err)
}
})
t.Run("TestFailedAuthentication", func(t *testing.T) {
server := newTestSSHServer(t)
serverAddr := server.start()
defer server.stop()
dialer := NewPasswordSSHDialer(serverAddr, testUser, "yeet5")
conn, err := dialer.Dial("tcp", "google.com:80")
if err == nil {
if conn != nil {
if err = conn.Close(); err != nil {
t.Errorf("failed to close connection: %v", err)
}
}
t.Fatalf("expected authentication error, got none")
}
})
}

View File

@ -159,9 +159,14 @@ func (p5 *ProxyEngine) announceValidating(sock *Proxy, presplit string) {
return
}
s := strs.Get()
s.MustWriteString("validating ")
s.MustWriteString(sock.GetProto().String())
s.MustWriteString("://")
if knownProto := sock.GetProto(); knownProto != ProtoNull {
s.MustWriteString("re-validating known proxy: ")
s.MustWriteString(knownProto.String())
s.MustWriteString("://")
} else {
s.MustWriteString("validating unknown proxy: ")
}
s.MustWriteString(presplit)
p5.dbgPrint(s)
@ -175,7 +180,7 @@ func (p5 *ProxyEngine) singleProxyCheck(sock *Proxy, protocol ProxyProtocol) err
endpoint = split[1]
}
// p5.announceValidating(sock, endpoint)
p5.announceValidating(sock, endpoint)
conn, err := net.DialTimeout("tcp", endpoint, p5.GetValidationTimeout())
if err != nil {
@ -230,7 +235,7 @@ func (sock *Proxy) validate() {
// TODO: consider giving the option for verbose logging of this stuff?
switch {
case sock.timesValidated == 0, sock.protocol.Get() == ProtoNull:
case sock.timesValidated == 0 && sock.protocol.Get() == ProtoNull:
// try to use the proxy with all 3 SOCKS versions
for tryProto := range protoMap {
if tryProto == ProtoNull {