1
2
mirror of https://github.com/yunginnanet/Rate5 synced 2024-06-29 18:40:48 +00:00
Rate5/ratelimiter.go

171 lines
4.2 KiB
Go

package rate5
import (
"fmt"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/patrickmn/go-cache"
)
// NewDefaultLimiter returns a ratelimiter with default settings without Strict mode.
func NewDefaultLimiter() *Limiter {
return newLimiter(Policy{
Window: DefaultWindow,
Burst: DefaultBurst,
Strict: false,
})
}
// NewCustomLimiter returns a ratelimiter with the given Policy applied as the Ruleset.
func NewCustomLimiter(policy Policy) *Limiter {
return newLimiter(policy)
}
// NewLimiter returns a custom limiter witout Strict mode
func NewLimiter(window int, burst int) *Limiter {
return newLimiter(Policy{
Window: int64(window),
Burst: int64(burst),
Strict: false,
})
}
// NewDefaultStrictLimiter returns a ratelimiter with default settings with Strict mode.
func NewDefaultStrictLimiter() *Limiter {
return newLimiter(Policy{
Window: DefaultWindow,
Burst: DefaultBurst,
Strict: true,
})
}
// NewStrictLimiter returns a custom limiter with Strict mode.
func NewStrictLimiter(window int, burst int) *Limiter {
return newLimiter(Policy{
Window: int64(window),
Burst: int64(burst),
Strict: true,
})
}
func newLimiter(policy Policy) *Limiter {
return &Limiter{
Ruleset: policy,
Patrons: cache.New(time.Duration(policy.Window)*time.Second, 5*time.Second),
known: make(map[interface{}]*int64),
RWMutex: &sync.RWMutex{},
debugMutex: &sync.RWMutex{},
debug: false,
}
}
func (q *Limiter) SetDebug(on bool) {
if !on {
q.Patrons.OnEvicted(nil)
}
q.debugMutex.Lock()
q.debug = on
q.debugMutex.Unlock()
}
// DebugChannel enables debug mode and returns a channel where debug messages are sent.
// NOTE: You must read from this channel if created via this function or it will block
func (q *Limiter) DebugChannel() chan string {
q.debugMutex.RLock()
if q.debug {
q.debugMutex.RUnlock()
return q.debugChannel
}
q.debugMutex.RUnlock()
q.debugMutex.Lock()
q.debug = true
q.debugChannel = make(chan string, 25)
q.Patrons.OnEvicted(func(src string, count interface{}) {
q.debugPrint("ratelimit (expired): ", src, " ", count)
})
q.debugMutex.Unlock()
return q.debugChannel
}
func intPtr(i int64) *int64 {
return &i
}
func (q *Limiter) getHitsPtr(src string) *int64 {
q.RLock()
if _, ok := q.known[src]; ok {
oldPtr := q.known[src]
q.RUnlock()
return oldPtr
}
q.RUnlock()
q.Lock()
newPtr := intPtr(0)
q.known[src] = newPtr
q.Unlock()
return newPtr
}
func (q *Limiter) strictLogic(src string, count int64) {
knownHits := q.getHitsPtr(src)
atomic.AddInt64(knownHits, 1)
extwindow := q.Ruleset.Window + atomic.LoadInt64(knownHits)
if err := q.Patrons.Replace(src, count, time.Duration(extwindow)*time.Second); err != nil {
q.debugPrint("ratelimit (strict) error: " + err.Error())
}
q.debugPrint("ratelimit (strict) limited: ", count, " ", src)
}
// Check checks and increments an Identities UniqueKey() output against a list of cached strings to determine and raise it's ratelimitting status.
func (q *Limiter) Check(from Identity) (limited bool) {
var count int64
var err error
src := from.UniqueKey()
count, err = q.Patrons.IncrementInt64(src, 1)
if err != nil {
if strings.Contains(err.Error(), "not found") {
q.debugPrint("ratelimit (new): ", src)
if cacheErr := q.Patrons.Add(src, int64(1), time.Duration(q.Ruleset.Window)*time.Second); cacheErr != nil {
q.debugPrint("ratelimit error: " + cacheErr.Error())
}
return false
}
q.debugPrint("ratelimit error: " + err.Error())
return true
}
if count < q.Ruleset.Burst {
return false
}
if q.Ruleset.Strict {
q.strictLogic(src, count)
} else {
q.debugPrint("ratelimit (limited): ", count, " ", src)
}
return true
}
// Peek checks an Identities UniqueKey() output against a list of cached strings to determine ratelimitting status without adding to its request count.
func (q *Limiter) Peek(from Identity) bool {
if ct, ok := q.Patrons.Get(from.UniqueKey()); ok {
count := ct.(int64)
if count > q.Ruleset.Burst {
return true
}
}
return false
}
func (q *Limiter) debugPrint(a ...interface{}) {
q.debugMutex.RLock()
defer q.debugMutex.RUnlock()
if !q.debug {
return
}
go func(msg ...interface{}) {
q.debugChannel <- fmt.Sprint(msg...)
}(a)
}