1
2
mirror of https://github.com/yunginnanet/Rate5 synced 2024-06-25 00:18:37 +00:00

Fix: race condition

This commit is contained in:
kayos@tcp.direct 2022-03-04 13:09:08 -08:00
parent dc296e1d60
commit 3517b17034
Signed by: kayos
GPG Key ID: 4B841471B4BEE979
3 changed files with 30 additions and 12 deletions

@ -2,7 +2,7 @@ package rate5
import (
"sync/atomic"
"sync"
"github.com/patrickmn/go-cache"
)
@ -44,6 +44,8 @@ type Limiter struct {
locker uint32
count atomic.Value
known map[interface{}]rated
dmu *sync.RWMutex
}
// Policy defines the mechanics of our ratelimiter.

@ -2,6 +2,7 @@ package rate5
import (
"fmt"
"sync"
"sync/atomic"
"time"
@ -54,18 +55,27 @@ func newLimiter(policy Policy) *Limiter {
q.Ruleset = policy
q.Patrons = cache.New(time.Duration(q.Ruleset.Window)*time.Second, 5*time.Second)
q.known = make(map[interface{}]rated)
q.Debug = false
q.dmu = &sync.RWMutex{}
q.SetDebug(false)
return q
}
func (q *Limiter) SetDebug(on bool) {
q.dmu.Lock()
q.Debug = on
q.dmu.Unlock()
}
// DebugChannel enables Debug mode and returns a channel where debug messages are sent (NOTE: You must read from this channel if created via this function or it will block)
func (q *Limiter) DebugChannel() chan string {
q.dmu.Lock()
q.Patrons.OnEvicted(func(src string, count interface{}) {
q.debugPrint("ratelimit (expired): ", src, " ", count)
})
q.Debug = true
debugChannel = make(chan string, 20)
q.dmu.Unlock()
return debugChannel
}
@ -97,7 +107,7 @@ func (q *Limiter) strictLogic(src string, count int) {
q.known[src].inc()
extwindow := q.Ruleset.Window + q.known[src].seen.Load().(int)
if err := q.Patrons.Replace(src, count, time.Duration(extwindow)*time.Second); err != nil {
q.debugPrint("Rate5: " + err.Error())
q.debugPrint("ratelimit: " + err.Error())
}
q.debugPrint("ratelimit (strictly limited): ", count, " ", src)
q.increment()
@ -111,7 +121,7 @@ func (q *Limiter) Check(from Identity) bool {
if count, err = q.Patrons.IncrementInt(src, 1); err != nil {
q.debugPrint("ratelimit (new): ", src)
if err := q.Patrons.Add(src, 1, time.Duration(q.Ruleset.Window)*time.Second); err != nil {
q.debugPrint("Rate5: " + err.Error())
q.debugPrint("ratelimit: " + err.Error())
}
return false
}
@ -155,7 +165,11 @@ func (q *Limiter) GetGrandTotalRated() int {
}
func (q *Limiter) debugPrint(a ...interface{}) {
q.dmu.RLock()
if q.Debug {
q.dmu.RUnlock()
debugChannel <- fmt.Sprint(a...)
return
}
q.dmu.RUnlock()
}

@ -174,7 +174,9 @@ func Test_ConcurrentSafetyTest(t *testing.T) {
usedkeys := make(map[string]interface{})
for n := 0; n != 5000; n++ {
const jobs = 1000
for n := 0; n != jobs ; n++ {
randos[n] = new(randomPatron)
ok := true
for ok {
@ -188,17 +190,17 @@ func Test_ConcurrentSafetyTest(t *testing.T) {
t.Logf("generated %d Patrons with unique keys", len(randos))
doneChan := make(chan bool)
finChan := make(chan bool)
doneChan := make(chan bool, 10)
finChan := make(chan bool, 10)
var finished = 0
for _, rp := range randos {
for n := 0; n != 5; n++ {
go func() {
limiter.Check(rp)
limiter.Peek(rp)
go func(randomp *randomPatron) {
limiter.Check(randomp)
limiter.Peek(randomp)
finChan <- true
}()
}(rp)
}
}
@ -208,7 +210,7 @@ func Test_ConcurrentSafetyTest(t *testing.T) {
case <-finChan:
finished++
default:
if finished == 25000 {
if finished == (jobs * 5) {
done = true
break
}