Finally do the thing we set out to do in the beginning >_>

This commit is contained in:
kayos@tcp.direct 2022-11-12 22:40:13 -08:00
parent c58518b71a
commit 7f2e315ca8
Signed by: kayos
GPG Key ID: 4B841471B4BEE979
6 changed files with 149 additions and 85 deletions

View File

@ -4,10 +4,10 @@ import (
"context"
"errors"
"fmt"
"net/http"
"time"
"github.com/hashicorp/go-multierror"
"github.com/valyala/fasthttp"
http "github.com/valyala/fasthttp"
)
type MyIPDetails struct {
@ -15,12 +15,12 @@ type MyIPDetails struct {
V6 *IPDetails `json:"ipv6,omitempty"`
}
func CheckIP4(ctx context.Context, h *http.Client) (details *IPDetails, err error) {
return checkIP(ctx, false)
func CheckIP4() (details *IPDetails, err error) {
return checkIP(false)
}
func CheckIP6(ctx context.Context, h *http.Client) (details *IPDetails, err error) {
return checkIP(ctx, true)
func CheckIP6() (details *IPDetails, err error) {
return checkIP(true)
}
func CheckIP(ctx context.Context) (*MyIPDetails, error) {
@ -35,7 +35,7 @@ func CheckIP(ctx context.Context) (*MyIPDetails, error) {
check := func(resChan chan result, ipv6 bool) error {
var err error
var r = result{ipv6: ipv6}
r.details, err = checkIP(ctx, r.ipv6)
r.details, err = checkIP(r.ipv6)
if err != nil {
if r.ipv6 {
err = fmt.Errorf("error checking ipv6: %w", err)
@ -93,7 +93,7 @@ func CheckIP(ctx context.Context) (*MyIPDetails, error) {
return myip, err
}
func checkIP(ctx context.Context, ipv6 bool) (details *IPDetails, err error) {
func checkIP(ipv6 bool) (details *IPDetails, err error) {
var target string
switch ipv6 {
case true:
@ -101,19 +101,19 @@ func checkIP(ctx context.Context, ipv6 bool) (details *IPDetails, err error) {
default:
target = EndpointCheck4
}
req := fasthttp.AcquireRequest()
res := fasthttp.AcquireResponse()
req := http.AcquireRequest()
res := http.AcquireResponse()
defer func() {
fasthttp.ReleaseRequest(req)
fasthttp.ReleaseResponse(res)
http.ReleaseRequest(req)
http.ReleaseResponse(res)
}()
req.SetRequestURI(target)
req.Header.SetMethod("GET")
req.Header.SetMethod(http.MethodGet)
req.Header.SetUserAgent(useragent)
client := fasthttp.Client{}
client := http.Client{}
client.DialDualStack = true
err = client.Do(req, res)
err = client.DoTimeout(req, res, 15*time.Second)
if err != nil {
return
}
@ -130,9 +130,9 @@ func checkIP(ctx context.Context, ipv6 bool) (details *IPDetails, err error) {
// Returns the mullvad server you are connected to if any, and any error that occured
//
//goland:noinspection GoNilness
func (r *Checker) AmIMullvad(ctx context.Context) (MullvadServer, error) {
func (c *Checker) AmIMullvad(ctx context.Context) (MullvadServer, error) {
me, err := CheckIP(ctx)
if me == nil || me.V4 == nil && me.V6 == nil {
if me == nil || me.V4 == nil && me.V6 == nil || err != nil {
return MullvadServer{}, err
}
if me.V4 != nil && !me.V4.MullvadExitIP {
@ -142,7 +142,7 @@ func (r *Checker) AmIMullvad(ctx context.Context) (MullvadServer, error) {
return MullvadServer{}, err
}
err = r.Update()
err = c.Update()
if err != nil {
return MullvadServer{}, err
}
@ -150,14 +150,14 @@ func (r *Checker) AmIMullvad(ctx context.Context) (MullvadServer, error) {
isMullvad := false
if me.V4 != nil && me.V4.MullvadExitIP {
isMullvad = true
if r.Has(me.V4.MullvadExitIPHostname) {
return r.Get(me.V4.MullvadExitIPHostname), nil
if c.Has(me.V4.MullvadExitIPHostname) {
return c.Get(me.V4.MullvadExitIPHostname), nil
}
}
if me.V6 != nil && me.V6.MullvadExitIP {
isMullvad = true
if r.Has(me.V6.MullvadExitIPHostname) {
return r.Get(me.V6.MullvadExitIPHostname), nil
if c.Has(me.V6.MullvadExitIPHostname) {
return c.Get(me.V6.MullvadExitIPHostname), nil
}
}
if isMullvad {

View File

@ -2,7 +2,6 @@ package mullsox
import (
"context"
"net/http"
"testing"
"time"
@ -10,8 +9,7 @@ import (
)
func TestCheckIP4(t *testing.T) {
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(15*time.Second))
v4, err := CheckIP4(ctx, http.DefaultClient)
v4, err := CheckIP4()
if err != nil {
t.Fatalf("%s", err.Error())
}
@ -20,12 +18,10 @@ func TestCheckIP4(t *testing.T) {
t.Fatalf("%s", err4j.Error())
}
t.Logf(string(v4j))
cancel()
}
func TestCheckIP6(t *testing.T) {
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(15*time.Second))
v6, err := CheckIP6(ctx, http.DefaultClient)
v6, err := CheckIP6()
if err != nil {
t.Fatalf("%s", err.Error())
}
@ -34,7 +30,6 @@ func TestCheckIP6(t *testing.T) {
t.Fatalf("%s", err6j.Error())
}
t.Logf(string(v6j))
cancel()
}
func TestCheckIPConcurrent(t *testing.T) {
@ -70,7 +65,7 @@ func TestCheckIPConcurrent(t *testing.T) {
}
func TestAmIMullvad(t *testing.T) {
servers := NewRelays()
servers := NewChecker()
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(15*time.Second))
am, err := servers.AmIMullvad(ctx)
if err != nil {

View File

@ -1,11 +1,10 @@
package mullsox
import (
"net/http"
"sync"
jsoniter "github.com/json-iterator/go"
"github.com/valyala/fasthttp"
http "github.com/valyala/fasthttp"
)
var json = jsoniter.ConfigCompatibleWithStandardLibrary
@ -15,13 +14,13 @@ func (mvs MullvadServer) String() string {
}
type Checker struct {
m map[string]MullvadServer
size int
url string
m map[string]MullvadServer
cachedSize int
url string
*sync.RWMutex
}
func NewRelays() *Checker {
func NewChecker() *Checker {
r := &Checker{
m: make(map[string]MullvadServer),
RWMutex: &sync.RWMutex{},
@ -30,88 +29,88 @@ func NewRelays() *Checker {
return r
}
func (r *Checker) Slice() []MullvadServer {
r.RLock()
defer r.RUnlock()
func (c *Checker) Slice() []MullvadServer {
c.RLock()
defer c.RUnlock()
var servers []MullvadServer
for _, server := range r.m {
for _, server := range c.m {
servers = append(servers, server)
}
return servers
}
func (r *Checker) Has(hostname string) bool {
r.RLock()
_, ok := r.m[hostname]
r.RUnlock()
func (c *Checker) Has(hostname string) bool {
c.RLock()
_, ok := c.m[hostname]
c.RUnlock()
return ok
}
func (r *Checker) Add(server MullvadServer) {
r.Lock()
r.m[server.Hostname] = server
r.Unlock()
func (c *Checker) Add(server MullvadServer) {
c.Lock()
c.m[server.Hostname] = server
c.Unlock()
}
func (r *Checker) Get(hostname string) MullvadServer {
r.RLock()
defer r.RUnlock()
return r.m[hostname]
func (c *Checker) Get(hostname string) MullvadServer {
c.RLock()
defer c.RUnlock()
return c.m[hostname]
}
func (r *Checker) clear() {
for k := range r.m {
delete(r.m, k)
func (c *Checker) clear() {
for k := range c.m {
delete(c.m, k)
}
}
func getContentSize(url string) int {
req := fasthttp.AcquireRequest()
res := fasthttp.AcquireResponse()
req := http.AcquireRequest()
res := http.AcquireResponse()
defer func() {
fasthttp.ReleaseRequest(req)
fasthttp.ReleaseResponse(res)
http.ReleaseRequest(req)
http.ReleaseResponse(res)
}()
req.Header.SetUserAgent(useragent)
req.Header.SetMethod(http.MethodHead)
req.SetRequestURI(url)
if err := fasthttp.Do(req, res); err != nil {
if err := http.Do(req, res); err != nil {
return -1
}
return res.Header.ContentLength()
}
func (r *Checker) Update() error {
func (c *Checker) Update() error {
var serverSlice []MullvadServer
if r.size > 0 {
current := getContentSize(r.url)
if current == r.size {
if c.cachedSize > 0 {
latestSize := getContentSize(c.url)
if latestSize == c.cachedSize {
return nil
}
}
req := fasthttp.AcquireRequest()
res := fasthttp.AcquireResponse()
req := http.AcquireRequest()
res := http.AcquireResponse()
defer func() {
fasthttp.ReleaseRequest(req)
fasthttp.ReleaseResponse(res)
http.ReleaseRequest(req)
http.ReleaseResponse(res)
}()
req.Header.SetUserAgent(useragent)
req.Header.SetContentType("application/json")
req.Header.SetMethod(http.MethodGet)
req.SetRequestURI(r.url)
if err := fasthttp.Do(req, res); err != nil {
req.SetRequestURI(c.url)
if err := http.Do(req, res); err != nil {
return err
}
if err := json.Unmarshal(res.Body(), &serverSlice); err != nil {
return err
}
r.Lock()
r.clear()
c.Lock()
c.clear()
for _, server := range serverSlice {
r.m[server.Hostname] = server
c.m[server.Hostname] = server
}
r.size = res.Header.ContentLength()
r.Unlock()
c.cachedSize = res.Header.ContentLength()
c.Unlock()
return nil
}

View File

@ -2,12 +2,10 @@ package mullsox
import (
"testing"
"github.com/davecgh/go-spew/spew"
)
func TestGetMullvadServers(t *testing.T) {
servers := NewRelays()
servers := NewChecker()
update := func() {
err := servers.Update()
@ -19,7 +17,7 @@ func TestGetMullvadServers(t *testing.T) {
t.Run("GetMullvadServers", func(t *testing.T) {
update()
t.Logf(spew.Sdump(servers.Slice()))
// t.Logf(spew.Sdump(servers.Slice()))
})
var last int
var lastSlice []MullvadServer
@ -31,14 +29,14 @@ func TestGetMullvadServers(t *testing.T) {
update()
update()
update()
last = servers.size
last = servers.cachedSize
lastSlice = servers.Slice()
})
t.Run("GetMullvadServersChanged", func(t *testing.T) {
servers.url = "https://api.mullvad.net/www/relays/openvpn/"
update()
if last == servers.size {
t.Fatalf("expected %d to not equal %d", last, servers.size)
if last == servers.cachedSize {
t.Fatalf("expected %d to not equal %d", last, servers.cachedSize)
}
if len(servers.Slice()) == len(lastSlice) {
t.Fatalf("expected %d to not equal %d", len(lastSlice), len(servers.Slice()))

57
sox.go
View File

@ -1,7 +1,58 @@
package mullsox
import "net/netip"
import (
"context"
"fmt"
"net"
"net/netip"
"sync"
"time"
)
func GetSOCKS() (netip.AddrPort, error) {
return netip.AddrPort{}, nil
func persistentResolver(hostname string) []netip.Addr {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
var ips []netip.Addr
for n := 0; n < 5; n++ {
var err error
var res []netip.Addr
go func() {
res, err = net.DefaultResolver.LookupNetIP(ctx, "ip", hostname)
if err == nil && len(res) > 0 {
ips = res
cancel()
}
}()
time.Sleep(1 * time.Second)
}
<-ctx.Done()
return ips
}
func (c *Checker) GetSOCKS() (sox []netip.AddrPort, err error) {
if err = c.Update(); err != nil {
return
}
wg := &sync.WaitGroup{}
for _, serv := range c.m {
wg.Add(1)
go func(endpoint *MullvadServer) {
defer wg.Done()
ips := persistentResolver(endpoint.SocksName)
for _, ip := range ips {
port := uint16(endpoint.SocksPort)
if port == 0 {
port = 1080
}
ap := netip.AddrPortFrom(ip, port)
if ap.IsValid() && ap.Port() > 0 {
sox = append(sox, ap)
continue
}
err = fmt.Errorf("invalid address/port combo: %s", ap.String())
}
}(&serv)
}
wg.Wait()
return
}

21
sox_test.go Normal file
View File

@ -0,0 +1,21 @@
package mullsox
import (
"testing"
)
func TestChecker_GetSOCKS(t *testing.T) {
c := NewChecker()
if err := c.Update(); err != nil {
t.Fatalf("%s", err.Error())
}
gotSox, err := c.GetSOCKS()
if err != nil {
t.Error(err)
}
if len(gotSox) == 0 {
t.Error("expected non-zero length")
}
t.Logf("got %d socks", len(gotSox))
t.Logf("%v", gotSox)
}