From c58518b71a642b83252590010f987040b3cfec53 Mon Sep 17 00:00:00 2001 From: "kayos@tcp.direct" Date: Fri, 28 Oct 2022 05:54:12 -0700 Subject: [PATCH] Cached relay map --- api.go | 2 ++ check.go | 58 +++++++++++++++++++++++--------------------- check_test.go | 5 ++-- relays.go | 66 ++++++++++++++++++++++++++++++++++++++------------ relays_test.go | 45 ++++++++++++++++++++++++++++++---- sox.go | 7 ++++++ 6 files changed, 133 insertions(+), 50 deletions(-) create mode 100644 sox.go diff --git a/api.go b/api.go index 71409ab..b3d18c6 100644 --- a/api.go +++ b/api.go @@ -1,5 +1,7 @@ package mullsox +const useragent = "mullsox/0.0.1" + const ( baseDomain = "mullvad.net" baseEndpoint = "am.i." + baseDomain diff --git a/check.go b/check.go index 3412d0d..ef728d1 100644 --- a/check.go +++ b/check.go @@ -4,10 +4,10 @@ import ( "context" "errors" "fmt" - "io" "net/http" "github.com/hashicorp/go-multierror" + "github.com/valyala/fasthttp" ) type MyIPDetails struct { @@ -16,14 +16,14 @@ type MyIPDetails struct { } func CheckIP4(ctx context.Context, h *http.Client) (details *IPDetails, err error) { - return checkIP(ctx, h, false) + return checkIP(ctx, false) } func CheckIP6(ctx context.Context, h *http.Client) (details *IPDetails, err error) { - return checkIP(ctx, h, true) + return checkIP(ctx, true) } -func CheckIP(ctx context.Context, h *http.Client) (*MyIPDetails, error) { +func CheckIP(ctx context.Context) (*MyIPDetails, error) { type result struct { details *IPDetails ipv6 bool @@ -35,7 +35,7 @@ func CheckIP(ctx context.Context, h *http.Client) (*MyIPDetails, error) { check := func(resChan chan result, ipv6 bool) error { var err error var r = result{ipv6: ipv6} - r.details, err = checkIP(ctx, h, r.ipv6) + r.details, err = checkIP(ctx, r.ipv6) if err != nil { if r.ipv6 { err = fmt.Errorf("error checking ipv6: %w", err) @@ -93,32 +93,36 @@ func CheckIP(ctx context.Context, h *http.Client) (*MyIPDetails, error) { return myip, err } -func checkIP(ctx context.Context, h *http.Client, ipv6 bool) (details *IPDetails, err error) { - var ( - resp *http.Response - cytes []byte - target string - ) +func checkIP(ctx context.Context, ipv6 bool) (details *IPDetails, err error) { + var target string switch ipv6 { case true: target = EndpointCheck6 default: target = EndpointCheck4 } - req, _ := http.NewRequestWithContext(ctx, "GET", target, nil) - resp, err = h.Do(req) + req := fasthttp.AcquireRequest() + res := fasthttp.AcquireResponse() + defer func() { + fasthttp.ReleaseRequest(req) + fasthttp.ReleaseResponse(res) + }() + req.SetRequestURI(target) + req.Header.SetMethod("GET") + req.Header.SetUserAgent(useragent) + client := fasthttp.Client{} + client.DialDualStack = true + + err = client.Do(req, res) if err != nil { return } - if resp.StatusCode != 200 { - err = fmt.Errorf("bad status code from %s : %s", target, resp.Status) + if res.StatusCode() != http.StatusOK { + err = fmt.Errorf("got status code %d", res.StatusCode()) return } - cytes, err = io.ReadAll(resp.Body) - if err != nil { - return - } - err = json.Unmarshal(cytes, &details) + + err = json.Unmarshal(res.Body(), &details) return } @@ -126,8 +130,8 @@ func checkIP(ctx context.Context, h *http.Client, ipv6 bool) (details *IPDetails // Returns the mullvad server you are connected to if any, and any error that occured // //goland:noinspection GoNilness -func AmIMullvad(ctx context.Context, client *http.Client) (MullvadServer, error) { - me, err := CheckIP(ctx, client) +func (r *Checker) AmIMullvad(ctx context.Context) (MullvadServer, error) { + me, err := CheckIP(ctx) if me == nil || me.V4 == nil && me.V6 == nil { return MullvadServer{}, err } @@ -138,7 +142,7 @@ func AmIMullvad(ctx context.Context, client *http.Client) (MullvadServer, error) return MullvadServer{}, err } - relays, err := GetMullvadServers() + err = r.Update() if err != nil { return MullvadServer{}, err } @@ -146,14 +150,14 @@ func AmIMullvad(ctx context.Context, client *http.Client) (MullvadServer, error) isMullvad := false if me.V4 != nil && me.V4.MullvadExitIP { isMullvad = true - if relays.Has(me.V4.MullvadExitIPHostname) { - return relays.Get(me.V4.MullvadExitIPHostname), nil + if r.Has(me.V4.MullvadExitIPHostname) { + return r.Get(me.V4.MullvadExitIPHostname), nil } } if me.V6 != nil && me.V6.MullvadExitIP { isMullvad = true - if relays.Has(me.V6.MullvadExitIPHostname) { - return relays.Get(me.V6.MullvadExitIPHostname), nil + if r.Has(me.V6.MullvadExitIPHostname) { + return r.Get(me.V6.MullvadExitIPHostname), nil } } if isMullvad { diff --git a/check_test.go b/check_test.go index 4d933de..5498b62 100644 --- a/check_test.go +++ b/check_test.go @@ -39,7 +39,7 @@ func TestCheckIP6(t *testing.T) { func TestCheckIPConcurrent(t *testing.T) { ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(15*time.Second)) - me, err := CheckIP(ctx, http.DefaultClient) + me, err := CheckIP(ctx) if err != nil { t.Fatalf("%s", err.Error()) } @@ -70,8 +70,9 @@ func TestCheckIPConcurrent(t *testing.T) { } func TestAmIMullvad(t *testing.T) { + servers := NewRelays() ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(15*time.Second)) - am, err := AmIMullvad(ctx, http.DefaultClient) + am, err := servers.AmIMullvad(ctx) if err != nil { t.Fatalf("%s", err.Error()) } diff --git a/relays.go b/relays.go index a795735..1c85650 100644 --- a/relays.go +++ b/relays.go @@ -14,21 +14,23 @@ func (mvs MullvadServer) String() string { return mvs.Hostname } -type Relays struct { +type Checker struct { m map[string]MullvadServer size int + url string *sync.RWMutex } -func NewRelays() *Relays { - r := &Relays{ +func NewRelays() *Checker { + r := &Checker{ m: make(map[string]MullvadServer), RWMutex: &sync.RWMutex{}, + url: EndpointRelays, } return r } -func (r *Relays) Slice() []MullvadServer { +func (r *Checker) Slice() []MullvadServer { r.RLock() defer r.RUnlock() var servers []MullvadServer @@ -38,46 +40,78 @@ func (r *Relays) Slice() []MullvadServer { return servers } -func (r *Relays) Has(hostname string) bool { +func (r *Checker) Has(hostname string) bool { r.RLock() _, ok := r.m[hostname] r.RUnlock() return ok } -func (r *Relays) Add(server MullvadServer) { +func (r *Checker) Add(server MullvadServer) { r.Lock() r.m[server.Hostname] = server r.Unlock() } -func (r *Relays) Get(hostname string) MullvadServer { +func (r *Checker) Get(hostname string) MullvadServer { r.RLock() defer r.RUnlock() return r.m[hostname] } -func GetMullvadServers() (*Relays, error) { - var servers = NewRelays() - var serverSlice []MullvadServer +func (r *Checker) clear() { + for k := range r.m { + delete(r.m, k) + } +} + +func getContentSize(url string) int { req := fasthttp.AcquireRequest() res := fasthttp.AcquireResponse() defer func() { fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(res) }() - req.Header.SetUserAgent("mulls0x/v0.0.1") + req.Header.SetUserAgent(useragent) + req.Header.SetMethod(http.MethodHead) + req.SetRequestURI(url) + if err := fasthttp.Do(req, res); err != nil { + return -1 + } + return res.Header.ContentLength() +} + +func (r *Checker) Update() error { + var serverSlice []MullvadServer + if r.size > 0 { + current := getContentSize(r.url) + if current == r.size { + return nil + } + } + + req := fasthttp.AcquireRequest() + res := fasthttp.AcquireResponse() + defer func() { + fasthttp.ReleaseRequest(req) + fasthttp.ReleaseResponse(res) + }() + req.Header.SetUserAgent(useragent) req.Header.SetContentType("application/json") req.Header.SetMethod(http.MethodGet) - req.SetRequestURI(EndpointRelays) + req.SetRequestURI(r.url) if err := fasthttp.Do(req, res); err != nil { - return nil, err + return err } if err := json.Unmarshal(res.Body(), &serverSlice); err != nil { - return nil, err + return err } + r.Lock() + r.clear() for _, server := range serverSlice { - servers.Add(server) + r.m[server.Hostname] = server } - return servers, nil + r.size = res.Header.ContentLength() + r.Unlock() + return nil } diff --git a/relays_test.go b/relays_test.go index f243a63..88fab57 100644 --- a/relays_test.go +++ b/relays_test.go @@ -1,12 +1,47 @@ package mullsox -import "testing" +import ( + "testing" + + "github.com/davecgh/go-spew/spew" +) func TestGetMullvadServers(t *testing.T) { - servers, err := GetMullvadServers() - if err != nil { - t.Fatalf("%s", err.Error()) + servers := NewRelays() + + update := func() { + err := servers.Update() + if err != nil { + t.Fatalf("%s", err.Error()) + } + t.Logf("got %d servers", len(servers.Slice())) } - t.Logf("got %d servers", len(servers.Slice())) + t.Run("GetMullvadServers", func(t *testing.T) { + update() + t.Logf(spew.Sdump(servers.Slice())) + }) + var last int + var lastSlice []MullvadServer + t.Run("GetMullvadServersCached", func(t *testing.T) { + update() + update() + update() + update() + update() + update() + update() + last = servers.size + lastSlice = servers.Slice() + }) + t.Run("GetMullvadServersChanged", func(t *testing.T) { + servers.url = "https://api.mullvad.net/www/relays/openvpn/" + update() + if last == servers.size { + t.Fatalf("expected %d to not equal %d", last, servers.size) + } + if len(servers.Slice()) == len(lastSlice) { + t.Fatalf("expected %d to not equal %d", len(lastSlice), len(servers.Slice())) + } + }) } diff --git a/sox.go b/sox.go new file mode 100644 index 0000000..a78f21b --- /dev/null +++ b/sox.go @@ -0,0 +1,7 @@ +package mullsox + +import "net/netip" + +func GetSOCKS() (netip.AddrPort, error) { + return netip.AddrPort{}, nil +}