mirror of https://github.com/yunginnanet/Rate5
Feat: Stringer Identity + Optimize
- Reduce potential debug contention by using cmpandswap atomics - Add the ability to use fmt.Stringers for Identity functionality (not sure why i ever did anything else tbh) - Complete test coverage
This commit is contained in:
parent
e008c0560f
commit
d7d8a0ce87
39
debug.go
39
debug.go
|
@ -1,18 +1,21 @@
|
|||
package rate5
|
||||
|
||||
import "fmt"
|
||||
import (
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
func (q *Limiter) debugPrintf(format string, a ...interface{}) {
|
||||
q.debugMutex.RLock()
|
||||
defer q.debugMutex.RUnlock()
|
||||
if !q.debug {
|
||||
if atomic.CompareAndSwapUint32(&q.debug, DebugDisabled, DebugDisabled) {
|
||||
return
|
||||
}
|
||||
msg := fmt.Sprintf(format, a...)
|
||||
select {
|
||||
case q.debugChannel <- msg:
|
||||
//
|
||||
default:
|
||||
println(msg)
|
||||
// drop the message but increment the lost counter
|
||||
atomic.AddInt64(&q.debugLost, 1)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -23,26 +26,22 @@ func (q *Limiter) setDebugEvict() {
|
|||
}
|
||||
|
||||
func (q *Limiter) SetDebug(on bool) {
|
||||
q.debugMutex.Lock()
|
||||
if !on {
|
||||
q.debug = false
|
||||
q.Patrons.OnEvicted(nil)
|
||||
q.debugMutex.Unlock()
|
||||
return
|
||||
switch on {
|
||||
case true:
|
||||
atomic.CompareAndSwapUint32(&q.debug, DebugDisabled, DebugEnabled)
|
||||
q.debugPrintf("rate5 debug enabled")
|
||||
case false:
|
||||
atomic.CompareAndSwapUint32(&q.debug, DebugEnabled, DebugDisabled)
|
||||
}
|
||||
q.debug = on
|
||||
q.setDebugEvict()
|
||||
q.debugMutex.Unlock()
|
||||
q.debugPrintf("rate5 debug enabled")
|
||||
}
|
||||
|
||||
// DebugChannel enables debug mode and returns a channel where debug messages are sent.
|
||||
// NOTE: You must read from this channel if created via this function or it will block
|
||||
//
|
||||
// NOTE: If you do not read from this channel, the debug messages will eventually be lost.
|
||||
// If this happens,
|
||||
func (q *Limiter) DebugChannel() chan string {
|
||||
defer func() {
|
||||
q.debugMutex.Lock()
|
||||
q.debug = true
|
||||
q.debugMutex.Unlock()
|
||||
atomic.CompareAndSwapUint32(&q.debug, DebugDisabled, DebugEnabled)
|
||||
}()
|
||||
q.debugMutex.RLock()
|
||||
if q.debugChannel != nil {
|
||||
|
@ -52,7 +51,7 @@ func (q *Limiter) DebugChannel() chan string {
|
|||
q.debugMutex.RUnlock()
|
||||
q.debugMutex.Lock()
|
||||
defer q.debugMutex.Unlock()
|
||||
q.debugChannel = make(chan string, 25)
|
||||
q.debugChannel = make(chan string, 55)
|
||||
q.setDebugEvict()
|
||||
return q.debugChannel
|
||||
}
|
||||
|
|
21
models.go
21
models.go
|
@ -1,6 +1,7 @@
|
|||
package rate5
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/patrickmn/go-cache"
|
||||
|
@ -18,19 +19,33 @@ type Identity interface {
|
|||
UniqueKey() string
|
||||
}
|
||||
|
||||
// IdentityStringer is an implentation of Identity that acts as a shim for types that implement fmt.Stringer.
|
||||
type IdentityStringer struct {
|
||||
stringer fmt.Stringer
|
||||
}
|
||||
|
||||
func (i IdentityStringer) UniqueKey() string {
|
||||
return i.stringer.String()
|
||||
}
|
||||
|
||||
const (
|
||||
DebugDisabled uint32 = iota
|
||||
DebugEnabled
|
||||
)
|
||||
|
||||
// Limiter implements an Enforcer to create an arbitrary ratelimiter.
|
||||
type Limiter struct {
|
||||
// Source is the implementation of the Identity interface. It is used to create a unique key for each request.
|
||||
Source Identity
|
||||
// Patrons gives access to the underlying cache type that powers the ratelimiter.
|
||||
// It is exposed for testing purposes.
|
||||
Patrons *cache.Cache
|
||||
|
||||
// Ruleset determines the Policy which is used to determine whether or not to ratelimit.
|
||||
// It consists of a Window and Burst, see Policy for more details.
|
||||
Ruleset Policy
|
||||
|
||||
debug bool
|
||||
debug uint32
|
||||
debugChannel chan string
|
||||
debugLost int64
|
||||
known map[interface{}]*int64
|
||||
debugMutex *sync.RWMutex
|
||||
*sync.RWMutex
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package rate5
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
@ -57,10 +58,12 @@ func NewStrictLimiter(window int, burst int) *Limiter {
|
|||
})
|
||||
}
|
||||
|
||||
/*NewHardcoreLimiter returns a custom limiter with Strict + Hardcore modes enabled.
|
||||
/*
|
||||
NewHardcoreLimiter returns a custom limiter with Strict + Hardcore modes enabled.
|
||||
|
||||
Hardcore mode causes the time limited to be multiplied by the number of hits.
|
||||
This differs from strict mode which is only using addition instead of multiplication.*/
|
||||
This differs from strict mode which is only using addition instead of multiplication.
|
||||
*/
|
||||
func NewHardcoreLimiter(window int, burst int) *Limiter {
|
||||
l := NewStrictLimiter(window, burst)
|
||||
l.Ruleset.Hardcore = true
|
||||
|
@ -80,7 +83,7 @@ func newLimiter(policy Policy) *Limiter {
|
|||
known: make(map[interface{}]*int64),
|
||||
RWMutex: &sync.RWMutex{},
|
||||
debugMutex: &sync.RWMutex{},
|
||||
debug: false,
|
||||
debug: DebugDisabled,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -122,6 +125,11 @@ func (q *Limiter) strictLogic(src string, count int64) {
|
|||
q.debugPrintf("%s ratelimit for %s: last count %d. time: %s", prefix, src, count, exttime)
|
||||
}
|
||||
|
||||
func (q *Limiter) CheckStringer(from fmt.Stringer) bool {
|
||||
targ := IdentityStringer{stringer: from}
|
||||
return q.Check(targ)
|
||||
}
|
||||
|
||||
// Check checks and increments an Identities UniqueKey() output against a list of cached strings to determine and raise it's ratelimitting status.
|
||||
func (q *Limiter) Check(from Identity) (limited bool) {
|
||||
var count int64
|
||||
|
@ -159,3 +167,8 @@ func (q *Limiter) Peek(from Identity) bool {
|
|||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (q *Limiter) PeekStringer(from fmt.Stringer) bool {
|
||||
targ := IdentityStringer{stringer: from}
|
||||
return q.Peek(targ)
|
||||
}
|
||||
|
|
|
@ -48,7 +48,6 @@ var (
|
|||
)
|
||||
|
||||
func watchDebug(ctx context.Context, r *Limiter, t *testing.T) {
|
||||
t.Helper()
|
||||
watchDebugMutex.Lock()
|
||||
defer watchDebugMutex.Unlock()
|
||||
rd := r.DebugChannel()
|
||||
|
@ -68,25 +67,28 @@ func watchDebug(ctx context.Context, r *Limiter, t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func peekCheckLimited(t *testing.T, limiter *Limiter, shouldbe bool) {
|
||||
t.Helper()
|
||||
func peekCheckLimited(t *testing.T, limiter *Limiter, shouldbe, stringer bool) {
|
||||
limited := limiter.Peek(dummyTicker)
|
||||
if stringer {
|
||||
limited = limiter.PeekStringer(dummyTicker)
|
||||
}
|
||||
switch {
|
||||
case limiter.Peek(dummyTicker) && !shouldbe:
|
||||
case limited && !shouldbe:
|
||||
if ct, ok := limiter.Patrons.Get(dummyTicker.UniqueKey()); ok {
|
||||
t.Errorf("Should not have been limited. Ratelimiter count: %d", ct)
|
||||
} else {
|
||||
t.Fatalf("dummyTicker does not exist in ratelimiter at all!")
|
||||
}
|
||||
case !limiter.Peek(dummyTicker) && shouldbe:
|
||||
case !limited && shouldbe:
|
||||
if ct, ok := limiter.Patrons.Get(dummyTicker.UniqueKey()); ok {
|
||||
t.Errorf("Should have been limited. Ratelimiter count: %d", ct)
|
||||
} else {
|
||||
t.Fatalf("dummyTicker does not exist in ratelimiter at all!")
|
||||
}
|
||||
case limiter.Peek(dummyTicker) && shouldbe:
|
||||
t.Logf("dummyTicker is limited")
|
||||
case !limiter.Peek(dummyTicker) && !shouldbe:
|
||||
t.Logf("dummyTicker is not limited")
|
||||
case limited && shouldbe:
|
||||
t.Logf("dummyTicker is limited as expected.")
|
||||
case !limited && !shouldbe:
|
||||
t.Logf("dummyTicker is not limited as expected.")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -105,6 +107,10 @@ func (tick *ticker) UniqueKey() string {
|
|||
return "TestItem"
|
||||
}
|
||||
|
||||
func (tick *ticker) String() string {
|
||||
return "TestItem"
|
||||
}
|
||||
|
||||
func Test_ResetItem(t *testing.T) {
|
||||
limiter := NewLimiter(500, 1)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
@ -114,26 +120,36 @@ func Test_ResetItem(t *testing.T) {
|
|||
limiter.Check(dummyTicker)
|
||||
}
|
||||
limiter.ResetItem(dummyTicker)
|
||||
peekCheckLimited(t, limiter, false)
|
||||
peekCheckLimited(t, limiter, false, false)
|
||||
cancel()
|
||||
}
|
||||
|
||||
func Test_NewDefaultLimiter(t *testing.T) {
|
||||
limiter := NewDefaultLimiter()
|
||||
limiter.Check(dummyTicker)
|
||||
peekCheckLimited(t, limiter, false)
|
||||
peekCheckLimited(t, limiter, false, false)
|
||||
for n := 0; n != DefaultBurst; n++ {
|
||||
limiter.Check(dummyTicker)
|
||||
}
|
||||
peekCheckLimited(t, limiter, true)
|
||||
peekCheckLimited(t, limiter, true, false)
|
||||
}
|
||||
|
||||
func Test_CheckAndPeekStringer(t *testing.T) {
|
||||
limiter := NewDefaultLimiter()
|
||||
limiter.CheckStringer(dummyTicker)
|
||||
peekCheckLimited(t, limiter, false, true)
|
||||
for n := 0; n != DefaultBurst; n++ {
|
||||
limiter.CheckStringer(dummyTicker)
|
||||
}
|
||||
peekCheckLimited(t, limiter, true, true)
|
||||
}
|
||||
|
||||
func Test_NewLimiter(t *testing.T) {
|
||||
limiter := NewLimiter(5, 1)
|
||||
limiter.Check(dummyTicker)
|
||||
peekCheckLimited(t, limiter, false)
|
||||
peekCheckLimited(t, limiter, false, false)
|
||||
limiter.Check(dummyTicker)
|
||||
peekCheckLimited(t, limiter, true)
|
||||
peekCheckLimited(t, limiter, true, false)
|
||||
}
|
||||
|
||||
func Test_NewDefaultStrictLimiter(t *testing.T) {
|
||||
|
@ -144,9 +160,9 @@ func Test_NewDefaultStrictLimiter(t *testing.T) {
|
|||
for n := 0; n < 25; n++ {
|
||||
limiter.Check(dummyTicker)
|
||||
}
|
||||
peekCheckLimited(t, limiter, false)
|
||||
peekCheckLimited(t, limiter, false, false)
|
||||
limiter.Check(dummyTicker)
|
||||
peekCheckLimited(t, limiter, true)
|
||||
peekCheckLimited(t, limiter, true, false)
|
||||
cancel()
|
||||
limiter = nil
|
||||
}
|
||||
|
@ -156,23 +172,23 @@ func Test_NewStrictLimiter(t *testing.T) {
|
|||
ctx, cancel := context.WithCancel(context.Background())
|
||||
go watchDebug(ctx, limiter, t)
|
||||
limiter.Check(dummyTicker)
|
||||
peekCheckLimited(t, limiter, false)
|
||||
peekCheckLimited(t, limiter, false, false)
|
||||
limiter.Check(dummyTicker)
|
||||
peekCheckLimited(t, limiter, true)
|
||||
peekCheckLimited(t, limiter, true, false)
|
||||
limiter.Check(dummyTicker)
|
||||
// for coverage, first we give the debug messages a couple seconds to be safe,
|
||||
// then we wait for the cache eviction to trigger a debug message.
|
||||
time.Sleep(2 * time.Second)
|
||||
t.Logf(<-limiter.DebugChannel())
|
||||
peekCheckLimited(t, limiter, false)
|
||||
peekCheckLimited(t, limiter, false, false)
|
||||
for n := 0; n != 6; n++ {
|
||||
limiter.Check(dummyTicker)
|
||||
}
|
||||
peekCheckLimited(t, limiter, true)
|
||||
peekCheckLimited(t, limiter, true, false)
|
||||
time.Sleep(5 * time.Second)
|
||||
peekCheckLimited(t, limiter, true)
|
||||
peekCheckLimited(t, limiter, true, false)
|
||||
time.Sleep(8 * time.Second)
|
||||
peekCheckLimited(t, limiter, false)
|
||||
peekCheckLimited(t, limiter, false, false)
|
||||
cancel()
|
||||
limiter = nil
|
||||
}
|
||||
|
@ -184,35 +200,35 @@ func Test_NewHardcoreLimiter(t *testing.T) {
|
|||
for n := 0; n != 4; n++ {
|
||||
limiter.Check(dummyTicker)
|
||||
}
|
||||
peekCheckLimited(t, limiter, false)
|
||||
peekCheckLimited(t, limiter, false, false)
|
||||
if !limiter.Check(dummyTicker) {
|
||||
t.Errorf("Should have been limited")
|
||||
}
|
||||
t.Logf("limited once, waiting for cache eviction")
|
||||
time.Sleep(2 * time.Second)
|
||||
peekCheckLimited(t, limiter, false)
|
||||
peekCheckLimited(t, limiter, false, false)
|
||||
for n := 0; n != 4; n++ {
|
||||
limiter.Check(dummyTicker)
|
||||
}
|
||||
peekCheckLimited(t, limiter, false)
|
||||
peekCheckLimited(t, limiter, false, false)
|
||||
if !limiter.Check(dummyTicker) {
|
||||
t.Errorf("Should have been limited")
|
||||
}
|
||||
limiter.Check(dummyTicker)
|
||||
limiter.Check(dummyTicker)
|
||||
time.Sleep(3 * time.Second)
|
||||
peekCheckLimited(t, limiter, true)
|
||||
peekCheckLimited(t, limiter, true, false)
|
||||
time.Sleep(5 * time.Second)
|
||||
peekCheckLimited(t, limiter, false)
|
||||
peekCheckLimited(t, limiter, false, false)
|
||||
for n := 0; n != 4; n++ {
|
||||
limiter.Check(dummyTicker)
|
||||
}
|
||||
peekCheckLimited(t, limiter, false)
|
||||
peekCheckLimited(t, limiter, false, false)
|
||||
for n := 0; n != 10; n++ {
|
||||
limiter.Check(dummyTicker)
|
||||
}
|
||||
time.Sleep(10 * time.Second)
|
||||
peekCheckLimited(t, limiter, true)
|
||||
peekCheckLimited(t, limiter, true, false)
|
||||
cancel()
|
||||
// for coverage, triggering the switch statement case for hardcore logic
|
||||
limiter2 := NewHardcoreLimiter(2, 5)
|
||||
|
@ -221,9 +237,9 @@ func Test_NewHardcoreLimiter(t *testing.T) {
|
|||
for n := 0; n != 6; n++ {
|
||||
limiter2.Check(dummyTicker)
|
||||
}
|
||||
peekCheckLimited(t, limiter2, true)
|
||||
peekCheckLimited(t, limiter2, true, false)
|
||||
time.Sleep(4 * time.Second)
|
||||
peekCheckLimited(t, limiter2, false)
|
||||
peekCheckLimited(t, limiter2, false, false)
|
||||
cancel2()
|
||||
}
|
||||
|
||||
|
@ -314,3 +330,18 @@ func Test_ConcurrentShouldLimit(t *testing.T) {
|
|||
concurrentTest(t, 50, 21, 20, true)
|
||||
concurrentTest(t, 50, 51, 50, true)
|
||||
}
|
||||
|
||||
func Test_debugChannelOverflow(t *testing.T) {
|
||||
limiter := NewDefaultLimiter()
|
||||
_ = limiter.DebugChannel()
|
||||
for n := 0; n != 78; n++ {
|
||||
limiter.Check(dummyTicker)
|
||||
if limiter.debugLost > 0 {
|
||||
t.Fatalf("debug channel overflowed")
|
||||
}
|
||||
}
|
||||
limiter.Check(dummyTicker)
|
||||
if limiter.debugLost == 0 {
|
||||
t.Fatalf("debug channel did not overflow")
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue