Compare commits
33 Commits
main
...
developmen
Author | SHA1 | Date | |
---|---|---|---|
kayos@tcp.direct | adc175e09b | ||
kayos@tcp.direct | 68e6aed64d | ||
kayos@tcp.direct | 21d0867118 | ||
kayos@tcp.direct | 143220da8c | ||
kayos@tcp.direct | 9983d6ba5a | ||
kayos@tcp.direct | 141ddb75f3 | ||
kayos@tcp.direct | b84164bbad | ||
kayos@tcp.direct | 9d88a045f4 | ||
kayos@tcp.direct | e49afae3a2 | ||
kayos@tcp.direct | c21cba86d4 | ||
kayos@tcp.direct | a4012cd5a9 | ||
kayos@tcp.direct | 7b91d42cbf | ||
kayos@tcp.direct | df7b567666 | ||
kayos@tcp.direct | 8592c4f63f | ||
kayos@tcp.direct | aded278c6b | ||
kayos@tcp.direct | 3c6bba9946 | ||
kayos | a1c31c9f14 | ||
kayos@tcp.direct | db3eaf355d | ||
kayos@tcp.direct | 0c61142d51 | ||
kayos@tcp.direct | d8c0a3a5ae | ||
kayos@tcp.direct | 4f3c842f47 | ||
kayos@tcp.direct | 40ed2878dd | ||
kayos@tcp.direct | 7ae3ebd8f6 | ||
kayos@tcp.direct | 05b9e36ac6 | ||
kayos@tcp.direct | bbbc9b74c0 | ||
kayos@tcp.direct | 39bc5ac542 | ||
kayos@tcp.direct | 84977926e1 | ||
kayos@tcp.direct | 2e85c817c7 | ||
kayos@tcp.direct | 503223183a | ||
kayos@tcp.direct | 121543bbf8 | ||
kayos@tcp.direct | 54e641ff21 | ||
kayos@tcp.direct | 69187d7f8e | ||
kayos@tcp.direct | d8f03bc8c6 |
20
daemons.go
20
daemons.go
|
@ -3,6 +3,7 @@ package prox5
|
|||
import (
|
||||
"errors"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.tcp.direct/kayos/common/entropy"
|
||||
|
@ -15,7 +16,17 @@ type proxyMap struct {
|
|||
}
|
||||
|
||||
func (sm proxyMap) add(sock string) (*Proxy, bool) {
|
||||
sm.plot.SetIfAbsent(sock, &Proxy{
|
||||
prot := ProtoNull
|
||||
if strings.Contains(sock, "://") {
|
||||
prot, sock = extractProtoFromProxyString(sock)
|
||||
}
|
||||
|
||||
// if the proxy already exists, noop and !ok
|
||||
if sm.plot.Has(sock) {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
newProx := &Proxy{
|
||||
Endpoint: sock,
|
||||
protocol: newImmutableProto(),
|
||||
lastValidated: time.UnixMilli(0),
|
||||
|
@ -23,7 +34,12 @@ func (sm proxyMap) add(sock string) (*Proxy, bool) {
|
|||
timesBad: 0,
|
||||
parent: sm.parent,
|
||||
lock: stateUnlocked,
|
||||
})
|
||||
}
|
||||
if prot != ProtoNull {
|
||||
newProx.protocol.set(prot)
|
||||
}
|
||||
|
||||
sm.plot.SetIfAbsent(sock, newProx)
|
||||
|
||||
return sm.plot.Get(sock)
|
||||
}
|
||||
|
|
19
defs.go
19
defs.go
|
@ -23,8 +23,8 @@ type proxyList struct {
|
|||
|
||||
func (pl *proxyList) add(p *Proxy) {
|
||||
pl.Lock()
|
||||
defer pl.Unlock()
|
||||
pl.PushBack(p)
|
||||
pl.Unlock()
|
||||
}
|
||||
|
||||
func (pl *proxyList) pop() *Proxy {
|
||||
|
@ -40,14 +40,18 @@ func (pl *proxyList) pop() *Proxy {
|
|||
|
||||
// ProxyChannels will likely be unexported in the future.
|
||||
type ProxyChannels struct {
|
||||
// SOCKS5 is a constant stream of verified SOCKS5 proxies
|
||||
// SOCKS5 is a constant stream of verified SOCKS5 proxies.
|
||||
SOCKS5 proxyList
|
||||
// SOCKS4 is a constant stream of verified SOCKS4 proxies
|
||||
// SOCKS4 is a constant stream of verified SOCKS4 proxies.
|
||||
SOCKS4 proxyList
|
||||
// SOCKS4a is a constant stream of verified SOCKS5 proxies
|
||||
// SOCKS4a is a constant stream of verified SOCKS4a proxies.
|
||||
SOCKS4a proxyList
|
||||
// HTTP is a constant stream of verified SOCKS5 proxies
|
||||
// HTTP is a constant stream of verified HTTP proxies.
|
||||
HTTP proxyList
|
||||
// HTTPS is a constant stream of verified HTTPS proxies.
|
||||
HTTPS proxyList
|
||||
// SSH is a constant stream of verified SSH proxies.
|
||||
SSH proxyList
|
||||
}
|
||||
|
||||
// Slice returns a slice of all proxyLists in ProxyChannels, note that HTTP is not included.
|
||||
|
@ -59,7 +63,7 @@ func (pc ProxyChannels) Slice() []*proxyList {
|
|||
return lists
|
||||
}
|
||||
|
||||
// ProxyEngine represents a proxy pool
|
||||
// ProxyEngine represents a proxy pool. This is the main component of the prox5 package.
|
||||
type ProxyEngine struct {
|
||||
Valids ProxyChannels
|
||||
DebugLogger logger.Logger
|
||||
|
@ -67,6 +71,9 @@ type ProxyEngine struct {
|
|||
// stats holds the Statistics for ProxyEngine
|
||||
stats Statistics
|
||||
|
||||
// Status is the current state of the ProxyEngine.
|
||||
// This is modified with atomics and should only be accessed with atomic.LoadUint32 and atomic.StoreUint32.
|
||||
// Likely to be unexported in the future and replaced with a ProxyEngine method.
|
||||
Status uint32
|
||||
|
||||
// Pending is a constant stream of proxy strings to be verified
|
||||
|
|
33
dispense.go
33
dispense.go
|
@ -1,11 +1,16 @@
|
|||
package prox5
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
func (p5 *ProxyEngine) getSocksStr(proto ProxyProtocol) string {
|
||||
func (p5 *ProxyEngine) GetSocksStr(proto ProxyProtocol) string {
|
||||
return p5.GetSocksStrCtx(context.Background(), proto)
|
||||
}
|
||||
|
||||
func (p5 *ProxyEngine) GetSocksStrCtx(ctx context.Context, proto ProxyProtocol) string {
|
||||
var sock *Proxy
|
||||
var list *proxyList
|
||||
switch proto {
|
||||
|
@ -17,8 +22,21 @@ func (p5 *ProxyEngine) getSocksStr(proto ProxyProtocol) string {
|
|||
list = &p5.Valids.SOCKS5
|
||||
case ProtoHTTP:
|
||||
list = &p5.Valids.HTTP
|
||||
case ProtoHTTPS:
|
||||
list = &p5.Valids.HTTPS
|
||||
case ProtoSSH:
|
||||
list = &p5.Valids.SSH
|
||||
case ProtoNull:
|
||||
return ""
|
||||
default:
|
||||
panic("unknown protocol")
|
||||
}
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ""
|
||||
default:
|
||||
}
|
||||
if list.Len() == 0 {
|
||||
p5.recycling()
|
||||
time.Sleep(250 * time.Millisecond)
|
||||
|
@ -29,6 +47,11 @@ func (p5 *ProxyEngine) getSocksStr(proto ProxyProtocol) string {
|
|||
list.Unlock()
|
||||
switch {
|
||||
case sock == nil:
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ""
|
||||
default:
|
||||
}
|
||||
p5.recycling()
|
||||
time.Sleep(250 * time.Millisecond)
|
||||
continue
|
||||
|
@ -44,24 +67,24 @@ func (p5 *ProxyEngine) getSocksStr(proto ProxyProtocol) string {
|
|||
// Socks5Str gets a SOCKS5 proxy that we have fully verified (dialed and then retrieved our IP address from a what-is-my-ip endpoint.
|
||||
// Will block if one is not available!
|
||||
func (p5 *ProxyEngine) Socks5Str() string {
|
||||
return p5.getSocksStr(ProtoSOCKS5)
|
||||
return p5.GetSocksStr(ProtoSOCKS5)
|
||||
}
|
||||
|
||||
// Socks4Str gets a SOCKS4 proxy that we have fully verified.
|
||||
// Will block if one is not available!
|
||||
func (p5 *ProxyEngine) Socks4Str() string {
|
||||
return p5.getSocksStr(ProtoSOCKS4)
|
||||
return p5.GetSocksStr(ProtoSOCKS4)
|
||||
}
|
||||
|
||||
// Socks4aStr gets a SOCKS4 proxy that we have fully verified.
|
||||
// Will block if one is not available!
|
||||
func (p5 *ProxyEngine) Socks4aStr() string {
|
||||
return p5.getSocksStr(ProtoSOCKS4a)
|
||||
return p5.GetSocksStr(ProtoSOCKS4a)
|
||||
}
|
||||
|
||||
// GetHTTPTunnel checks for an available HTTP CONNECT proxy in our pool.
|
||||
func (p5 *ProxyEngine) GetHTTPTunnel() string {
|
||||
return p5.getSocksStr(ProtoHTTP)
|
||||
return p5.GetSocksStr(ProtoHTTP)
|
||||
}
|
||||
|
||||
// GetAnySOCKS retrieves any version SOCKS proxy as a Proxy type
|
||||
|
|
3
go.mod
3
go.mod
|
@ -6,6 +6,7 @@ require (
|
|||
git.tcp.direct/kayos/common v0.9.7
|
||||
git.tcp.direct/kayos/go-socks5 v0.3.0
|
||||
git.tcp.direct/kayos/socks v0.1.3
|
||||
github.com/davecgh/go-spew v1.1.1
|
||||
github.com/gdamore/tcell/v2 v2.7.1
|
||||
github.com/miekg/dns v1.1.58
|
||||
github.com/ooni/oohttp v0.6.7
|
||||
|
@ -14,6 +15,7 @@ require (
|
|||
github.com/refraction-networking/utls v1.6.0
|
||||
github.com/rivo/tview v0.0.0-20230208211350-7dfff1ce7854
|
||||
github.com/yunginnanet/Rate5 v1.3.0
|
||||
golang.org/x/crypto v0.21.0
|
||||
golang.org/x/net v0.22.0
|
||||
)
|
||||
|
||||
|
@ -27,7 +29,6 @@ require (
|
|||
github.com/patrickmn/go-cache v2.1.0+incompatible // indirect
|
||||
github.com/quic-go/quic-go v0.37.4 // indirect
|
||||
github.com/rivo/uniseg v0.4.3 // indirect
|
||||
golang.org/x/crypto v0.21.0 // indirect
|
||||
golang.org/x/mod v0.14.0 // indirect
|
||||
golang.org/x/sys v0.18.0 // indirect
|
||||
golang.org/x/term v0.18.0 // indirect
|
||||
|
|
44
parse.go
44
parse.go
|
@ -1,11 +1,11 @@
|
|||
package prox5
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
func filterv6(in string) (filtered string, ok bool) {
|
||||
|
@ -50,7 +50,7 @@ func buildProxyString(username, password, address, port string, v6 bool) (result
|
|||
builder.MustWriteString(password)
|
||||
builder.MustWriteString("@")
|
||||
}
|
||||
builder.MustWriteString(address)
|
||||
builder.MustWriteString(strings.ToLower(address))
|
||||
if v6 {
|
||||
builder.MustWriteString("]")
|
||||
}
|
||||
|
@ -60,11 +60,17 @@ func buildProxyString(username, password, address, port string, v6 bool) (result
|
|||
}
|
||||
|
||||
func filter(in string) (filtered string, ok bool) { //nolint:cyclop
|
||||
if !strings.Contains(in, ":") {
|
||||
return "", false
|
||||
protoStr, protoNormalized, protoOK := protoStrNormalize(in)
|
||||
if protoOK {
|
||||
in = protoNormalized
|
||||
}
|
||||
|
||||
split := strings.Split(in, ":")
|
||||
|
||||
if !strings.Contains(in, ":") {
|
||||
in = in + ":1080"
|
||||
}
|
||||
|
||||
if len(split) < 2 {
|
||||
return "", false
|
||||
}
|
||||
|
@ -72,7 +78,7 @@ func filter(in string) (filtered string, ok bool) { //nolint:cyclop
|
|||
case 2:
|
||||
_, isDomain := dns.IsDomainName(split[0])
|
||||
if isDomain && isNumber(split[1]) {
|
||||
return in, true
|
||||
return protoStr + strings.ToLower(in), true
|
||||
}
|
||||
combo, err := netip.ParseAddrPort(in)
|
||||
if err != nil {
|
||||
|
@ -83,40 +89,44 @@ func filter(in string) (filtered string, ok bool) { //nolint:cyclop
|
|||
if !strings.Contains(in, "@") {
|
||||
return "", false
|
||||
}
|
||||
split := strings.Split(in, "@")
|
||||
if !strings.Contains(split[0], ":") {
|
||||
domSplit := strings.Split(in, "@")
|
||||
if !strings.Contains(domSplit[0], ":") {
|
||||
return "", false
|
||||
}
|
||||
splitAuth := strings.Split(split[0], ":")
|
||||
splitServ := strings.Split(split[1], ":")
|
||||
splitAuth := strings.Split(domSplit[0], ":")
|
||||
splitServ := strings.Split(domSplit[1], ":")
|
||||
_, isDomain := dns.IsDomainName(splitServ[0])
|
||||
if isDomain && isNumber(splitServ[1]) {
|
||||
return buildProxyString(splitAuth[0], splitAuth[1],
|
||||
return protoStr + buildProxyString(splitAuth[0], splitAuth[1],
|
||||
splitServ[0], splitServ[1], false), true
|
||||
}
|
||||
if _, err := netip.ParseAddrPort(split[1]); err == nil {
|
||||
return buildProxyString(splitAuth[0], splitAuth[1],
|
||||
if _, err := netip.ParseAddrPort(domSplit[1]); err == nil {
|
||||
return protoStr + buildProxyString(splitAuth[0], splitAuth[1],
|
||||
splitServ[0], splitServ[1], false), true
|
||||
}
|
||||
case 4:
|
||||
_, isDomain := dns.IsDomainName(split[0])
|
||||
if isDomain && isNumber(split[1]) {
|
||||
return buildProxyString(split[2], split[3], split[0], split[1], false), true
|
||||
return protoStr + buildProxyString(split[2], split[3], split[0], split[1], false), true
|
||||
}
|
||||
_, isDomain = dns.IsDomainName(split[2])
|
||||
if isDomain && isNumber(split[3]) {
|
||||
return buildProxyString(split[0], split[1], split[2], split[3], false), true
|
||||
return protoStr + buildProxyString(split[0], split[1], split[2], split[3], false), true
|
||||
}
|
||||
if _, err := netip.ParseAddrPort(split[2] + ":" + split[3]); err == nil {
|
||||
return buildProxyString(split[0], split[1], split[2], split[3], false), true
|
||||
return protoStr + buildProxyString(split[0], split[1], split[2], split[3], false), true
|
||||
}
|
||||
if _, err := netip.ParseAddrPort(split[0] + ":" + split[1]); err == nil {
|
||||
return buildProxyString(split[2], split[3], split[0], split[1], false), true
|
||||
return protoStr + buildProxyString(split[2], split[3], split[0], split[1], false), true
|
||||
}
|
||||
default:
|
||||
if !strings.Contains(in, "[") || !strings.Contains(in, "]:") {
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
return filterv6(in)
|
||||
v6Filt, v6Ok := filterv6(in)
|
||||
if v6Ok {
|
||||
v6Filt = protoStr + v6Filt
|
||||
}
|
||||
return v6Filt, v6Ok
|
||||
}
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
package prox5
|
||||
|
||||
import "testing"
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_filter(t *testing.T) {
|
||||
type args struct {
|
||||
|
@ -45,6 +48,14 @@ func Test_filter(t *testing.T) {
|
|||
wantFiltered: "yeet.com:1080",
|
||||
wantOk: true,
|
||||
},
|
||||
{
|
||||
name: "simpleDomainWithAlpha",
|
||||
args: args{
|
||||
in: "YeEt.com:1080",
|
||||
},
|
||||
wantFiltered: "yeet.com:1080",
|
||||
wantOk: true,
|
||||
},
|
||||
{
|
||||
name: "domainWithAuth",
|
||||
args: args{
|
||||
|
@ -53,6 +64,14 @@ func Test_filter(t *testing.T) {
|
|||
wantFiltered: "user:pass@yeet.com:1080",
|
||||
wantOk: true,
|
||||
},
|
||||
{
|
||||
name: "domainWithAuthWithAlpha",
|
||||
args: args{
|
||||
in: "YeEt.com:1080:uSer:pAss",
|
||||
},
|
||||
wantFiltered: "uSer:pAss@yeet.com:1080",
|
||||
wantOk: true,
|
||||
},
|
||||
{
|
||||
name: "ipv6",
|
||||
args: args{
|
||||
|
@ -78,15 +97,57 @@ func Test_filter(t *testing.T) {
|
|||
wantOk: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotFiltered, gotOk := filter(tt.args.in)
|
||||
if gotFiltered != tt.wantFiltered {
|
||||
t.Errorf("filter() gotFiltered = %v, want %v", gotFiltered, tt.wantFiltered)
|
||||
|
||||
var prefixTests = []test{{
|
||||
name: "invalid prefix",
|
||||
args: args{
|
||||
in: "yeet://yeet.com:1080",
|
||||
},
|
||||
wantFiltered: "",
|
||||
wantOk: false,
|
||||
}}
|
||||
|
||||
for protoStr, protoPrefix := range protoStrs {
|
||||
for _, tt := range tests {
|
||||
if strings.Contains(tt.name, "invalid") || strings.Contains(tt.name, "prefix") {
|
||||
continue
|
||||
}
|
||||
if gotOk != tt.wantOk {
|
||||
t.Errorf("filter() gotOk = %v, want %v", gotOk, tt.wantOk)
|
||||
}
|
||||
})
|
||||
tt.args.in = protoPrefix + tt.args.in
|
||||
tt.name = tt.name + " with " + protoStr + " prefix"
|
||||
prefixTests = append(prefixTests, test{
|
||||
name: tt.name,
|
||||
args: tt.args,
|
||||
wantFiltered: protoPrefix + tt.wantFiltered,
|
||||
wantOk: tt.wantOk,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
t.Run("tests without protocol prefixes", func(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotFiltered, gotOk := filter(tt.args.in)
|
||||
if gotFiltered != tt.wantFiltered {
|
||||
t.Errorf("filter() gotFiltered = %v, want %v", gotFiltered, tt.wantFiltered)
|
||||
}
|
||||
if gotOk != tt.wantOk {
|
||||
t.Errorf("filter() gotOk = %v, want %v", gotOk, tt.wantOk)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("tests with protocol prefixes", func(t *testing.T) {
|
||||
for _, tt := range prefixTests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotFiltered, gotOk := filter(tt.args.in)
|
||||
if gotFiltered != tt.wantFiltered {
|
||||
t.Errorf("filter() gotFiltered = %v, want %v", gotFiltered, tt.wantFiltered)
|
||||
}
|
||||
if gotOk != tt.wantOk {
|
||||
t.Errorf("filter() gotOk = %v, want %v", gotOk, tt.wantOk)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
47
proto.go
47
proto.go
|
@ -1,6 +1,7 @@
|
|||
package prox5
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
|
@ -16,12 +17,58 @@ const (
|
|||
ProtoSOCKS4a
|
||||
ProtoSOCKS5
|
||||
ProtoHTTP
|
||||
ProtoHTTPS
|
||||
ProtoSSH
|
||||
)
|
||||
|
||||
var protoMap = map[ProxyProtocol]string{
|
||||
ProtoSOCKS5: "socks5", ProtoNull: "unknown", ProtoSOCKS4: "socks4", ProtoSOCKS4a: "socks4a",
|
||||
}
|
||||
|
||||
var strToProto = map[string]ProxyProtocol{
|
||||
"socks5": ProtoSOCKS5, "socks4": ProtoSOCKS4, "socks4a": ProtoSOCKS4a, "http": ProtoHTTP, "ssh": ProtoSSH,
|
||||
}
|
||||
|
||||
func protoFromStr(s string) (ProxyProtocol, bool) {
|
||||
if strings.Contains(s, "://") {
|
||||
s = strings.Split(s, "://")[0]
|
||||
}
|
||||
prot, ok := strToProto[s]
|
||||
if !ok {
|
||||
prot = ProtoNull
|
||||
}
|
||||
return prot, ok
|
||||
}
|
||||
|
||||
var protoStrs = map[string]string{
|
||||
"socks5": "socks5://",
|
||||
"socks4": "socks4://",
|
||||
"socks4a": "socks4://",
|
||||
"http": "http://",
|
||||
"https": "https://",
|
||||
"ssh": "ssh://",
|
||||
}
|
||||
|
||||
func protoStrNormalize(s string) (protoStr string, cleaned string, ok bool) {
|
||||
cleaned = s
|
||||
if !strings.Contains(s, "://") {
|
||||
return
|
||||
}
|
||||
cleaned = strings.Split(cleaned, "://")[1]
|
||||
s = strings.ToLower(strings.Split(s, "://")[0])
|
||||
protoStr, ok = protoStrs[s]
|
||||
return
|
||||
}
|
||||
|
||||
func extractProtoFromProxyString(s string) (prot ProxyProtocol, cleaned string) {
|
||||
cleaned = s
|
||||
prot, _ = protoFromStr(s)
|
||||
if prot != ProtoNull {
|
||||
cleaned = strings.Split(s, "://")[1]
|
||||
}
|
||||
return prot, cleaned
|
||||
}
|
||||
|
||||
func (p ProxyProtocol) String() string {
|
||||
return protoMap[p]
|
||||
}
|
||||
|
|
|
@ -202,6 +202,8 @@ func TestProx5(t *testing.T) {
|
|||
}
|
||||
|
||||
var successCount int64 = 0
|
||||
var fin = &atomic.Bool{}
|
||||
fin.Store(false)
|
||||
|
||||
makeReq := func() {
|
||||
select {
|
||||
|
@ -211,7 +213,7 @@ func TestProx5(t *testing.T) {
|
|||
|
||||
}
|
||||
resp, err := p5.GetHTTPClient().Get("http://127.0.0.1:8055")
|
||||
if err != nil && !errors.Is(err, ErrNoProxies) && !errors.Is(err, net.ErrClosed) {
|
||||
if !fin.Load() && err != nil && !errors.Is(err, ErrNoProxies) && !errors.Is(err, net.ErrClosed) {
|
||||
t.Error("[FAIL] " + err.Error())
|
||||
}
|
||||
if err != nil && errors.Is(err, ErrNoProxies) {
|
||||
|
@ -261,6 +263,7 @@ testLoop:
|
|||
if err := p5.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
fin.Store(true)
|
||||
// let the proxy engine close gracefully
|
||||
time.Sleep(time.Second * 5)
|
||||
}
|
||||
|
|
3
proxy.go
3
proxy.go
|
@ -26,6 +26,9 @@ const (
|
|||
type Proxy struct {
|
||||
// Endpoint is the address:port of the proxy that we connect to
|
||||
Endpoint string
|
||||
|
||||
sshDialer *SSHDialer
|
||||
|
||||
// ProxiedIP is the address that we end up having when making proxied requests through this proxy
|
||||
// TODO: parse this and store as flat int type
|
||||
ProxiedIP string
|
||||
|
|
|
@ -0,0 +1,190 @@
|
|||
package prox5
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/url"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
"golang.org/x/crypto/ssh/agent"
|
||||
)
|
||||
|
||||
type SSHDialer struct {
|
||||
host string
|
||||
clientConfig *ssh.ClientConfig
|
||||
clientConn *ssh.Client
|
||||
timeout time.Duration
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func ioClose(closer io.Closer) {
|
||||
_ = closer.Close() // we don't care about this error. consider it "handled"
|
||||
}
|
||||
|
||||
func getSignersFromSocket(uri string) (signers []ssh.Signer, err error) {
|
||||
if strings.Contains(uri, "://") {
|
||||
if uriSplit := strings.Split(uri, "://"); len(uriSplit) == 2 {
|
||||
uri = uriSplit[1]
|
||||
}
|
||||
}
|
||||
var conn net.Conn
|
||||
if conn, err = net.Dial("unix", uri); err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to ssh-agent: %w", err)
|
||||
}
|
||||
defer ioClose(conn)
|
||||
sshAgent := agent.NewClient(conn)
|
||||
if signers, err = sshAgent.Signers(); err != nil {
|
||||
return nil, fmt.Errorf("failed to get signers from ssh-agent socket: %w", err)
|
||||
}
|
||||
if len(signers) == 0 {
|
||||
return nil, errors.New("no signers provided by ssh-agent socket")
|
||||
}
|
||||
return signers, nil
|
||||
}
|
||||
|
||||
func NewPasswordSSHDialer(endpoint string, user, pass string) *SSHDialer {
|
||||
clientConfig := &ssh.ClientConfig{
|
||||
User: user,
|
||||
Auth: []ssh.AuthMethod{ssh.Password(pass)},
|
||||
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||||
}
|
||||
clientConfig.SetDefaults()
|
||||
return &SSHDialer{
|
||||
clientConfig: clientConfig,
|
||||
host: endpoint,
|
||||
}
|
||||
}
|
||||
|
||||
func NewAgentSSHDialer(endpoint string, user string, signers ...any) (*SSHDialer, error) {
|
||||
agentURI := os.Getenv("SSH_AUTH_SOCK")
|
||||
if len(signers) == 0 && (agentURI == "" || runtime.GOOS == "windows") {
|
||||
return nil, errors.New("no signers provided and no SSH_AUTH_SOCK available")
|
||||
}
|
||||
|
||||
var sshSigners []ssh.Signer
|
||||
|
||||
signersProvided := false
|
||||
|
||||
for i, signer := range signers {
|
||||
switch castedSigner := signer.(type) {
|
||||
case ssh.Signer:
|
||||
if i == 0 {
|
||||
signersProvided = true
|
||||
}
|
||||
if !signersProvided {
|
||||
return nil, errors.New("multiple signers provided but they aren't all ssh.Signer")
|
||||
}
|
||||
sshSigners = append(sshSigners, castedSigner)
|
||||
case url.URL:
|
||||
if signersProvided {
|
||||
return nil, errors.New("multiple signers provided but they aren't all ssh.Signer")
|
||||
}
|
||||
agentURI = castedSigner.String()
|
||||
case string:
|
||||
if signersProvided {
|
||||
return nil, errors.New("multiple signers provided but they aren't all ssh.Signer")
|
||||
}
|
||||
agentURI = castedSigner
|
||||
}
|
||||
}
|
||||
|
||||
if !signersProvided {
|
||||
var err error
|
||||
sshSigners, err = getSignersFromSocket(agentURI)
|
||||
if err != nil {
|
||||
return nil, err // these are wrapped
|
||||
}
|
||||
}
|
||||
|
||||
clientConfig := &ssh.ClientConfig{
|
||||
User: user,
|
||||
Auth: []ssh.AuthMethod{ssh.PublicKeys(sshSigners...)},
|
||||
}
|
||||
|
||||
clientConfig.SetDefaults()
|
||||
|
||||
clientConfig.HostKeyCallback = ssh.InsecureIgnoreHostKey()
|
||||
|
||||
return &SSHDialer{
|
||||
clientConfig: clientConfig,
|
||||
host: endpoint,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (sshd *SSHDialer) WithHostKeyVerification(callback ssh.HostKeyCallback) *SSHDialer {
|
||||
sshd.clientConfig.HostKeyCallback = callback
|
||||
return sshd
|
||||
}
|
||||
|
||||
func (sshd *SSHDialer) WithTimeout(timeout time.Duration) *SSHDialer {
|
||||
sshd.timeout = timeout
|
||||
return sshd
|
||||
}
|
||||
|
||||
func (sshd *SSHDialer) Close() error {
|
||||
sshd.mu.Lock()
|
||||
defer sshd.mu.Unlock()
|
||||
if sshd.clientConn == nil {
|
||||
return nil
|
||||
}
|
||||
err := sshd.clientConn.Close()
|
||||
sshd.clientConn = nil
|
||||
return err
|
||||
}
|
||||
|
||||
type dialRes struct {
|
||||
conn net.Conn
|
||||
err error
|
||||
}
|
||||
|
||||
func (sshd *SSHDialer) dial(resChan chan dialRes, network, addr string) {
|
||||
sshd.mu.RLock()
|
||||
if sshd.clientConn == nil {
|
||||
sshd.mu.RUnlock()
|
||||
sshd.mu.Lock()
|
||||
var err error
|
||||
sshd.clientConn, err = ssh.Dial("tcp", sshd.host, sshd.clientConfig)
|
||||
if err != nil {
|
||||
if sshd.clientConn != nil {
|
||||
ioClose(sshd.clientConn)
|
||||
}
|
||||
sshd.clientConn = nil
|
||||
sshd.mu.Unlock()
|
||||
resChan <- dialRes{nil, err}
|
||||
return
|
||||
}
|
||||
sshd.mu.Unlock()
|
||||
sshd.mu.RLock()
|
||||
}
|
||||
sshd.mu.RUnlock()
|
||||
c, e := sshd.clientConn.Dial(network, addr)
|
||||
resChan <- dialRes{c, e}
|
||||
}
|
||||
|
||||
func (sshd *SSHDialer) DialCtx(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
resChan := make(chan dialRes)
|
||||
go sshd.dial(resChan, network, addr)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, fmt.Errorf("context cancelled: %w", ctx.Err())
|
||||
case res := <-resChan:
|
||||
return res.conn, res.err
|
||||
}
|
||||
}
|
||||
|
||||
func (sshd *SSHDialer) Dial(network, addr string) (net.Conn, error) {
|
||||
if sshd.timeout == 0 {
|
||||
return sshd.DialCtx(context.Background(), network, addr)
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), sshd.timeout)
|
||||
defer cancel()
|
||||
return sshd.DialCtx(ctx, network, addr)
|
||||
}
|
|
@ -0,0 +1,316 @@
|
|||
package prox5
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"github.com/davecgh/go-spew/spew"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
const (
|
||||
testUser = "yeetersonmcgee"
|
||||
testPass = "yeetinemallday"
|
||||
)
|
||||
|
||||
type testSSHServer struct {
|
||||
listener net.Listener
|
||||
config *ssh.ServerConfig
|
||||
t *testing.T
|
||||
errChan chan error
|
||||
closed *atomic.Bool
|
||||
testHTTP *atomic.Value
|
||||
}
|
||||
|
||||
func newTestSSHServer(t *testing.T) *testSSHServer {
|
||||
s := &testSSHServer{
|
||||
t: t,
|
||||
errChan: make(chan error, 1), // Buffered channel to handle non-blocking error reporting
|
||||
closed: &atomic.Bool{},
|
||||
testHTTP: &atomic.Value{},
|
||||
}
|
||||
key := genKey()
|
||||
signer := signerFromKey(key)
|
||||
|
||||
s.config = &ssh.ServerConfig{
|
||||
PasswordCallback: func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
|
||||
if conn.User() == testUser && string(password) == testPass {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, ssh.ErrNoAuth
|
||||
},
|
||||
}
|
||||
|
||||
s.config.AddHostKey(signer)
|
||||
|
||||
s.closed.Store(false)
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func genKey() *rsa.PrivateKey {
|
||||
k, _ := rsa.GenerateKey(rand.Reader, 2048)
|
||||
return k
|
||||
}
|
||||
|
||||
func signerFromKey(key *rsa.PrivateKey) ssh.Signer {
|
||||
signer, _ := ssh.NewSignerFromKey(key)
|
||||
return signer
|
||||
}
|
||||
|
||||
func (s *testSSHServer) start() string {
|
||||
|
||||
var err error
|
||||
if s.listener, err = net.Listen("tcp", "127.0.0.1:0"); err != nil {
|
||||
s.t.Fatal(err)
|
||||
}
|
||||
|
||||
go s.handler()
|
||||
|
||||
return s.listener.Addr().String()
|
||||
}
|
||||
|
||||
func (s *testSSHServer) handler() {
|
||||
for {
|
||||
conn, err := s.listener.Accept()
|
||||
if err != nil && !strings.Contains(err.Error(), "use of closed") {
|
||||
s.errChan <- err // Send the error to the channel
|
||||
return
|
||||
}
|
||||
go s.handleConnection(conn)
|
||||
}
|
||||
}
|
||||
|
||||
type tcpIPRequest struct {
|
||||
HostToConnect string
|
||||
PortToConnect uint32
|
||||
OriginatorIPAddress string
|
||||
OriginatorPort uint32
|
||||
}
|
||||
|
||||
func (s *testSSHServer) testHTTPServer() string {
|
||||
if s.testHTTP.Load() != nil {
|
||||
return s.testHTTP.Load().(string)
|
||||
}
|
||||
serv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = w.Write([]byte("<b>yeet</b>"))
|
||||
}))
|
||||
s.testHTTP.Store(serv.URL)
|
||||
s.t.Logf("test HTTP server started on %s", serv.URL)
|
||||
s.t.Cleanup(serv.Close)
|
||||
return serv.URL
|
||||
}
|
||||
|
||||
func (s *testSSHServer) handleDirectTCPIP(req *ssh.Request, channel io.ReadWriteCloser, reply bool) {
|
||||
// direct-tcpip request data structure as per RFC 4254, section 7.2
|
||||
|
||||
data := &tcpIPRequest{}
|
||||
|
||||
if req == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err := ssh.Unmarshal(req.Payload, data); err != nil {
|
||||
s.t.Errorf("Failed to unmarshal direct-tcpip request: %v", err)
|
||||
channel.Close()
|
||||
return
|
||||
}
|
||||
|
||||
s.t.Logf("direct-tcpip request: %+v", data)
|
||||
|
||||
srvURL := s.testHTTPServer()
|
||||
|
||||
s.t.Logf("faking connection to remote host: %s:%d", data.HostToConnect, data.PortToConnect)
|
||||
|
||||
prt, _ := strconv.Atoi(strings.Split(strings.TrimPrefix(srvURL, "http://"), ":")[1])
|
||||
data.HostToConnect = strings.TrimSuffix(strings.Split(srvURL, "://")[1], fmt.Sprintf(":%d", prt))
|
||||
data.PortToConnect = uint32(prt)
|
||||
|
||||
remoteConn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", data.HostToConnect, data.PortToConnect))
|
||||
if err != nil {
|
||||
s.t.Logf("Failed to connect to remote host: %s:%d, error: %v", data.HostToConnect, data.PortToConnect, err)
|
||||
if reply {
|
||||
_ = req.Reply(false, nil)
|
||||
}
|
||||
_ = channel.Close()
|
||||
return
|
||||
}
|
||||
|
||||
s.t.Logf("connected to remote host: %s", remoteConn.RemoteAddr())
|
||||
|
||||
if reply {
|
||||
s.t.Logf("replying to request")
|
||||
if err = req.Reply(true, nil); err != nil {
|
||||
s.t.Errorf("Failed to reply to request: %v", err)
|
||||
_ = channel.Close()
|
||||
_ = remoteConn.Close()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
_ = channel.Close()
|
||||
_ = remoteConn.Close()
|
||||
}()
|
||||
if _, err = io.Copy(channel, remoteConn); err != nil && !strings.Contains(err.Error(), "use of closed") {
|
||||
s.t.Errorf("failed to copy from remote to channel: %v", err)
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
defer func() {
|
||||
_ = channel.Close()
|
||||
_ = remoteConn.Close()
|
||||
}()
|
||||
if _, err = io.Copy(remoteConn, channel); err != nil && !strings.Contains(err.Error(), "use of closed") {
|
||||
s.t.Errorf("failed to copy from channel to remote: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *testSSHServer) handleConnection(conn net.Conn) {
|
||||
if conn == nil {
|
||||
return
|
||||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer func() {
|
||||
cancel()
|
||||
if conn != nil {
|
||||
_ = conn.Close()
|
||||
}
|
||||
}()
|
||||
clientConn, channels, requests, err := ssh.NewServerConn(conn, s.config)
|
||||
if err != nil && !strings.Contains(err.Error(), "use of closed") {
|
||||
s.t.Logf("failed to establish server connection: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if clientConn == nil {
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case req := <-requests:
|
||||
s.t.Log(spew.Sdump(req))
|
||||
case newChannel := <-channels:
|
||||
s.t.Logf("new channel: %s", newChannel.ChannelType())
|
||||
switch newChannel.ChannelType() {
|
||||
case "direct-tcpip":
|
||||
tcpIPChan, tcpIPReqs, chanErr := newChannel.Accept()
|
||||
if chanErr != nil {
|
||||
s.t.Errorf("failed to accept direct-tcpip channel: %v", chanErr)
|
||||
return
|
||||
}
|
||||
s.t.Logf("accepted direct-tcpip channel")
|
||||
if len(newChannel.ExtraData()) > 0 {
|
||||
go s.handleDirectTCPIP(&ssh.Request{
|
||||
Type: "direct-tcpip",
|
||||
WantReply: true,
|
||||
Payload: newChannel.ExtraData(),
|
||||
}, tcpIPChan, false)
|
||||
}
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case req := <-tcpIPReqs:
|
||||
go s.handleDirectTCPIP(req, tcpIPChan, false)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
default:
|
||||
s.t.Errorf("unhandled channel type: %s", newChannel.ChannelType())
|
||||
if err = newChannel.Reject(ssh.UnknownChannelType, "unhandled channel type"); err != nil {
|
||||
s.t.Errorf("failed to reject channel: %v", err)
|
||||
}
|
||||
_ = clientConn.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
s.t.Logf("new connection from %s", clientConn.RemoteAddr())
|
||||
|
||||
if err := clientConn.Wait(); err != nil {
|
||||
s.t.Logf("failed to wait for client connection: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *testSSHServer) stop() {
|
||||
if err := s.listener.Close(); err != nil {
|
||||
s.t.Errorf("failed to close listener: %v", err)
|
||||
}
|
||||
select {
|
||||
case err := <-s.errChan:
|
||||
s.t.Logf("server stopped with error: %v", err)
|
||||
default:
|
||||
}
|
||||
close(s.errChan)
|
||||
}
|
||||
|
||||
func TestSSHDialer(t *testing.T) {
|
||||
t.Run("TestSuccessfulConnection", func(t *testing.T) {
|
||||
server := newTestSSHServer(t)
|
||||
serverAddr := server.start()
|
||||
defer server.stop()
|
||||
|
||||
dialer := NewPasswordSSHDialer(serverAddr, testUser, testPass)
|
||||
conn, err := dialer.Dial("tcp", "google.com:80")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to establish connection: %v", err)
|
||||
}
|
||||
|
||||
var n int
|
||||
if n, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: google.com\r\n\r\n")); err != nil {
|
||||
t.Fatalf("failed to write to connection: %v", err)
|
||||
}
|
||||
|
||||
t.Logf("[client] wrote %d bytes to connection", n)
|
||||
|
||||
buf := make([]byte, 1024)
|
||||
n, err = conn.Read(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read from connection: %v", err)
|
||||
}
|
||||
t.Logf("[client] read %d bytes from connection", n)
|
||||
t.Log(string(buf[:n]))
|
||||
if !strings.Contains(string(buf[:n]), "<b>yeet</b>") {
|
||||
t.Fatal("expected response to contain '<b>yeet</b>'")
|
||||
}
|
||||
if err = conn.Close(); err != nil {
|
||||
t.Fatalf("failed to close connection: %v", err)
|
||||
}
|
||||
})
|
||||
t.Run("TestFailedAuthentication", func(t *testing.T) {
|
||||
server := newTestSSHServer(t)
|
||||
serverAddr := server.start()
|
||||
defer server.stop()
|
||||
|
||||
dialer := NewPasswordSSHDialer(serverAddr, testUser, "yeet5")
|
||||
conn, err := dialer.Dial("tcp", "google.com:80")
|
||||
if err == nil {
|
||||
if conn != nil {
|
||||
if err = conn.Close(); err != nil {
|
||||
t.Errorf("failed to close connection: %v", err)
|
||||
}
|
||||
}
|
||||
t.Fatalf("expected authentication error, got none")
|
||||
}
|
||||
})
|
||||
}
|
|
@ -159,9 +159,14 @@ func (p5 *ProxyEngine) announceValidating(sock *Proxy, presplit string) {
|
|||
return
|
||||
}
|
||||
s := strs.Get()
|
||||
s.MustWriteString("validating ")
|
||||
s.MustWriteString(sock.GetProto().String())
|
||||
s.MustWriteString("://")
|
||||
|
||||
if knownProto := sock.GetProto(); knownProto != ProtoNull {
|
||||
s.MustWriteString("re-validating known proxy: ")
|
||||
s.MustWriteString(knownProto.String())
|
||||
s.MustWriteString("://")
|
||||
} else {
|
||||
s.MustWriteString("validating unknown proxy: ")
|
||||
}
|
||||
s.MustWriteString(presplit)
|
||||
p5.dbgPrint(s)
|
||||
|
||||
|
@ -175,7 +180,7 @@ func (p5 *ProxyEngine) singleProxyCheck(sock *Proxy, protocol ProxyProtocol) err
|
|||
endpoint = split[1]
|
||||
}
|
||||
|
||||
// p5.announceValidating(sock, endpoint)
|
||||
p5.announceValidating(sock, endpoint)
|
||||
|
||||
conn, err := net.DialTimeout("tcp", endpoint, p5.GetValidationTimeout())
|
||||
if err != nil {
|
||||
|
@ -230,7 +235,7 @@ func (sock *Proxy) validate() {
|
|||
// TODO: consider giving the option for verbose logging of this stuff?
|
||||
|
||||
switch {
|
||||
case sock.timesValidated == 0, sock.protocol.Get() == ProtoNull:
|
||||
case sock.timesValidated == 0 && sock.protocol.Get() == ProtoNull:
|
||||
// try to use the proxy with all 3 SOCKS versions
|
||||
for tryProto := range protoMap {
|
||||
if tryProto == ProtoNull {
|
||||
|
|
Loading…
Reference in New Issue