Feat: Hardcore limiter and Reset function

This commit is contained in:
kayos@tcp.direct 2022-07-16 07:28:00 -07:00
parent 5ba326db3e
commit 2eb517efa5
Signed by: kayos
GPG Key ID: 4B841471B4BEE979
4 changed files with 241 additions and 143 deletions

58
debug.go Normal file
View File

@ -0,0 +1,58 @@
package rate5
import "fmt"
func (q *Limiter) debugPrintf(format string, a ...interface{}) {
q.debugMutex.RLock()
defer q.debugMutex.RUnlock()
if !q.debug {
return
}
msg := fmt.Sprintf(format, a...)
select {
case q.debugChannel <- msg:
default:
println(msg)
}
}
func (q *Limiter) setDebugEvict() {
q.Patrons.OnEvicted(func(src string, count interface{}) {
q.debugPrintf("ratelimit (expired): %s | last count [%d]", src, count)
})
}
func (q *Limiter) SetDebug(on bool) {
q.debugMutex.Lock()
if !on {
q.debug = false
q.Patrons.OnEvicted(nil)
q.debugMutex.Unlock()
return
}
q.debug = on
q.setDebugEvict()
q.debugMutex.Unlock()
q.debugPrintf("rate5 debug enabled")
}
// DebugChannel enables debug mode and returns a channel where debug messages are sent.
// NOTE: You must read from this channel if created via this function or it will block
func (q *Limiter) DebugChannel() chan string {
defer func() {
q.debugMutex.Lock()
q.debug = true
q.debugMutex.Unlock()
}()
q.debugMutex.RLock()
if q.debugChannel != nil {
q.debugMutex.RUnlock()
return q.debugChannel
}
q.debugMutex.RUnlock()
q.debugMutex.Lock()
defer q.debugMutex.Unlock()
q.debugChannel = make(chan string, 25)
q.setDebugEvict()
return q.debugChannel
}

View File

@ -42,6 +42,9 @@ type Policy struct {
Window int64
// Burst is the amount of times that Check will not trigger a limit within the duration defined by Window.
Burst int64
// Strict mode punishes triggers of the ratelimitby increasing the amount of time they have to wait every time they trigger the limitter.
// Strict mode punishes triggers of the ratelimitter by increasing the wait time upon every trigger of the limiter.
Strict bool
// Hardcore mode implies strict mode but instead of using addition when adding to the wait time, it uses multiplication.
// This will cause exponential ratelimiting.
Hardcore bool
}

View File

