Cached relay map
This commit is contained in:
parent
17a530eb81
commit
c58518b71a
2
api.go
2
api.go
|
@ -1,5 +1,7 @@
|
|||
package mullsox
|
||||
|
||||
const useragent = "mullsox/0.0.1"
|
||||
|
||||
const (
|
||||
baseDomain = "mullvad.net"
|
||||
baseEndpoint = "am.i." + baseDomain
|
||||
|
|
58
check.go
58
check.go
|
@ -4,10 +4,10 @@ import (
|
|||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
type MyIPDetails struct {
|
||||
|
@ -16,14 +16,14 @@ type MyIPDetails struct {
|
|||
}
|
||||
|
||||
func CheckIP4(ctx context.Context, h *http.Client) (details *IPDetails, err error) {
|
||||
return checkIP(ctx, h, false)
|
||||
return checkIP(ctx, false)
|
||||
}
|
||||
|
||||
func CheckIP6(ctx context.Context, h *http.Client) (details *IPDetails, err error) {
|
||||
return checkIP(ctx, h, true)
|
||||
return checkIP(ctx, true)
|
||||
}
|
||||
|
||||
func CheckIP(ctx context.Context, h *http.Client) (*MyIPDetails, error) {
|
||||
func CheckIP(ctx context.Context) (*MyIPDetails, error) {
|
||||
type result struct {
|
||||
details *IPDetails
|
||||
ipv6 bool
|
||||
|
@ -35,7 +35,7 @@ func CheckIP(ctx context.Context, h *http.Client) (*MyIPDetails, error) {
|
|||
check := func(resChan chan result, ipv6 bool) error {
|
||||
var err error
|
||||
var r = result{ipv6: ipv6}
|
||||
r.details, err = checkIP(ctx, h, r.ipv6)
|
||||
r.details, err = checkIP(ctx, r.ipv6)
|
||||
if err != nil {
|
||||
if r.ipv6 {
|
||||
err = fmt.Errorf("error checking ipv6: %w", err)
|
||||
|
@ -93,32 +93,36 @@ func CheckIP(ctx context.Context, h *http.Client) (*MyIPDetails, error) {
|
|||
return myip, err
|
||||
}
|
||||
|
||||
func checkIP(ctx context.Context, h *http.Client, ipv6 bool) (details *IPDetails, err error) {
|
||||
var (
|
||||
resp *http.Response
|
||||
cytes []byte
|
||||
target string
|
||||
)
|
||||
func checkIP(ctx context.Context, ipv6 bool) (details *IPDetails, err error) {
|
||||
var target string
|
||||
switch ipv6 {
|
||||
case true:
|
||||
target = EndpointCheck6
|
||||
default:
|
||||
target = EndpointCheck4
|
||||
}
|
||||
req, _ := http.NewRequestWithContext(ctx, "GET", target, nil)
|
||||
resp, err = h.Do(req)
|
||||
req := fasthttp.AcquireRequest()
|
||||
res := fasthttp.AcquireResponse()
|
||||
defer func() {
|
||||
fasthttp.ReleaseRequest(req)
|
||||
fasthttp.ReleaseResponse(res)
|
||||
}()
|
||||
req.SetRequestURI(target)
|
||||
req.Header.SetMethod("GET")
|
||||
req.Header.SetUserAgent(useragent)
|
||||
client := fasthttp.Client{}
|
||||
client.DialDualStack = true
|
||||
|
||||
err = client.Do(req, res)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if resp.StatusCode != 200 {
|
||||
err = fmt.Errorf("bad status code from %s : %s", target, resp.Status)
|
||||
if res.StatusCode() != http.StatusOK {
|
||||
err = fmt.Errorf("got status code %d", res.StatusCode())
|
||||
return
|
||||
}
|
||||
cytes, err = io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = json.Unmarshal(cytes, &details)
|
||||
|
||||
err = json.Unmarshal(res.Body(), &details)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -126,8 +130,8 @@ func checkIP(ctx context.Context, h *http.Client, ipv6 bool) (details *IPDetails
|
|||
// Returns the mullvad server you are connected to if any, and any error that occured
|
||||
//
|
||||
//goland:noinspection GoNilness
|
||||
func AmIMullvad(ctx context.Context, client *http.Client) (MullvadServer, error) {
|
||||
me, err := CheckIP(ctx, client)
|
||||
func (r *Checker) AmIMullvad(ctx context.Context) (MullvadServer, error) {
|
||||
me, err := CheckIP(ctx)
|
||||
if me == nil || me.V4 == nil && me.V6 == nil {
|
||||
return MullvadServer{}, err
|
||||
}
|
||||
|
@ -138,7 +142,7 @@ func AmIMullvad(ctx context.Context, client *http.Client) (MullvadServer, error)
|
|||
return MullvadServer{}, err
|
||||
}
|
||||
|
||||
relays, err := GetMullvadServers()
|
||||
err = r.Update()
|
||||
if err != nil {
|
||||
return MullvadServer{}, err
|
||||
}
|
||||
|
@ -146,14 +150,14 @@ func AmIMullvad(ctx context.Context, client *http.Client) (MullvadServer, error)
|
|||
isMullvad := false
|
||||
if me.V4 != nil && me.V4.MullvadExitIP {
|
||||
isMullvad = true
|
||||
if relays.Has(me.V4.MullvadExitIPHostname) {
|
||||
return relays.Get(me.V4.MullvadExitIPHostname), nil
|
||||
if r.Has(me.V4.MullvadExitIPHostname) {
|
||||
return r.Get(me.V4.MullvadExitIPHostname), nil
|
||||
}
|
||||
}
|
||||
if me.V6 != nil && me.V6.MullvadExitIP {
|
||||
isMullvad = true
|
||||
if relays.Has(me.V6.MullvadExitIPHostname) {
|
||||
return relays.Get(me.V6.MullvadExitIPHostname), nil
|
||||
if r.Has(me.V6.MullvadExitIPHostname) {
|
||||
return r.Get(me.V6.MullvadExitIPHostname), nil
|
||||
}
|
||||
}
|
||||
if isMullvad {
|
||||
|
|
|
@ -39,7 +39,7 @@ func TestCheckIP6(t *testing.T) {
|
|||
|
||||
func TestCheckIPConcurrent(t *testing.T) {
|
||||
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(15*time.Second))
|
||||
me, err := CheckIP(ctx, http.DefaultClient)
|
||||
me, err := CheckIP(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("%s", err.Error())
|
||||
}
|
||||
|
@ -70,8 +70,9 @@ func TestCheckIPConcurrent(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestAmIMullvad(t *testing.T) {
|
||||
servers := NewRelays()
|
||||
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(15*time.Second))
|
||||
am, err := AmIMullvad(ctx, http.DefaultClient)
|
||||
am, err := servers.AmIMullvad(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("%s", err.Error())
|
||||
}
|
||||
|
|
66
relays.go
66
relays.go
|
@ -14,21 +14,23 @@ func (mvs MullvadServer) String() string {
|
|||
return mvs.Hostname
|
||||
}
|
||||
|
||||
type Relays struct {
|
||||
type Checker struct {
|
||||
m map[string]MullvadServer
|
||||
size int
|
||||
url string
|
||||
*sync.RWMutex
|
||||
}
|
||||
|
||||
func NewRelays() *Relays {
|
||||
r := &Relays{
|
||||
func NewRelays() *Checker {
|
||||
r := &Checker{
|
||||
m: make(map[string]MullvadServer),
|
||||
RWMutex: &sync.RWMutex{},
|
||||
url: EndpointRelays,
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *Relays) Slice() []MullvadServer {
|
||||
func (r *Checker) Slice() []MullvadServer {
|
||||
r.RLock()
|
||||
defer r.RUnlock()
|
||||
var servers []MullvadServer
|
||||
|
@ -38,46 +40,78 @@ func (r *Relays) Slice() []MullvadServer {
|
|||
return servers
|
||||
}
|
||||
|
||||
func (r *Relays) Has(hostname string) bool {
|
||||
func (r *Checker) Has(hostname string) bool {
|
||||
r.RLock()
|
||||
_, ok := r.m[hostname]
|
||||
r.RUnlock()
|
||||
return ok
|
||||
}
|
||||
|
||||
func (r *Relays) Add(server MullvadServer) {
|
||||
func (r *Checker) Add(server MullvadServer) {
|
||||
r.Lock()
|
||||
r.m[server.Hostname] = server
|
||||
r.Unlock()
|
||||
}
|
||||
|
||||
func (r *Relays) Get(hostname string) MullvadServer {
|
||||
func (r *Checker) Get(hostname string) MullvadServer {
|
||||
r.RLock()
|
||||
defer r.RUnlock()
|
||||
return r.m[hostname]
|
||||
}
|
||||
|
||||
func GetMullvadServers() (*Relays, error) {
|
||||
var servers = NewRelays()
|
||||
var serverSlice []MullvadServer
|
||||
func (r *Checker) clear() {
|
||||
for k := range r.m {
|
||||
delete(r.m, k)
|
||||
}
|
||||
}
|
||||
|
||||
func getContentSize(url string) int {
|
||||
req := fasthttp.AcquireRequest()
|
||||
res := fasthttp.AcquireResponse()
|
||||
defer func() {
|
||||
fasthttp.ReleaseRequest(req)
|
||||
fasthttp.ReleaseResponse(res)
|
||||
}()
|
||||
req.Header.SetUserAgent("mulls0x/v0.0.1")
|
||||
req.Header.SetUserAgent(useragent)
|
||||
req.Header.SetMethod(http.MethodHead)
|
||||
req.SetRequestURI(url)
|
||||
if err := fasthttp.Do(req, res); err != nil {
|
||||
return -1
|
||||
}
|
||||
return res.Header.ContentLength()
|
||||
}
|
||||
|
||||
func (r *Checker) Update() error {
|
||||
var serverSlice []MullvadServer
|
||||
if r.size > 0 {
|
||||
current := getContentSize(r.url)
|
||||
if current == r.size {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
req := fasthttp.AcquireRequest()
|
||||
res := fasthttp.AcquireResponse()
|
||||
defer func() {
|
||||
fasthttp.ReleaseRequest(req)
|
||||
fasthttp.ReleaseResponse(res)
|
||||
}()
|
||||
req.Header.SetUserAgent(useragent)
|
||||
req.Header.SetContentType("application/json")
|
||||
req.Header.SetMethod(http.MethodGet)
|
||||
req.SetRequestURI(EndpointRelays)
|
||||
req.SetRequestURI(r.url)
|
||||
if err := fasthttp.Do(req, res); err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
if err := json.Unmarshal(res.Body(), &serverSlice); err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
r.Lock()
|
||||
r.clear()
|
||||
for _, server := range serverSlice {
|
||||
servers.Add(server)
|
||||
r.m[server.Hostname] = server
|
||||
}
|
||||
return servers, nil
|
||||
r.size = res.Header.ContentLength()
|
||||
r.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -1,12 +1,47 @@
|
|||
package mullsox
|
||||
|
||||
import "testing"
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/davecgh/go-spew/spew"
|
||||
)
|
||||
|
||||
func TestGetMullvadServers(t *testing.T) {
|
||||
servers, err := GetMullvadServers()
|
||||
if err != nil {
|
||||
t.Fatalf("%s", err.Error())
|
||||
servers := NewRelays()
|
||||
|
||||
update := func() {
|
||||
err := servers.Update()
|
||||
if err != nil {
|
||||
t.Fatalf("%s", err.Error())
|
||||
}
|
||||
t.Logf("got %d servers", len(servers.Slice()))
|
||||
}
|
||||
|
||||
t.Logf("got %d servers", len(servers.Slice()))
|
||||
t.Run("GetMullvadServers", func(t *testing.T) {
|
||||
update()
|
||||
t.Logf(spew.Sdump(servers.Slice()))
|
||||
})
|
||||
var last int
|
||||
var lastSlice []MullvadServer
|
||||
t.Run("GetMullvadServersCached", func(t *testing.T) {
|
||||
update()
|
||||
update()
|
||||
update()
|
||||
update()
|
||||
update()
|
||||
update()
|
||||
update()
|
||||
last = servers.size
|
||||
lastSlice = servers.Slice()
|
||||
})
|
||||
t.Run("GetMullvadServersChanged", func(t *testing.T) {
|
||||
servers.url = "https://api.mullvad.net/www/relays/openvpn/"
|
||||
update()
|
||||
if last == servers.size {
|
||||
t.Fatalf("expected %d to not equal %d", last, servers.size)
|
||||
}
|
||||
if len(servers.Slice()) == len(lastSlice) {
|
||||
t.Fatalf("expected %d to not equal %d", len(lastSlice), len(servers.Slice()))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue