From 5aea254233a8c22bd61f364b2f6d2f1bdba4d818 Mon Sep 17 00:00:00 2001 From: "kayos@tcp.direct" Date: Tue, 28 Sep 2021 01:24:04 -0700 Subject: [PATCH] Fix: Major bug with Peek. New: test cases --- README.md | 2 +- _examples/rated.go | 5 +- models.go | 10 ++-- ratelimiter.go | 40 ++++++++------- ratelimiter_test.go | 115 ++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 148 insertions(+), 24 deletions(-) create mode 100644 ratelimiter_test.go diff --git a/README.md b/README.md index dfca3bc..13615bd 100644 --- a/README.md +++ b/README.md @@ -47,4 +47,4 @@ func (s *Server) handleTCP(c *Client) { ## To-Do More Documentation More To-Dos -Test Cases +~~Test Cases~~ diff --git a/_examples/rated.go b/_examples/rated.go index 3eb03bd..edbacca 100644 --- a/_examples/rated.go +++ b/_examples/rated.go @@ -51,7 +51,7 @@ type Client struct { loggedin bool connected bool - autlog []Login + authlog []Login deadline time.Duration read *bufio.Reader @@ -108,7 +108,6 @@ func init() { argParse() - rd := Rater.DebugChannel() rrd := RegRater.DebugChannel() crd := CmdRater.DebugChannel() @@ -278,6 +277,7 @@ func (c *Client) recv() string { func randUint32() uint32 { b := make([]byte, 4096) if _, err := rand.Read(b); err != nil { + panic(err) } return binary.BigEndian.Uint32(b) @@ -318,6 +318,7 @@ func (s *Server) setID(c *Client, id string) { func (s *Server) replaceSession(c *Client, id string) { s.mu.Lock() + s.AuthLog[id] = append(s.AuthLog[id], Login{ // we're not logged in so UniqueKey is still the IP address IP: c.UniqueKey(), diff --git a/models.go b/models.go index 229c06d..b0fff8b 100644 --- a/models.go +++ b/models.go @@ -1,7 +1,6 @@ package rate5 import ( - "sync" "sync/atomic" "github.com/patrickmn/go-cache" @@ -14,6 +13,11 @@ const ( DefaultBurst = 25 ) +const ( + stateUnlocked uint32 = iota + stateLocked +) + var debugChannel chan string // Identity is an interface that allows any arbitrary type to be used for a unique key in ratelimit checks when implemented. @@ -23,6 +27,7 @@ type Identity interface { type rated struct { seen atomic.Value + locker *uint32 } // Limiter implements an Enforcer to create an arbitrary ratelimiter. @@ -36,9 +41,8 @@ type Limiter struct { delivered through a channel. See: DebugChannel() */ Debug bool - count int + count atomic.Value known map[interface{}]*rated - mu *sync.RWMutex } // Policy defines the mechanics of our ratelimiter. diff --git a/ratelimiter.go b/ratelimiter.go index e2827f0..ce5aeb7 100644 --- a/ratelimiter.go +++ b/ratelimiter.go @@ -2,13 +2,13 @@ package rate5 import ( "fmt" - "sync" "sync/atomic" "time" "github.com/patrickmn/go-cache" ) + // NewDefaultLimiter returns a ratelimiter with default settings without Strict mode. func NewDefaultLimiter() *Limiter { return newLimiter(Policy{ @@ -55,7 +55,6 @@ func newLimiter(policy Policy) *Limiter { q.Ruleset = policy q.Patrons = cache.New(time.Duration(q.Ruleset.Window)*time.Second, 5*time.Second) q.known = make(map[interface{}]*rated) - q.mu = &sync.RWMutex{} return q } @@ -70,6 +69,11 @@ func (q *Limiter) DebugChannel() chan string { } func (s *rated) inc() { + for !atomic.CompareAndSwapUint32(s.locker, stateUnlocked, stateLocked) { + time.Sleep(10 * time.Millisecond) + } + defer atomic.StoreUint32(s.locker, stateUnlocked) + if s.seen.Load() == nil { s.seen.Store(1) return @@ -78,20 +82,15 @@ func (s *rated) inc() { } func (q *Limiter) strictLogic(src string, count int) { - q.mu.Lock() if _, ok := q.known[src]; !ok { - q.known[src]=&rated{ - seen: atomic.Value{}, - } + atomic.StoreUint32(q.known[src].locker, stateUnlocked) + q.known[src]=&rated{seen: atomic.Value{}} } - q.known[src].inc() extwindow := q.Ruleset.Window + q.known[src].seen.Load().(int) - if err := q.Patrons.Replace(src, count, time.Duration(extwindow)*time.Second); err != nil { q.debugPrint("Rate5: " + err.Error()) } - q.mu.Unlock() q.debugPrint("ratelimit (strictly limited): ", count, " ", src) q.increment() } @@ -122,24 +121,29 @@ func (q *Limiter) Check(from Identity) bool { // Peek checks an Identities UniqueKey() output against a list of cached strings to determine ratelimitting status without adding to its request count. func (q *Limiter) Peek(from Identity) bool { - if _, ok := q.Patrons.Get(from.UniqueKey()); ok { - return true + if ct, ok := q.Patrons.Get(from.UniqueKey()); ok { + count := ct.(int) + if count > q.Ruleset.Burst { + return true + } } - return false } func (q *Limiter) increment() { - q.mu.Lock() - defer q.mu.Unlock() - q.count++ + if q.count.Load() == nil { + q.count.Store(1) + return + } + q.count.Store(q.count.Load().(int) + 1) } // GetGrandTotalRated returns the historic total amount of times we have ever reported something as ratelimited. func (q *Limiter) GetGrandTotalRated() int { - q.mu.RLock() - defer q.mu.RUnlock() - return q.count + if q.count.Load() == nil { + return 0 + } + return q.count.Load().(int) } func (q *Limiter) debugPrint(a ...interface{}) { diff --git a/ratelimiter_test.go b/ratelimiter_test.go new file mode 100644 index 0000000..ba649c3 --- /dev/null +++ b/ratelimiter_test.go @@ -0,0 +1,115 @@ +package rate5 + +import ( + "testing" + "time" +) + +var ( + dummyTicker *ticker + testDebug = true +) + +func watchDebug(r *Limiter, t *testing.T) { + t.Logf("debug enabled") + rd := r.DebugChannel() + + pre := "[Rate5] " + var lastcount = 0 + var count = 0 + for { + select { + case msg := <-rd: + t.Logf("%s Limit: %s \n", pre, msg) + default: + count++ + if count-lastcount >= 10 { + lastcount = count + t.Logf("Times limited: %d", r.GetGrandTotalRated()) + } + time.Sleep(time.Duration(10) * time.Millisecond) + } + } +} + +type ticker struct{} + +func init() { + dummyTicker = &ticker{} +} + +func (tick *ticker) UniqueKey() string { + return "tick" +} + +func Test_NewCustomLimiter(t *testing.T) { + limiter := NewCustomLimiter(Policy{ + Window: 5, + Burst: 10, + Strict: false, + }) + + //goland:noinspection GoBoolExpressions + if testDebug { + go watchDebug(limiter, t) + time.Sleep(100 * time.Millisecond) + } + + for n := 0; n < 9; n++ { + limiter.Check(dummyTicker) + } + if limiter.Peek(dummyTicker) { + if ct, ok := limiter.Patrons.Get(dummyTicker.UniqueKey()); ok { + t.Errorf("Should not have been limited. Ratelimiter count: %d", ct) + } else { + t.Errorf("dummyTicker does not exist in ratelimiter at all!") + } + } + if !limiter.Check(dummyTicker) { + if ct, ok := limiter.Patrons.Get(dummyTicker.UniqueKey()); ok { + t.Errorf("Should have been limited. Ratelimiter count: %d", ct) + } else { + t.Errorf("dummyTicker does not exist in ratelimiter at all!") + } + } + + //goland:noinspection GoBoolExpressions + if testDebug { + t.Logf("[Finished NewCustomLimiter] Times ratelimited: %d", limiter.GetGrandTotalRated()) + } +} + +func Test_NewDefaultStrictLimiter(t *testing.T) { + // DefaultBurst = 25 + // DefaultWindow = 5 + limiter := NewDefaultStrictLimiter() + + //goland:noinspection GoBoolExpressions + if testDebug { + go watchDebug(limiter, t) + time.Sleep(100 * time.Millisecond) + } + + for n := 0; n < 24; n++ { + limiter.Check(dummyTicker) + } + if limiter.Peek(dummyTicker) { + if ct, ok := limiter.Patrons.Get(dummyTicker.UniqueKey()); ok { + t.Errorf("Should not have been limited. Ratelimiter count: %d", ct) + } else { + t.Errorf("dummyTicker does not exist in ratelimiter at all!") + } + } + if !limiter.Check(dummyTicker) { + if ct, ok := limiter.Patrons.Get(dummyTicker.UniqueKey()); ok { + t.Errorf("Should have been limited. Ratelimiter count: %d, policy: %d", ct, limiter.Ruleset.Burst) + } else { + t.Errorf("dummyTicker does not exist in ratelimiter at all!") + } + } + + //goland:noinspection GoBoolExpressions + if testDebug { + t.Logf("[Finished NewCustomLimiter] Times ratelimited: %d", limiter.GetGrandTotalRated()) + } +}