mirror of
https://github.com/yunginnanet/Rate5
synced 2024-06-28 10:00:52 +00:00
Fix: Major bug with Peek. New: test cases
This commit is contained in:
parent
27822b0603
commit
5aea254233
@ -47,4 +47,4 @@ func (s *Server) handleTCP(c *Client) {
|
|||||||
## To-Do
|
## To-Do
|
||||||
More Documentation
|
More Documentation
|
||||||
More To-Dos
|
More To-Dos
|
||||||
Test Cases
|
~~Test Cases~~
|
||||||
|
@ -51,7 +51,7 @@ type Client struct {
|
|||||||
|
|
||||||
loggedin bool
|
loggedin bool
|
||||||
connected bool
|
connected bool
|
||||||
autlog []Login
|
authlog []Login
|
||||||
|
|
||||||
deadline time.Duration
|
deadline time.Duration
|
||||||
read *bufio.Reader
|
read *bufio.Reader
|
||||||
@ -108,7 +108,6 @@ func init() {
|
|||||||
|
|
||||||
argParse()
|
argParse()
|
||||||
|
|
||||||
|
|
||||||
rd := Rater.DebugChannel()
|
rd := Rater.DebugChannel()
|
||||||
rrd := RegRater.DebugChannel()
|
rrd := RegRater.DebugChannel()
|
||||||
crd := CmdRater.DebugChannel()
|
crd := CmdRater.DebugChannel()
|
||||||
@ -278,6 +277,7 @@ func (c *Client) recv() string {
|
|||||||
func randUint32() uint32 {
|
func randUint32() uint32 {
|
||||||
b := make([]byte, 4096)
|
b := make([]byte, 4096)
|
||||||
if _, err := rand.Read(b); err != nil {
|
if _, err := rand.Read(b); err != nil {
|
||||||
|
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
return binary.BigEndian.Uint32(b)
|
return binary.BigEndian.Uint32(b)
|
||||||
@ -318,6 +318,7 @@ func (s *Server) setID(c *Client, id string) {
|
|||||||
|
|
||||||
func (s *Server) replaceSession(c *Client, id string) {
|
func (s *Server) replaceSession(c *Client, id string) {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
|
|
||||||
s.AuthLog[id] = append(s.AuthLog[id], Login{
|
s.AuthLog[id] = append(s.AuthLog[id], Login{
|
||||||
// we're not logged in so UniqueKey is still the IP address
|
// we're not logged in so UniqueKey is still the IP address
|
||||||
IP: c.UniqueKey(),
|
IP: c.UniqueKey(),
|
||||||
|
10
models.go
10
models.go
@ -1,7 +1,6 @@
|
|||||||
package rate5
|
package rate5
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"sync"
|
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
|
||||||
"github.com/patrickmn/go-cache"
|
"github.com/patrickmn/go-cache"
|
||||||
@ -14,6 +13,11 @@ const (
|
|||||||
DefaultBurst = 25
|
DefaultBurst = 25
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
stateUnlocked uint32 = iota
|
||||||
|
stateLocked
|
||||||
|
)
|
||||||
|
|
||||||
var debugChannel chan string
|
var debugChannel chan string
|
||||||
|
|
||||||
// Identity is an interface that allows any arbitrary type to be used for a unique key in ratelimit checks when implemented.
|
// Identity is an interface that allows any arbitrary type to be used for a unique key in ratelimit checks when implemented.
|
||||||
@ -23,6 +27,7 @@ type Identity interface {
|
|||||||
|
|
||||||
type rated struct {
|
type rated struct {
|
||||||
seen atomic.Value
|
seen atomic.Value
|
||||||
|
locker *uint32
|
||||||
}
|
}
|
||||||
|
|
||||||
// Limiter implements an Enforcer to create an arbitrary ratelimiter.
|
// Limiter implements an Enforcer to create an arbitrary ratelimiter.
|
||||||
@ -36,9 +41,8 @@ type Limiter struct {
|
|||||||
delivered through a channel. See: DebugChannel() */
|
delivered through a channel. See: DebugChannel() */
|
||||||
Debug bool
|
Debug bool
|
||||||
|
|
||||||
count int
|
count atomic.Value
|
||||||
known map[interface{}]*rated
|
known map[interface{}]*rated
|
||||||
mu *sync.RWMutex
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Policy defines the mechanics of our ratelimiter.
|
// Policy defines the mechanics of our ratelimiter.
|
||||||
|
@ -2,13 +2,13 @@ package rate5
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"sync"
|
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/patrickmn/go-cache"
|
"github.com/patrickmn/go-cache"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
// NewDefaultLimiter returns a ratelimiter with default settings without Strict mode.
|
// NewDefaultLimiter returns a ratelimiter with default settings without Strict mode.
|
||||||
func NewDefaultLimiter() *Limiter {
|
func NewDefaultLimiter() *Limiter {
|
||||||
return newLimiter(Policy{
|
return newLimiter(Policy{
|
||||||
@ -55,7 +55,6 @@ func newLimiter(policy Policy) *Limiter {
|
|||||||
q.Ruleset = policy
|
q.Ruleset = policy
|
||||||
q.Patrons = cache.New(time.Duration(q.Ruleset.Window)*time.Second, 5*time.Second)
|
q.Patrons = cache.New(time.Duration(q.Ruleset.Window)*time.Second, 5*time.Second)
|
||||||
q.known = make(map[interface{}]*rated)
|
q.known = make(map[interface{}]*rated)
|
||||||
q.mu = &sync.RWMutex{}
|
|
||||||
return q
|
return q
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -70,6 +69,11 @@ func (q *Limiter) DebugChannel() chan string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *rated) inc() {
|
func (s *rated) inc() {
|
||||||
|
for !atomic.CompareAndSwapUint32(s.locker, stateUnlocked, stateLocked) {
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
}
|
||||||
|
defer atomic.StoreUint32(s.locker, stateUnlocked)
|
||||||
|
|
||||||
if s.seen.Load() == nil {
|
if s.seen.Load() == nil {
|
||||||
s.seen.Store(1)
|
s.seen.Store(1)
|
||||||
return
|
return
|
||||||
@ -78,20 +82,15 @@ func (s *rated) inc() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (q *Limiter) strictLogic(src string, count int) {
|
func (q *Limiter) strictLogic(src string, count int) {
|
||||||
q.mu.Lock()
|
|
||||||
if _, ok := q.known[src]; !ok {
|
if _, ok := q.known[src]; !ok {
|
||||||
q.known[src]=&rated{
|
atomic.StoreUint32(q.known[src].locker, stateUnlocked)
|
||||||
seen: atomic.Value{},
|
q.known[src]=&rated{seen: atomic.Value{}}
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
q.known[src].inc()
|
q.known[src].inc()
|
||||||
extwindow := q.Ruleset.Window + q.known[src].seen.Load().(int)
|
extwindow := q.Ruleset.Window + q.known[src].seen.Load().(int)
|
||||||
|
|
||||||
if err := q.Patrons.Replace(src, count, time.Duration(extwindow)*time.Second); err != nil {
|
if err := q.Patrons.Replace(src, count, time.Duration(extwindow)*time.Second); err != nil {
|
||||||
q.debugPrint("Rate5: " + err.Error())
|
q.debugPrint("Rate5: " + err.Error())
|
||||||
}
|
}
|
||||||
q.mu.Unlock()
|
|
||||||
q.debugPrint("ratelimit (strictly limited): ", count, " ", src)
|
q.debugPrint("ratelimit (strictly limited): ", count, " ", src)
|
||||||
q.increment()
|
q.increment()
|
||||||
}
|
}
|
||||||
@ -122,24 +121,29 @@ func (q *Limiter) Check(from Identity) bool {
|
|||||||
|
|
||||||
// Peek checks an Identities UniqueKey() output against a list of cached strings to determine ratelimitting status without adding to its request count.
|
// Peek checks an Identities UniqueKey() output against a list of cached strings to determine ratelimitting status without adding to its request count.
|
||||||
func (q *Limiter) Peek(from Identity) bool {
|
func (q *Limiter) Peek(from Identity) bool {
|
||||||
if _, ok := q.Patrons.Get(from.UniqueKey()); ok {
|
if ct, ok := q.Patrons.Get(from.UniqueKey()); ok {
|
||||||
return true
|
count := ct.(int)
|
||||||
|
if count > q.Ruleset.Burst {
|
||||||
|
return true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *Limiter) increment() {
|
func (q *Limiter) increment() {
|
||||||
q.mu.Lock()
|
if q.count.Load() == nil {
|
||||||
defer q.mu.Unlock()
|
q.count.Store(1)
|
||||||
q.count++
|
return
|
||||||
|
}
|
||||||
|
q.count.Store(q.count.Load().(int) + 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetGrandTotalRated returns the historic total amount of times we have ever reported something as ratelimited.
|
// GetGrandTotalRated returns the historic total amount of times we have ever reported something as ratelimited.
|
||||||
func (q *Limiter) GetGrandTotalRated() int {
|
func (q *Limiter) GetGrandTotalRated() int {
|
||||||
q.mu.RLock()
|
if q.count.Load() == nil {
|
||||||
defer q.mu.RUnlock()
|
return 0
|
||||||
return q.count
|
}
|
||||||
|
return q.count.Load().(int)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *Limiter) debugPrint(a ...interface{}) {
|
func (q *Limiter) debugPrint(a ...interface{}) {
|
||||||
|
115
ratelimiter_test.go
Normal file
115
ratelimiter_test.go
Normal file
@ -0,0 +1,115 @@
|
|||||||
|
package rate5
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
dummyTicker *ticker
|
||||||
|
testDebug = true
|
||||||
|
)
|
||||||
|
|
||||||
|
func watchDebug(r *Limiter, t *testing.T) {
|
||||||
|
t.Logf("debug enabled")
|
||||||
|
rd := r.DebugChannel()
|
||||||
|
|
||||||
|
pre := "[Rate5] "
|
||||||
|
var lastcount = 0
|
||||||
|
var count = 0
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case msg := <-rd:
|
||||||
|
t.Logf("%s Limit: %s \n", pre, msg)
|
||||||
|
default:
|
||||||
|
count++
|
||||||
|
if count-lastcount >= 10 {
|
||||||
|
lastcount = count
|
||||||
|
t.Logf("Times limited: %d", r.GetGrandTotalRated())
|
||||||
|
}
|
||||||
|
time.Sleep(time.Duration(10) * time.Millisecond)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type ticker struct{}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
dummyTicker = &ticker{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tick *ticker) UniqueKey() string {
|
||||||
|
return "tick"
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_NewCustomLimiter(t *testing.T) {
|
||||||
|
limiter := NewCustomLimiter(Policy{
|
||||||
|
Window: 5,
|
||||||
|
Burst: 10,
|
||||||
|
Strict: false,
|
||||||
|
})
|
||||||
|
|
||||||
|
//goland:noinspection GoBoolExpressions
|
||||||
|
if testDebug {
|
||||||
|
go watchDebug(limiter, t)
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
for n := 0; n < 9; n++ {
|
||||||
|
limiter.Check(dummyTicker)
|
||||||
|
}
|
||||||
|
if limiter.Peek(dummyTicker) {
|
||||||
|
if ct, ok := limiter.Patrons.Get(dummyTicker.UniqueKey()); ok {
|
||||||
|
t.Errorf("Should not have been limited. Ratelimiter count: %d", ct)
|
||||||
|
} else {
|
||||||
|
t.Errorf("dummyTicker does not exist in ratelimiter at all!")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !limiter.Check(dummyTicker) {
|
||||||
|
if ct, ok := limiter.Patrons.Get(dummyTicker.UniqueKey()); ok {
|
||||||
|
t.Errorf("Should have been limited. Ratelimiter count: %d", ct)
|
||||||
|
} else {
|
||||||
|
t.Errorf("dummyTicker does not exist in ratelimiter at all!")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//goland:noinspection GoBoolExpressions
|
||||||
|
if testDebug {
|
||||||
|
t.Logf("[Finished NewCustomLimiter] Times ratelimited: %d", limiter.GetGrandTotalRated())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_NewDefaultStrictLimiter(t *testing.T) {
|
||||||
|
// DefaultBurst = 25
|
||||||
|
// DefaultWindow = 5
|
||||||
|
limiter := NewDefaultStrictLimiter()
|
||||||
|
|
||||||
|
//goland:noinspection GoBoolExpressions
|
||||||
|
if testDebug {
|
||||||
|
go watchDebug(limiter, t)
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
for n := 0; n < 24; n++ {
|
||||||
|
limiter.Check(dummyTicker)
|
||||||
|
}
|
||||||
|
if limiter.Peek(dummyTicker) {
|
||||||
|
if ct, ok := limiter.Patrons.Get(dummyTicker.UniqueKey()); ok {
|
||||||
|
t.Errorf("Should not have been limited. Ratelimiter count: %d", ct)
|
||||||
|
} else {
|
||||||
|
t.Errorf("dummyTicker does not exist in ratelimiter at all!")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !limiter.Check(dummyTicker) {
|
||||||
|
if ct, ok := limiter.Patrons.Get(dummyTicker.UniqueKey()); ok {
|
||||||
|
t.Errorf("Should have been limited. Ratelimiter count: %d, policy: %d", ct, limiter.Ruleset.Burst)
|
||||||
|
} else {
|
||||||
|
t.Errorf("dummyTicker does not exist in ratelimiter at all!")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//goland:noinspection GoBoolExpressions
|
||||||
|
if testDebug {
|
||||||
|
t.Logf("[Finished NewCustomLimiter] Times ratelimited: %d", limiter.GetGrandTotalRated())
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user