mirror of https://github.com/yunginnanet/Rate5
Fix: Major bug with Peek. New: test cases
This commit is contained in:
parent
27822b0603
commit
5aea254233
|
@ -47,4 +47,4 @@ func (s *Server) handleTCP(c *Client) {
|
|||
## To-Do
|
||||
More Documentation
|
||||
More To-Dos
|
||||
Test Cases
|
||||
~~Test Cases~~
|
||||
|
|
|
@ -51,7 +51,7 @@ type Client struct {
|
|||
|
||||
loggedin bool
|
||||
connected bool
|
||||
autlog []Login
|
||||
authlog []Login
|
||||
|
||||
deadline time.Duration
|
||||
read *bufio.Reader
|
||||
|
@ -108,7 +108,6 @@ func init() {
|
|||
|
||||
argParse()
|
||||
|
||||
|
||||
rd := Rater.DebugChannel()
|
||||
rrd := RegRater.DebugChannel()
|
||||
crd := CmdRater.DebugChannel()
|
||||
|
@ -278,6 +277,7 @@ func (c *Client) recv() string {
|
|||
func randUint32() uint32 {
|
||||
b := make([]byte, 4096)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
|
||||
panic(err)
|
||||
}
|
||||
return binary.BigEndian.Uint32(b)
|
||||
|
@ -318,6 +318,7 @@ func (s *Server) setID(c *Client, id string) {
|
|||
|
||||
func (s *Server) replaceSession(c *Client, id string) {
|
||||
s.mu.Lock()
|
||||
|
||||
s.AuthLog[id] = append(s.AuthLog[id], Login{
|
||||
// we're not logged in so UniqueKey is still the IP address
|
||||
IP: c.UniqueKey(),
|
||||
|
|
10
models.go
10
models.go
|
@ -1,7 +1,6 @@
|
|||
package rate5
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/patrickmn/go-cache"
|
||||
|
@ -14,6 +13,11 @@ const (
|
|||
DefaultBurst = 25
|
||||
)
|
||||
|
||||
const (
|
||||
stateUnlocked uint32 = iota
|
||||
stateLocked
|
||||
)
|
||||
|
||||
var debugChannel chan string
|
||||
|
||||
// Identity is an interface that allows any arbitrary type to be used for a unique key in ratelimit checks when implemented.
|
||||
|
@ -23,6 +27,7 @@ type Identity interface {
|
|||
|
||||
type rated struct {
|
||||
seen atomic.Value
|
||||
locker *uint32
|
||||
}
|
||||
|
||||
// Limiter implements an Enforcer to create an arbitrary ratelimiter.
|
||||
|
@ -36,9 +41,8 @@ type Limiter struct {
|
|||
delivered through a channel. See: DebugChannel() */
|
||||
Debug bool
|
||||
|
||||
count int
|
||||
count atomic.Value
|
||||
known map[interface{}]*rated
|
||||
mu *sync.RWMutex
|
||||
}
|
||||
|
||||
// Policy defines the mechanics of our ratelimiter.
|
||||
|
|
|
@ -2,13 +2,13 @@ package rate5
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/patrickmn/go-cache"
|
||||
)
|
||||
|
||||
|
||||
// NewDefaultLimiter returns a ratelimiter with default settings without Strict mode.
|
||||
func NewDefaultLimiter() *Limiter {
|
||||
return newLimiter(Policy{
|
||||
|
@ -55,7 +55,6 @@ func newLimiter(policy Policy) *Limiter {
|
|||
q.Ruleset = policy
|
||||
q.Patrons = cache.New(time.Duration(q.Ruleset.Window)*time.Second, 5*time.Second)
|
||||
q.known = make(map[interface{}]*rated)
|
||||
q.mu = &sync.RWMutex{}
|
||||
return q
|
||||
}
|
||||
|
||||
|
@ -70,6 +69,11 @@ func (q *Limiter) DebugChannel() chan string {
|
|||
}
|
||||
|
||||
func (s *rated) inc() {
|
||||
for !atomic.CompareAndSwapUint32(s.locker, stateUnlocked, stateLocked) {
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
defer atomic.StoreUint32(s.locker, stateUnlocked)
|
||||
|
||||
if s.seen.Load() == nil {
|
||||
s.seen.Store(1)
|
||||
return
|
||||
|
@ -78,20 +82,15 @@ func (s *rated) inc() {
|
|||
}
|
||||
|
||||
func (q *Limiter) strictLogic(src string, count int) {
|
||||
q.mu.Lock()
|
||||
if _, ok := q.known[src]; !ok {
|
||||
q.known[src]=&rated{
|
||||
seen: atomic.Value{},
|
||||
}
|
||||
atomic.StoreUint32(q.known[src].locker, stateUnlocked)
|
||||
q.known[src]=&rated{seen: atomic.Value{}}
|
||||
}
|
||||
|
||||
q.known[src].inc()
|
||||
extwindow := q.Ruleset.Window + q.known[src].seen.Load().(int)
|
||||
|
||||
if err := q.Patrons.Replace(src, count, time.Duration(extwindow)*time.Second); err != nil {
|
||||
q.debugPrint("Rate5: " + err.Error())
|
||||
}
|
||||
q.mu.Unlock()
|
||||
q.debugPrint("ratelimit (strictly limited): ", count, " ", src)
|
||||
q.increment()
|
||||
}
|
||||
|
@ -122,24 +121,29 @@ func (q *Limiter) Check(from Identity) bool {
|
|||
|
||||
// Peek checks an Identities UniqueKey() output against a list of cached strings to determine ratelimitting status without adding to its request count.
|
||||
func (q *Limiter) Peek(from Identity) bool {
|
||||
if _, ok := q.Patrons.Get(from.UniqueKey()); ok {
|
||||
return true
|
||||
if ct, ok := q.Patrons.Get(from.UniqueKey()); ok {
|
||||
count := ct.(int)
|
||||
if count > q.Ruleset.Burst {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (q *Limiter) increment() {
|
||||
q.mu.Lock()
|
||||
defer q.mu.Unlock()
|
||||
q.count++
|
||||
if q.count.Load() == nil {
|
||||
q.count.Store(1)
|
||||
return
|
||||
}
|
||||
q.count.Store(q.count.Load().(int) + 1)
|
||||
}
|
||||
|
||||
// GetGrandTotalRated returns the historic total amount of times we have ever reported something as ratelimited.
|
||||
func (q *Limiter) GetGrandTotalRated() int {
|
||||
q.mu.RLock()
|
||||
defer q.mu.RUnlock()
|
||||
return q.count
|
||||
if q.count.Load() == nil {
|
||||
return 0
|
||||
}
|
||||
return q.count.Load().(int)
|
||||
}
|
||||
|
||||
func (q *Limiter) debugPrint(a ...interface{}) {
|
||||
|
|
|
@ -0,0 +1,115 @@
|
|||
package rate5
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
dummyTicker *ticker
|
||||
testDebug = true
|
||||
)
|
||||
|
||||
func watchDebug(r *Limiter, t *testing.T) {
|
||||
t.Logf("debug enabled")
|
||||
rd := r.DebugChannel()
|
||||
|
||||
pre := "[Rate5] "
|
||||
var lastcount = 0
|
||||
var count = 0
|
||||
for {
|
||||
select {
|
||||
case msg := <-rd:
|
||||
t.Logf("%s Limit: %s \n", pre, msg)
|
||||
default:
|
||||
count++
|
||||
if count-lastcount >= 10 {
|
||||
lastcount = count
|
||||
t.Logf("Times limited: %d", r.GetGrandTotalRated())
|
||||
}
|
||||
time.Sleep(time.Duration(10) * time.Millisecond)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type ticker struct{}
|
||||
|
||||
func init() {
|
||||
dummyTicker = &ticker{}
|
||||
}
|
||||
|
||||
func (tick *ticker) UniqueKey() string {
|
||||
return "tick"
|
||||
}
|
||||
|
||||
func Test_NewCustomLimiter(t *testing.T) {
|
||||
limiter := NewCustomLimiter(Policy{
|
||||
Window: 5,
|
||||
Burst: 10,
|
||||
Strict: false,
|
||||
})
|
||||
|
||||
//goland:noinspection GoBoolExpressions
|
||||
if testDebug {
|
||||
go watchDebug(limiter, t)
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
|
||||
for n := 0; n < 9; n++ {
|
||||
limiter.Check(dummyTicker)
|
||||
}
|
||||
if limiter.Peek(dummyTicker) {
|
||||
if ct, ok := limiter.Patrons.Get(dummyTicker.UniqueKey()); ok {
|
||||
t.Errorf("Should not have been limited. Ratelimiter count: %d", ct)
|
||||
} else {
|
||||
t.Errorf("dummyTicker does not exist in ratelimiter at all!")
|
||||
}
|
||||
}
|
||||
if !limiter.Check(dummyTicker) {
|
||||
if ct, ok := limiter.Patrons.Get(dummyTicker.UniqueKey()); ok {
|
||||
t.Errorf("Should have been limited. Ratelimiter count: %d", ct)
|
||||
} else {
|
||||
t.Errorf("dummyTicker does not exist in ratelimiter at all!")
|
||||
}
|
||||
}
|
||||
|
||||
//goland:noinspection GoBoolExpressions
|
||||
if testDebug {
|
||||
t.Logf("[Finished NewCustomLimiter] Times ratelimited: %d", limiter.GetGrandTotalRated())
|
||||
}
|
||||
}
|
||||
|
||||
func Test_NewDefaultStrictLimiter(t *testing.T) {
|
||||
// DefaultBurst = 25
|
||||
// DefaultWindow = 5
|
||||
limiter := NewDefaultStrictLimiter()
|
||||
|
||||
//goland:noinspection GoBoolExpressions
|
||||
if testDebug {
|
||||
go watchDebug(limiter, t)
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
|
||||
for n := 0; n < 24; n++ {
|
||||
limiter.Check(dummyTicker)
|
||||
}
|
||||
if limiter.Peek(dummyTicker) {
|
||||
if ct, ok := limiter.Patrons.Get(dummyTicker.UniqueKey()); ok {
|
||||
t.Errorf("Should not have been limited. Ratelimiter count: %d", ct)
|
||||
} else {
|
||||
t.Errorf("dummyTicker does not exist in ratelimiter at all!")
|
||||
}
|
||||
}
|
||||
if !limiter.Check(dummyTicker) {
|
||||
if ct, ok := limiter.Patrons.Get(dummyTicker.UniqueKey()); ok {
|
||||
t.Errorf("Should have been limited. Ratelimiter count: %d, policy: %d", ct, limiter.Ruleset.Burst)
|
||||
} else {
|
||||
t.Errorf("dummyTicker does not exist in ratelimiter at all!")
|
||||
}
|
||||
}
|
||||
|
||||
//goland:noinspection GoBoolExpressions
|
||||
if testDebug {
|
||||
t.Logf("[Finished NewCustomLimiter] Times ratelimited: %d", limiter.GetGrandTotalRated())
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue