Mild overhaul and testing improvements

This commit is contained in:
kayos@tcp.direct 2022-07-06 04:05:23 -07:00
parent c31c832330
commit 82b40accd4
Signed by: kayos
GPG Key ID: 4B841471B4BEE979
7 changed files with 132 additions and 176 deletions

13
_examples/go.mod Normal file
View File

@ -0,0 +1,13 @@
module rated
go 1.18
require (
git.tcp.direct/kayos/common v0.5.5
github.com/yunginnanet/Rate5 v0.4.4
)
require (
github.com/patrickmn/go-cache v2.1.0+incompatible // indirect
nullprogram.com/x/rng v1.1.0 // indirect
)

8
_examples/go.sum Normal file
View File

@ -0,0 +1,8 @@
git.tcp.direct/kayos/common v0.5.5 h1:ZLM7Q82acnSQmrWSQ98W4EKaszsf9JUYIsZgVr8V5ME=
git.tcp.direct/kayos/common v0.5.5/go.mod h1:jG1yXbN+5PrRZwGe32qIGWgLC4x5JWdyNBbMj1gIWB0=
github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc=
github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ=
github.com/yunginnanet/Rate5 v0.4.4 h1:NkQcBK6wD9RQ6AQ/rv+Cp4LasZlB38AkHnrN3X3tHF8=
github.com/yunginnanet/Rate5 v0.4.4/go.mod h1:aaaV1FLFmdBk1AD7uGQF53hgfPQg9yfBmIfDxtJuYZs=
nullprogram.com/x/rng v1.1.0 h1:SMU7DHaQSWtKJNTpNFIFt8Wd/KSmOuSDPXrMFp/UMro=
nullprogram.com/x/rng v1.1.0/go.mod h1:glGw6V87vyfawxCzqOABL3WfL95G65az9Z2JZCylCkg=

View File