@ -1,7 +1,6 @@
package rate5
import (
"fmt"
"sync"
"sync/atomic"
"time"
@ -47,7 +46,7 @@ func NewDefaultStrictLimiter() *Limiter {
})
}
/*NewStrictLimiter returns a custom limiter with Strict mode.
/*NewStrictLimiter returns a custom limiter with Strict mode enabled.
* Window is the time in seconds that the limiter will cache requests.
* Burst is the number of requests that can be made in the window.*/
func NewStrictLimiter(window int, burst int) *Limiter {
@ -58,10 +57,26 @@ func NewStrictLimiter(window int, burst int) *Limiter {
})
}
/*NewHardcoreLimiter returns a custom limiter with Strict + Hardcore modes enabled.
Hardcore mode causes the time limited to be multiplied by the number of hits.
This differs from strict mode which is only using addition instead of multiplication.*/
func NewHardcoreLimiter(window int, burst int) *Limiter {
l := NewStrictLimiter(window, burst)
l.Ruleset.Hardcore = true
return l
}
func (q *Limiter) ResetItem(from Identity) {
q.Patrons.Delete(from.UniqueKey())
q.debugPrintf("ratelimit for %s has been reset", from.UniqueKey())
}
func newLimiter(policy Policy) *Limiter {
window := time.Duration(policy.Window) * time.Second
return &Limiter{
Ruleset: policy,
Patrons: cache.New(time.Duration(policy.Window)*time.Second, 5*time.Second),
Patrons: cache.New(window, 1*time.Second),
known: make(map[interface{}]*int64),
RWMutex: &sync.RWMutex{},
debugMutex: &sync.RWMutex{},
@ -69,34 +84,6 @@ func newLimiter(policy Policy) *Limiter {
}
}
func (q *Limiter) SetDebug(on bool) {
if !on {
q.Patrons.OnEvicted(nil)
}
q.debugMutex.Lock()
q.debug = on
q.debugMutex.Unlock()
}
// DebugChannel enables debug mode and returns a channel where debug messages are sent.
// NOTE: You must read from this channel if created via this function or it will block
func (q *Limiter) DebugChannel() chan string {
q.debugMutex.RLock()
if q.debug {
q.debugMutex.RUnlock()
return q.debugChannel
}
q.debugMutex.RUnlock()
q.debugMutex.Lock()
q.debug = true
q.debugChannel = make(chan string, 25)
q.Patrons.OnEvicted(func(src string, count interface{}) {
q.debugPrint("ratelimit (expired): ", src, " ", count)
})
q.debugMutex.Unlock()
return q.debugChannel
}
func intPtr(i int64) *int64 {
return &i
}
@ -119,9 +106,20 @@ func (q *Limiter) getHitsPtr(src string) *int64 {
func (q *Limiter) strictLogic(src string, count int64) {
knownHits := q.getHitsPtr(src)
atomic.AddInt64(knownHits, 1)
extwindow := q.Ruleset.Window + atomic.LoadInt64(knownHits)
_ = q.Patrons.Replace(src, count, time.Duration(extwindow)*time.Second)
q.debugPrint("ratelimit (strict) limited: ", count, " ", src)
var extwindow int64
prefix := "hardcore"
switch {
case q.Ruleset.Hardcore && q.Ruleset.Window > 1:
extwindow = atomic.LoadInt64(knownHits) * q.Ruleset.Window
case q.Ruleset.Hardcore && q.Ruleset.Window <= 1:
extwindow = atomic.LoadInt64(knownHits) * 2
case !q.Ruleset.Hardcore:
prefix = "strict"
extwindow = atomic.LoadInt64(knownHits) + q.Ruleset.Window
}
exttime := time.Duration(extwindow) * time.Second
_ = q.Patrons.Replace(src, count, exttime)
q.debugPrintf("%s ratelimit for %s: last count %d. time: %s", prefix, src, count, exttime)
}
// Check checks and increments an Identities UniqueKey() output against a list of cached strings to determine and raise it's ratelimitting status.
@ -129,10 +127,11 @@ func (q *Limiter) Check(from Identity) (limited bool) {
var count int64
var err error
src := from.UniqueKey()
q.Patrons.DeleteExpired()
count, err = q.Patrons.IncrementInt64(src, 1)
if err != nil {
// IncrementInt64 should only error if the value is not an int64, so we can assume it's a new key.
q.debugPrint("ratelimit (new): ", src)
q.debugPrintf("ratelimit %s (new) ", src)
// We can't reproduce this throwing an error, we can only assume that the key is new.
_ = q.Patrons.Add(src, int64(1), time.Duration(q.Ruleset.Window)*time.Second)
return false
@ -143,13 +142,15 @@ func (q *Limiter) Check(from Identity) (limited bool) {
if q.Ruleset.Strict {
q.strictLogic(src, count)
} else {
q.debugPrint("ratelimit (limited): ", count, " ", src)
q.debugPrintf("ratelimit %s: last count %d. time: %s",
src, count, time.Duration(q.Ruleset.Window)*time.Second)
}
return true
}
// Peek checks an Identities UniqueKey() output against a list of cached strings to determine ratelimitting status without adding to its request count.
func (q *Limiter) Peek(from Identity) bool {
q.Patrons.DeleteExpired()
if ct, ok := q.Patrons.Get(from.UniqueKey()); ok {
count := ct.(int64)
if count > q.Ruleset.Burst {
@ -158,14 +159,3 @@ func (q *Limiter) Peek(from Identity) bool {
}
return false
}
func (q *Limiter) debugPrint(a ...interface{}) {
q.debugMutex.RLock()
defer q.debugMutex.RUnlock()
if !q.debug {
return
}
go func(msg ...interface{}) {
q.debugChannel <- fmt.Sprint(msg...)
}(a)
}

View File

@ -1,6 +1,7 @@
package rate5
import (
"context"
"crypto/rand"
"encoding/binary"
"runtime"
@ -11,7 +12,6 @@ import (
var (
dummyTicker *ticker
stopDebug = make(chan bool)
)
type randomPatron struct {
@ -42,153 +42,202 @@ func (rp *randomPatron) GenerateKey() {
rp.key = string(buf)
}
var forCoverage = &sync.Once{}
var (
forCoverage = &sync.Once{}
watchDebugMutex = &sync.Mutex{}
)
func watchDebug(r *Limiter, t *testing.T) {
t.Logf("debug enabled")
func watchDebug(ctx context.Context, r *Limiter, t *testing.T) {
t.Helper()
watchDebugMutex.Lock()
defer watchDebugMutex.Unlock()
rd := r.DebugChannel()
forCoverage.Do(func() { rd = r.DebugChannel() })
pre := "[Rate5] "
forCoverage.Do(func() {
r.SetDebug(true)
rd = r.DebugChannel()
})
for {
select {
case msg := <-rd:
t.Logf("%s Limit: %s \n", pre, msg)
case <-stopDebug:
case <-ctx.Done():
r = nil
return
case msg := <-rd:
t.Logf("%s \n", msg)
default:
}
}
}
func peekCheckLimited(t *testing.T, limiter *Limiter, shouldbe bool) {
t.Helper()
switch {
case limiter.Peek(dummyTicker) && !shouldbe:
if ct, ok := limiter.Patrons.Get(dummyTicker.UniqueKey()); ok {
t.Errorf("Should not have been limited. Ratelimiter count: %d", ct)
} else {
t.Fatalf("dummyTicker does not exist in ratelimiter at all!")
}
case !limiter.Peek(dummyTicker) && shouldbe:
if ct, ok := limiter.Patrons.Get(dummyTicker.UniqueKey()); ok {
t.Errorf("Should have been limited. Ratelimiter count: %d", ct)
} else {
t.Fatalf("dummyTicker does not exist in ratelimiter at all!")
}
case limiter.Peek(dummyTicker) && shouldbe:
t.Logf("dummyTicker is limited")
case !limiter.Peek(dummyTicker) && !shouldbe:
t.Logf("dummyTicker is not limited")
}
}
// this test exists here for coverage, we are simulating the debug channel overflowing and then invoking println().
func Test_debugPrintf(t *testing.T) {
limiter := NewLimiter(1, 1)
_ = limiter.DebugChannel()
for n := 0; n < 50; n++ {
limiter.Check(dummyTicker)
}
}
type ticker struct{}
func (tick *ticker) UniqueKey() string {
return "tick"
return "TestItem"
}
func Test_ResetItem(t *testing.T) {
limiter := NewLimiter(500, 1)
ctx, cancel := context.WithCancel(context.Background())
go watchDebug(ctx, limiter, t)
time.Sleep(25 * time.Millisecond)
for n := 0; n < 10; n++ {
limiter.Check(dummyTicker)
}
limiter.ResetItem(dummyTicker)
peekCheckLimited(t, limiter, false)
cancel()
}
func Test_NewDefaultLimiter(t *testing.T) {
limiter := NewDefaultLimiter()
limiter.Check(dummyTicker)
if limiter.Peek(dummyTicker) {
t.Errorf("Should not have been limited")
}
for n := 0; n != DefaultBurst+1; n++ {
peekCheckLimited(t, limiter, false)
for n := 0; n != DefaultBurst; n++ {
limiter.Check(dummyTicker)
}
if !limiter.Peek(dummyTicker) {
t.Errorf("Should have been limited")
}
peekCheckLimited(t, limiter, true)
}
func Test_NewLimiter(t *testing.T) {
limiter := NewLimiter(5, 1)
limiter.Check(dummyTicker)
if limiter.Peek(dummyTicker) {
t.Errorf("Should not have been limited")
}
peekCheckLimited(t, limiter, false)
limiter.Check(dummyTicker)
if !limiter.Peek(dummyTicker) {
t.Errorf("Should have been limited")
}
}
func Test_NewCustomLimiter(t *testing.T) {
limiter := NewCustomLimiter(Policy{
Window: 5,
Burst: 10,
Strict: false,
})
go watchDebug(limiter, t)
time.Sleep(25 * time.Millisecond)
for n := 0; n < 9; n++ {
limiter.Check(dummyTicker)
}
if limiter.Peek(dummyTicker) {
if ct, ok := limiter.Patrons.Get(dummyTicker.UniqueKey()); ok {
t.Errorf("Should not have been limited. Ratelimiter count: %d", ct)
} else {
t.Fatalf("dummyTicker does not exist in ratelimiter at all!")
}
}
if !limiter.Check(dummyTicker) {
if ct, ok := limiter.Patrons.Get(dummyTicker.UniqueKey()); ok {
t.Errorf("Should have been limited. Ratelimiter count: %d", ct)
} else {
t.Fatalf("dummyTicker does not exist in ratelimiter at all!")
}
}
stopDebug <- true
limiter = nil
peekCheckLimited(t, limiter, true)
}
func Test_NewDefaultStrictLimiter(t *testing.T) {
// DefaultBurst = 25
// DefaultWindow = 5
limiter := NewDefaultStrictLimiter()
go watchDebug(limiter, t)
ctx, cancel := context.WithCancel(context.Background())
go watchDebug(ctx, limiter, t)
time.Sleep(25 * time.Millisecond)
for n := 0; n < 24; n++ {
for n := 0; n < 25; n++ {
limiter.Check(dummyTicker)
}
if limiter.Peek(dummyTicker) {
if ct, ok := limiter.Patrons.Get(dummyTicker.UniqueKey()); ok {
t.Errorf("Should not have been limited. Ratelimiter count: %d", ct)
} else {
t.Fatalf("dummyTicker does not exist in ratelimiter at all!")
}
}
if !limiter.Check(dummyTicker) {
if ct, ok := limiter.Patrons.Get(dummyTicker.UniqueKey()); ok {
t.Errorf("Should have been limited. Ratelimiter count: %d, policy: %d", ct, limiter.Ruleset.Burst)
} else {
t.Errorf("dummyTicker does not exist in ratelimiter at all!")
}
}
stopDebug <- true
peekCheckLimited(t, limiter, false)
limiter.Check(dummyTicker)
peekCheckLimited(t, limiter, true)
cancel()
limiter = nil
}
func Test_NewStrictLimiter(t *testing.T) {
limiter := NewStrictLimiter(5, 1)
ctx, cancel := context.WithCancel(context.Background())
go watchDebug(ctx, limiter, t)
limiter.Check(dummyTicker)
if limiter.Peek(dummyTicker) {
t.Errorf("Should not have been limited")
peekCheckLimited(t, limiter, false)
limiter.Check(dummyTicker)
peekCheckLimited(t, limiter, true)
limiter.Check(dummyTicker)
// for coverage, first we give the debug messages a couple seconds to be safe,
// then we wait for the cache eviction to trigger a debug message.
time.Sleep(2 * time.Second)
t.Logf(<-limiter.DebugChannel())
peekCheckLimited(t, limiter, false)
for n := 0; n != 6; n++ {
limiter.Check(dummyTicker)
}
limiter.Check(dummyTicker)
if !limiter.Peek(dummyTicker) {
peekCheckLimited(t, limiter, true)
time.Sleep(5 * time.Second)
peekCheckLimited(t, limiter, true)
time.Sleep(8 * time.Second)
peekCheckLimited(t, limiter, false)
cancel()
limiter = nil
}
func Test_NewHardcoreLimiter(t *testing.T) {
limiter := NewHardcoreLimiter(1, 5)
ctx, cancel := context.WithCancel(context.Background())
go watchDebug(ctx, limiter, t)
for n := 0; n != 4; n++ {
limiter.Check(dummyTicker)
}
peekCheckLimited(t, limiter, false)
if !limiter.Check(dummyTicker) {
t.Errorf("Should have been limited")
}
t.Logf("limited once, waiting for cache eviction")
time.Sleep(2 * time.Second)
peekCheckLimited(t, limiter, false)
for n := 0; n != 4; n++ {
limiter.Check(dummyTicker)
}
peekCheckLimited(t, limiter, false)
if !limiter.Check(dummyTicker) {
t.Errorf("Should have been limited")
}
limiter.Check(dummyTicker)
// for coverage
exp := limiter.DebugChannel()
<-exp
if limiter.Peek(dummyTicker) {
t.Errorf("Should not have been limited")
limiter.Check(dummyTicker)
time.Sleep(3 * time.Second)
peekCheckLimited(t, limiter, true)
time.Sleep(5 * time.Second)
peekCheckLimited(t, limiter, false)
for n := 0; n != 4; n++ {
limiter.Check(dummyTicker)
}
peekCheckLimited(t, limiter, false)
for n := 0; n != 10; n++ {
limiter.Check(dummyTicker)
}
time.Sleep(10 * time.Second)
peekCheckLimited(t, limiter, true)
cancel()
// for coverage, triggering the switch statement case for hardcore logic
limiter2 := NewHardcoreLimiter(2, 5)
ctx2, cancel2 := context.WithCancel(context.Background())
go watchDebug(ctx2, limiter2, t)
for n := 0; n != 6; n++ {
limiter2.Check(dummyTicker)
}
peekCheckLimited(t, limiter2, true)
time.Sleep(4 * time.Second)
peekCheckLimited(t, limiter2, false)
cancel2()
}
func concurrentTest(t *testing.T, jobs int, iterCount int, burst int64, shouldLimit bool) {
func concurrentTest(t *testing.T, jobs int, iterCount int, burst int64, shouldLimit bool) { //nolint:funlen
var randos map[int]*randomPatron
randos = make(map[int]*randomPatron)
limiter := NewCustomLimiter(Policy{
Window: 240,
Burst: burst,
Strict: true,
})
limitNotice := sync.Once{}
limiter.SetDebug(false)
usedkeys := make(map[string]interface{})
for n := 0; n != jobs; n++ {
randos[n] = new(randomPatron)
ok := true
@ -200,7 +249,6 @@ func concurrentTest(t *testing.T, jobs int, iterCount int, burst int64, shouldLi
}
}
}
t.Logf("generated %d Patrons with unique keys, running Check() with them %d times concurrently with a burst limit of %d...",
len(randos), iterCount, burst)
@ -220,7 +268,6 @@ func concurrentTest(t *testing.T, jobs int, iterCount int, burst int64, shouldLi
}
}(rp)
}
testloop:
for {
select {