@ -2,9 +2,7 @@ package main
import (
"bufio"
"crypto/rand"
"encoding/base64"
"encoding/binary"
"fmt"
"net"
"os"
@ -12,6 +10,8 @@ import (
"sync"
"time"
"git.tcp.direct/kayos/common/entropy"
rate5 "github.com/yunginnanet/Rate5"
)
@ -86,7 +86,6 @@ func argParse() {
default:
continue
}
}
}
@ -116,7 +115,6 @@ func init() {
func watchDebug(rd, rrd, crd chan string) {
pre := "[Rate5] "
var lastcount = 0
var count = 0
for {
select {
@ -130,12 +128,6 @@ func watchDebug(rd, rrd, crd chan string) {
fmt.Printf("%s CmdLimit: %s \n", pre, msg)
count++
default:
if count-lastcount >= 25 {
lastcount = count
fmt.Println("Rater: ", Rater.GetGrandTotalRated())
fmt.Println("RegRater: ", RegRater.GetGrandTotalRated())
fmt.Println("CmdRater: ", CmdRater.GetGrandTotalRated())
}
time.Sleep(time.Duration(10) * time.Millisecond)
}
}
@ -164,7 +156,6 @@ func (s *Server) preLogin(c *Client) {
c.send("invalid. type 'REGISTER' to register a new ID\n")
return
}
}
func (s *Server) mainPrompt(c *Client) {
@ -275,12 +266,7 @@ func (c *Client) recv() string {
}
func randUint32() uint32 {
b := make([]byte, 4096)
if _, err := rand.Read(b); err != nil {
panic(err)
}
return binary.BigEndian.Uint32(b)
return entropy.GetOptimizedRand().Uint32()
}
func keygen() string {

2
go.mod
View File

@ -1,5 +1,5 @@
module github.com/yunginnanet/Rate5
go 1.17
go 1.18
require github.com/patrickmn/go-cache v2.1.0+incompatible

View File

@ -1,8 +1,8 @@
package rate5
import (
"sync/atomic"
"sync"
"github.com/patrickmn/go-cache"
)
@ -13,11 +13,6 @@ const (
DefaultBurst = 25
)
const (
stateUnlocked uint32 = iota
stateLocked
)
var debugChannel chan string
// Identity is an interface that allows any arbitrary type to be used for a unique key in ratelimit checks when implemented.
@ -26,8 +21,7 @@ type Identity interface {
}
type rated struct {
seen *atomic.Value
locker uint32
seen int64
}
// Limiter implements an Enforcer to create an arbitrary ratelimiter.
@ -37,23 +31,22 @@ type Limiter struct {
Patrons *cache.Cache
// Ruleset is the actual ratelimitting model.
Ruleset Policy
/* Debug mode (toggled here) enables debug messages
/* debug mode (toggled here) enables debug messages
delivered through a channel. See: DebugChannel() */
Debug bool
debug bool
locker uint32
count atomic.Value
known map[interface{}]rated
dmu *sync.RWMutex
*sync.RWMutex
}
// Policy defines the mechanics of our ratelimiter.
type Policy struct {
// Window defines the duration in seconds that we should keep track of ratelimit triggers,
Window int
Window int64
// Burst is the amount of times that Check will not trigger a limit within the duration defined by Window.
Burst int
Burst int64
// Strict mode punishes triggers of the ratelimitby increasing the amount of time they have to wait every time they trigger the limitter.
Strict bool
}

View File

@ -2,6 +2,7 @@ package rate5
import (
"fmt"
"strings"
"sync"
"sync/atomic"
"time"
@ -26,8 +27,8 @@ func NewCustomLimiter(policy Policy) *Limiter {
// NewLimiter returns a custom limiter witout Strict mode
func NewLimiter(window int, burst int) *Limiter {
return newLimiter(Policy{
Window: window,
Burst: burst,
Window: int64(window),
Burst: int64(burst),
Strict: false,
})
}
@ -44,92 +45,96 @@ func NewDefaultStrictLimiter() *Limiter {
// NewStrictLimiter returns a custom limiter with Strict mode.
func NewStrictLimiter(window int, burst int) *Limiter {
return newLimiter(Policy{
Window: window,
Burst: burst,
Window: int64(window),
Burst: int64(burst),
Strict: true,
})
}
func newLimiter(policy Policy) *Limiter {
q := new(Limiter)
q.Ruleset = policy
q.Patrons = cache.New(time.Duration(q.Ruleset.Window)*time.Second, 5*time.Second)
q.known = make(map[interface{}]rated)
q.dmu = &sync.RWMutex{}
q.SetDebug(false)
return q
return &Limiter{
Ruleset: policy,
Patrons: cache.New(time.Duration(policy.Window)*time.Second, 5*time.Second),
known: make(map[interface{}]rated),
RWMutex: &sync.RWMutex{},
debug: false,
}
}
func (q *Limiter) SetDebug(on bool) {
q.dmu.Lock()
q.Debug = on
q.dmu.Unlock()
q.Lock()
q.debug = on
q.Unlock()
}
// DebugChannel enables Debug mode and returns a channel where debug messages are sent (NOTE: You must read from this channel if created via this function or it will block)
// DebugChannel enables debug mode and returns a channel where debug messages are sent.
// NOTE: You must read from this channel if created via this function or it will block
func (q *Limiter) DebugChannel() chan string {
q.dmu.Lock()
q.Lock()
q.Patrons.OnEvicted(func(src string, count interface{}) {
q.debugPrint("ratelimit (expired): ", src, " ", count)
})
q.Debug = true
q.debug = true
debugChannel = make(chan string, 20)
q.dmu.Unlock()
q.Unlock()
return debugChannel
}
func (s rated) inc() {
for !atomic.CompareAndSwapUint32(&s.locker, stateUnlocked, stateLocked) {
time.Sleep(10 * time.Millisecond)
}
defer atomic.StoreUint32(&s.locker, stateUnlocked)
if s.seen.Load() == nil {
s.seen.Store(1)
return
}
s.seen.Store(s.seen.Load().(int) + 1)
func (r rated) count() int64 {
return atomic.LoadInt64(&r.seen)
}
func (q *Limiter) strictLogic(src string, count int) {
for !atomic.CompareAndSwapUint32(&q.locker, stateUnlocked, stateLocked) {
time.Sleep(10 * time.Millisecond)
}
defer atomic.StoreUint32(&q.locker, stateUnlocked)
func (r rated) inc() {
atomic.AddInt64(&r.seen, 1)
}
func (q *Limiter) checkStrictPatron(src interface{}) {
q.RLock()
if _, ok := q.known[src]; !ok {
q.known[src] = rated{
seen: &atomic.Value{},
locker: stateUnlocked,
}
q.RUnlock()
q.newStrictPatron(src)
q.RLock()
}
q.RUnlock()
}
func (q *Limiter) newStrictPatron(src interface{}) {
q.Lock()
q.known[src] = rated{seen: 0}
q.Unlock()
}
func (q *Limiter) strictLogic(src string, count int64) {
q.checkStrictPatron(src)
q.known[src].inc()
extwindow := q.Ruleset.Window + q.known[src].seen.Load().(int)
extwindow := q.Ruleset.Window + q.known[src].count()
if err := q.Patrons.Replace(src, count, time.Duration(extwindow)*time.Second); err != nil {
q.debugPrint("ratelimit: " + err.Error())
q.debugPrint("ratelimit (strict) error: " + err.Error())
}
q.debugPrint("ratelimit (strictly limited): ", count, " ", src)
q.increment()
q.debugPrint("ratelimit (strict) limited: ", count, " ", src)
}
// Check checks and increments an Identities UniqueKey() output against a list of cached strings to determine and raise it's ratelimitting status.
func (q *Limiter) Check(from Identity) bool {
var count int
func (q *Limiter) Check(from Identity) (limited bool) {
var count int64
var err error
src := from.UniqueKey()
if count, err = q.Patrons.IncrementInt(src, 1); err != nil {
q.debugPrint("ratelimit (new): ", src)
if err := q.Patrons.Add(src, 1, time.Duration(q.Ruleset.Window)*time.Second); err != nil {
q.debugPrint("ratelimit: " + err.Error())
count, err = q.Patrons.IncrementInt64(src, 1)
if err != nil {
if strings.Contains(err.Error(), "not found") {
q.debugPrint("ratelimit (new): ", src)
if cacheErr := q.Patrons.Add(src, int64(1), time.Duration(q.Ruleset.Window)*time.Second); cacheErr != nil {
q.debugPrint("ratelimit error: " + cacheErr.Error())
}
return false
}
return false
q.debugPrint("ratelimit error: " + err.Error())
return true
}
if count < q.Ruleset.Burst {
return false
}
if !q.Ruleset.Strict {
q.increment()
q.debugPrint("ratelimit (limited): ", count, " ", src)
return true
}
@ -140,7 +145,7 @@ func (q *Limiter) Check(from Identity) bool {
// Peek checks an Identities UniqueKey() output against a list of cached strings to determine ratelimitting status without adding to its request count.
func (q *Limiter) Peek(from Identity) bool {
if ct, ok := q.Patrons.Get(from.UniqueKey()); ok {
count := ct.(int)
count := ct.(int64)
if count > q.Ruleset.Burst {
return true
}
@ -148,28 +153,12 @@ func (q *Limiter) Peek(from Identity) bool {
return false
}
func (q *Limiter) increment() {
if q.count.Load() == nil {
q.count.Store(1)
return
}
q.count.Store(q.count.Load().(int) + 1)
}
// GetGrandTotalRated returns the historic total amount of times we have ever reported something as ratelimited.
func (q *Limiter) GetGrandTotalRated() int {
if q.count.Load() == nil {
return 0
}
return q.count.Load().(int)
}
func (q *Limiter) debugPrint(a ...interface{}) {
q.dmu.RLock()
if q.Debug {
q.dmu.RUnlock()
q.RLock()
if q.debug {
q.RUnlock()
debugChannel <- fmt.Sprint(a...)
return
}
q.dmu.RUnlock()
q.RUnlock()
}

View File

@ -3,21 +3,16 @@ package rate5
import (
"crypto/rand"
"encoding/binary"
"sync"
"testing"
"time"
)
var (
dummyTicker *ticker
testDebug = true
stopDebug chan bool
stopDebug = make(chan bool)
)
func init() {
stopDebug = make(chan bool)
}
type randomPatron struct {
key string
Identity
@ -38,7 +33,7 @@ func randomUint32() uint32 {
}
func (rp *randomPatron) GenerateKey() {
var keylen = 8
var keylen = 10
buf := make([]byte, keylen)
for n := 0; n != keylen; n++ {
buf[n] = charset[randomUint32()%uint32(len(charset))]
@ -49,33 +44,19 @@ func (rp *randomPatron) GenerateKey() {
func watchDebug(r *Limiter, t *testing.T) {
t.Logf("debug enabled")
rd := r.DebugChannel()
pre := "[Rate5] "
var lastcount = 0
var count = 0
for {
select {
case msg := <-rd:
t.Logf("%s Limit: %s \n", pre, msg)
case <-stopDebug:
return
default:
count++
if count-lastcount >= 10 {
lastcount = count
t.Logf("Times limited: %d", r.GetGrandTotalRated())
}
time.Sleep(10 * time.Millisecond)
}
}
}
type ticker struct{}
func init() {
dummyTicker = &ticker{}
}
func (tick *ticker) UniqueKey() string {
return "tick"
}
@ -87,11 +68,8 @@ func Test_NewCustomLimiter(t *testing.T) {
Strict: false,
})
//goland:noinspection GoBoolExpressions
if testDebug {
go watchDebug(limiter, t)
time.Sleep(25 * time.Millisecond)
}
go watchDebug(limiter, t)
time.Sleep(25 * time.Millisecond)
for n := 0; n < 9; n++ {
limiter.Check(dummyTicker)
@ -111,11 +89,6 @@ func Test_NewCustomLimiter(t *testing.T) {
}
}
//goland:noinspection GoBoolExpressions
if testDebug {
t.Logf("[Finished NewCustomLimiter] Times ratelimited: %d", limiter.GetGrandTotalRated())
}
stopDebug <- true
limiter = nil
}
@ -125,11 +98,8 @@ func Test_NewDefaultStrictLimiter(t *testing.T) {
// DefaultWindow = 5
limiter := NewDefaultStrictLimiter()
//goland:noinspection GoBoolExpressions
if testDebug {
go watchDebug(limiter, t)
time.Sleep(25 * time.Millisecond)
}
go watchDebug(limiter, t)
time.Sleep(25 * time.Millisecond)
for n := 0; n < 24; n++ {
limiter.Check(dummyTicker)
@ -150,33 +120,27 @@ func Test_NewDefaultStrictLimiter(t *testing.T) {
}
}
//goland:noinspection GoBoolExpressions
if testDebug {
t.Logf("[Finished NewCustomLimiter] Times ratelimited: %d", limiter.GetGrandTotalRated())
}
stopDebug <- true
limiter = nil
}
// This test is only here for safety, if the package is not safe, this will often panic.
// We give this a healthy amount of padding in terms of our checks as this is far beyond the tolerances we expect during runtime.
// At the end of the day, not panicing here is passing.
func Test_ConcurrentSafetyTest(t *testing.T) {
func concurrentTest(t *testing.T, jobs int, iterCount int, burst int64, shouldLimit bool) { //nolint:funlen
var randos map[int]*randomPatron
randos = make(map[int]*randomPatron)
limiter := NewCustomLimiter(Policy{
Window: 240,
Burst: 5000,
Burst: burst,
Strict: true,
})
limitNotice := sync.Once{}
limiter.SetDebug(false)
usedkeys := make(map[string]interface{})
const jobs = 1000
for n := 0; n != jobs ; n++ {
for n := 0; n != jobs; n++ {
randos[n] = new(randomPatron)
ok := true
for ok {
@ -188,56 +152,59 @@ func Test_ConcurrentSafetyTest(t *testing.T) {
}
}
t.Logf("generated %d Patrons with unique keys", len(randos))
t.Logf("generated %d Patrons with unique keys, running Check() with them %d times concurrently...",
len(randos), iterCount)
doneChan := make(chan bool, 10)
finChan := make(chan bool, 10)
var finished = 0
for _, rp := range randos {
for n := 0; n != 5; n++ {
for n := 0; n != iterCount; n++ {
go func(randomp *randomPatron) {
limiter.Check(randomp)
limiter.Peek(randomp)
if limiter.Peek(randomp) {
limitNotice.Do(func() {
t.Logf("(sync.Once) %s limited", randomp.UniqueKey())
})
}
finChan <- true
}(rp)
}
}
var done = false
testloop:
for {
select {
case <-debugChannel:
t.Logf("[debug] %s", <-debugChannel)
case <-finChan:
finished++
default:
if finished == (jobs * 5) {
done = true
break
if finished >= (jobs * iterCount) {
break testloop
}
}
if done {
go func() {
doneChan <- true
}()
break
}
}
<-doneChan
println("done")
for _, rp := range randos {
if limiter.Peek(rp) {
if limiter.Peek(rp) && !shouldLimit {
if ct, ok := limiter.Patrons.Get(rp.UniqueKey()); ok {
t.Logf("WARN: Should not have been limited. Ratelimiter count: %d, policy: %d", ct, limiter.Ruleset.Burst)
t.Errorf("WARN: Should not have been limited. Ratelimiter count: %d, policy: %d", ct, limiter.Ruleset.Burst)
} else {
t.Errorf("randomPatron does not exist in ratelimiter at all!")
}
}
}
//goland:noinspection GoBoolExpressions
if testDebug {
t.Logf("[Finished StrictConcurrentStressTest] Times ratelimited: %d", limiter.GetGrandTotalRated())
}
}
func Test_ConcurrentShouldNotLimit(t *testing.T) {
concurrentTest(t, 500, 20, 20, false)
concurrentTest(t, 500, 50, 50, false)
}
func Test_ConcurrentShouldLimit(t *testing.T) {
concurrentTest(t, 500, 21, 20, true)
concurrentTest(t, 500, 51, 50, true)
}