mirror of https://github.com/yunginnanet/Rate5
Merge pull request #2 from yunginnanet/overhaul
This commit is contained in:
commit
78c11c9a7a
|
@ -1,376 +0,0 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
rate5 "github.com/yunginnanet/Rate5"
|
||||
)
|
||||
|
||||
// characters used for registration IDs
|
||||
const charset = "abcdefghijklmnopqrstuvwxyz1234567890"
|
||||
|
||||
var (
|
||||
// Rater is our connection ratelimiter using default limiter settings.
|
||||
Rater *rate5.Limiter
|
||||
// RegRater will only allow one registration per 50 seconds and will add to the wait each time you get limited. (by IP)
|
||||
RegRater *rate5.Limiter
|
||||
// CmdRater will slow down commands sent, if not logged in by IP, if logged in by ID.
|
||||
CmdRater *rate5.Limiter
|
||||
|
||||
srv *Server
|
||||
keySize = 8
|
||||
)
|
||||
|
||||
// Server is an instance of our concurrent TCP server including a map of active clients
|
||||
type Server struct {
|
||||
Map map[string]*Client
|
||||
AuthLog map[string][]Login
|
||||
Exempt map[string]bool
|
||||
mu *sync.RWMutex
|
||||
}
|
||||
|
||||
// Login represents a successful login by a user
|
||||
type Login struct {
|
||||
IP string
|
||||
Time time.Time
|
||||
}
|
||||
|
||||
// Client represents a known patron of our Server
|
||||
type Client struct {
|
||||
ID string
|
||||
Conn net.Conn
|
||||
|
||||
loggedin bool
|
||||
connected bool
|
||||
authlog []Login
|
||||
|
||||
deadline time.Duration
|
||||
read *bufio.Reader
|
||||
}
|
||||
|
||||
// UniqueKey is an implementation of our Identity interface, in short: Rate5 doesn't care where you derive the string used for ratelimiting
|
||||
func (c Client) UniqueKey() string {
|
||||
var err error
|
||||
var host string
|
||||
if c.loggedin {
|
||||
return c.ID
|
||||
}
|
||||
if host, _, err = net.SplitHostPort(c.Conn.RemoteAddr().String()); err == nil {
|
||||
return host
|
||||
}
|
||||
panic(err)
|
||||
}
|
||||
|
||||
func argParse() {
|
||||
if len(os.Args) < 1 {
|
||||
return
|
||||
}
|
||||
for i, arg := range os.Args {
|
||||
switch arg {
|
||||
case "-e":
|
||||
fallthrough
|
||||
case "--exempt":
|
||||
if len(os.Args) <= i+1 {
|
||||
return
|
||||
}
|
||||
srv.Exempt[os.Args[i+1]] = true
|
||||
default:
|
||||
continue
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
func init() {
|
||||
// Rater is our connection ratelimiter
|
||||
Rater = rate5.NewDefaultLimiter()
|
||||
// RegRater will only allow one registration per 50 seconds and will add to the wait each time you get limited
|
||||
RegRater = rate5.NewStrictLimiter(50, 1)
|
||||
// CmdRater will slow down commands send when connected
|
||||
CmdRater = rate5.NewLimiter(10, 20)
|
||||
|
||||
srv = &Server{
|
||||
Map: make(map[string]*Client),
|
||||
AuthLog: make(map[string][]Login),
|
||||
Exempt: make(map[string]bool),
|
||||
|
||||
mu: &sync.RWMutex{},
|
||||
}
|
||||
|
||||
argParse()
|
||||
|
||||
rd := Rater.DebugChannel()
|
||||
rrd := RegRater.DebugChannel()
|
||||
crd := CmdRater.DebugChannel()
|
||||
go watchDebug(rd, rrd, crd)
|
||||
}
|
||||
|
||||
func watchDebug(rd, rrd, crd chan string) {
|
||||
pre := "[Rate5] "
|
||||
var lastcount = 0
|
||||
var count = 0
|
||||
for {
|
||||
select {
|
||||
case msg := <-rd:
|
||||
fmt.Printf("%s Limit: %s \n", pre, msg)
|
||||
count++
|
||||
case msg := <-rrd:
|
||||
fmt.Printf("%s RegLimit: %s \n", pre, msg)
|
||||
count++
|
||||
case msg := <-crd:
|
||||
fmt.Printf("%s CmdLimit: %s \n", pre, msg)
|
||||
count++
|
||||
default:
|
||||
if count-lastcount >= 25 {
|
||||
lastcount = count
|
||||
fmt.Println("Rater: ", Rater.GetGrandTotalRated())
|
||||
fmt.Println("RegRater: ", RegRater.GetGrandTotalRated())
|
||||
fmt.Println("CmdRater: ", CmdRater.GetGrandTotalRated())
|
||||
}
|
||||
time.Sleep(time.Duration(10) * time.Millisecond)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) preLogin(c *Client) {
|
||||
c.send("Auth: ")
|
||||
in := c.recv()
|
||||
switch {
|
||||
case s.authCheck(c, in):
|
||||
c.loggedin = true
|
||||
c.deadline = time.Duration(480) * time.Second
|
||||
c.send("successful login")
|
||||
return
|
||||
case in == "register":
|
||||
// no exemption for strict ratelimiter (rate5 testing)
|
||||
if RegRater.Check(c) {
|
||||
c.send("you already registered recently\n")
|
||||
return
|
||||
}
|
||||
println("new registration from " + c.UniqueKey())
|
||||
s.setID(c, s.getUnusedID())
|
||||
c.send("\nregistration success\n[New ID]: " + c.ID)
|
||||
return
|
||||
default:
|
||||
c.send("invalid. type 'REGISTER' to register a new ID\n")
|
||||
return
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (s *Server) mainPrompt(c *Client) {
|
||||
for c.connected {
|
||||
c.send("\nRate5 > ")
|
||||
switch c.recv() {
|
||||
case "history":
|
||||
c.send("account logins:\n")
|
||||
for _, login := range s.AuthLog[c.ID] {
|
||||
c.send(login.Time.Format("Mon, 02 Jan 2006 15:04:05 MST") + ": " + login.IP + "\n")
|
||||
}
|
||||
case "help":
|
||||
c.send("history, whoami, logout\n")
|
||||
case "whoami":
|
||||
c.send(c.ID + "\n")
|
||||
case "quit":
|
||||
fallthrough
|
||||
case "exit":
|
||||
case "logout":
|
||||
c.loggedin = false
|
||||
return
|
||||
default:
|
||||
c.send("unknown command, are you lost?")
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func isExempt(c *Client) bool {
|
||||
srv.mu.RLock()
|
||||
_, exempt := srv.Exempt[c.UniqueKey()]
|
||||
srv.mu.RUnlock()
|
||||
return exempt
|
||||
}
|
||||
|
||||
func connRateCheck(c *Client) bool {
|
||||
if isExempt(c) {
|
||||
return false
|
||||
}
|
||||
if Rater.Check(c) {
|
||||
c.send("too many connections")
|
||||
println(c.UniqueKey() + " ratelimited")
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func closeConn(c *Client) {
|
||||
if err := c.Conn.Close(); err != nil {
|
||||
println(err.Error())
|
||||
}
|
||||
println("closed: " + c.Conn.RemoteAddr().String())
|
||||
}
|
||||
|
||||
func (s *Server) handleTCP(c *Client) {
|
||||
if err := c.Conn.(*net.TCPConn).SetLinger(0); err != nil {
|
||||
fmt.Println("error while setting setlinger:", err.Error())
|
||||
}
|
||||
defer closeConn(c)
|
||||
if rated := connRateCheck(c); rated {
|
||||
return
|
||||
}
|
||||
c.read = bufio.NewReader(c.Conn)
|
||||
if _, err := c.Conn.Write(loginBanner()); err != nil {
|
||||
return
|
||||
}
|
||||
for !c.loggedin {
|
||||
if !c.connected {
|
||||
return
|
||||
}
|
||||
s.preLogin(c)
|
||||
}
|
||||
s.mainPrompt(c)
|
||||
}
|
||||
|
||||
func (c *Client) send(data string) {
|
||||
if err := c.Conn.SetReadDeadline(time.Now().Add(c.deadline)); err != nil {
|
||||
fmt.Println("error while setting deadline:", err.Error())
|
||||
}
|
||||
if _, err := c.Conn.Write([]byte(data)); err != nil {
|
||||
c.connected = false
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) recv() string {
|
||||
if err := c.Conn.SetReadDeadline(time.Now().Add(c.deadline)); err != nil {
|
||||
fmt.Println("error while setting deadline:", err.Error())
|
||||
}
|
||||
|
||||
if !isExempt(c) {
|
||||
if CmdRater.Check(c) {
|
||||
if !c.loggedin {
|
||||
// if they hit the ratelimiter during log-in, disconnect them
|
||||
c.connected = false
|
||||
}
|
||||
time.Sleep(time.Duration(1250) * time.Millisecond)
|
||||
}
|
||||
}
|
||||
|
||||
in, err := c.read.ReadString('\n')
|
||||
if err != nil {
|
||||
println(c.UniqueKey() + ": " + err.Error())
|
||||
c.connected = false
|
||||
return in
|
||||
}
|
||||
c.read.Reset(c.Conn)
|
||||
return strings.ToLower(strings.TrimRight(in, "\n"))
|
||||
}
|
||||
|
||||
func randUint32() uint32 {
|
||||
b := make([]byte, 4096)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
|
||||
panic(err)
|
||||
}
|
||||
return binary.BigEndian.Uint32(b)
|
||||
}
|
||||
|
||||
func keygen() string {
|
||||
chrlen := len(charset)
|
||||
b := make([]byte, keySize)
|
||||
for i := 0; i != keySize; i++ {
|
||||
b[i] = charset[randUint32()%uint32(chrlen)]
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
// getUnusedKey assures that our newly generated ID is not in use
|
||||
func (s *Server) getUnusedID() string {
|
||||
s.mu.RLock()
|
||||
var newkey string
|
||||
for {
|
||||
newkey = keygen()
|
||||
if _, ok := s.Map[newkey]; !ok {
|
||||
break
|
||||
} else {
|
||||
println("key already exists! generating new...")
|
||||
}
|
||||
}
|
||||
s.mu.RUnlock()
|
||||
return newkey
|
||||
}
|
||||
|
||||
// setID sets the clients ID safely
|
||||
func (s *Server) setID(c *Client, id string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
c.ID = id
|
||||
s.Map[id] = c
|
||||
}
|
||||
|
||||
func (s *Server) replaceSession(c *Client, id string) {
|
||||
s.mu.Lock()
|
||||
|
||||
s.AuthLog[id] = append(s.AuthLog[id], Login{
|
||||
// we're not logged in so UniqueKey is still the IP address
|
||||
IP: c.UniqueKey(),
|
||||
Time: time.Now(),
|
||||
})
|
||||
defer s.mu.Unlock()
|
||||
delete(s.Map, id)
|
||||
s.Map[id] = c
|
||||
c.ID = id
|
||||
}
|
||||
|
||||
func (s *Server) authCheck(c *Client, id string) bool {
|
||||
s.mu.RLock()
|
||||
if old, ok := s.Map[id]; ok {
|
||||
s.mu.RUnlock()
|
||||
old.connected = false
|
||||
closeConn(old)
|
||||
s.replaceSession(c, id)
|
||||
return true
|
||||
}
|
||||
|
||||
s.mu.RUnlock()
|
||||
return false
|
||||
}
|
||||
|
||||
func loginBanner() []byte {
|
||||
var data []byte
|
||||
var err error
|
||||
login := "CnwgG1s5MDs0MG1SG1swbRtbMG0gG1s5Nzs0MG3DhhtbMG0bWzBtIBtbOTc7NDBtzpMbWzBtG1swbSAbWzk3OzQwbc6jG1swbRtbMG0gG1swbRtbOTc7MzJtNRtbMG0bWzBtIHwKCg=="
|
||||
if data, err = base64.StdEncoding.DecodeString(login); err == nil {
|
||||
return data
|
||||
}
|
||||
panic(err)
|
||||
}
|
||||
|
||||
func main() {
|
||||
l, err := net.Listen("tcp", "127.0.0.1:4444")
|
||||
if err != nil {
|
||||
panic(err.Error())
|
||||
}
|
||||
println("listening...")
|
||||
|
||||
for {
|
||||
conn, err := l.Accept()
|
||||
if err != nil {
|
||||
println(err.Error())
|
||||
}
|
||||
go srv.handleTCP(&Client{
|
||||
Conn: conn,
|
||||
connected: true,
|
||||
loggedin: false,
|
||||
deadline: time.Duration(12) * time.Second,
|
||||
})
|
||||
}
|
||||
}
|
2
go.mod
2
go.mod
|
@ -1,5 +1,5 @@
|
|||
module github.com/yunginnanet/Rate5
|
||||
|
||||
go 1.17
|
||||
go 1.18
|
||||
|
||||
require github.com/patrickmn/go-cache v2.1.0+incompatible
|
||||
|
|
25
models.go
25
models.go
|
@ -1,8 +1,8 @@
|
|||
package rate5
|
||||
|
||||
import (
|
||||
"sync/atomic"
|
||||
"sync"
|
||||
|
||||
"github.com/patrickmn/go-cache"
|
||||
)
|
||||
|
||||
|
@ -13,11 +13,6 @@ const (
|
|||
DefaultBurst = 25
|
||||
)
|
||||
|
||||
const (
|
||||
stateUnlocked uint32 = iota
|
||||
stateLocked
|
||||
)
|
||||
|
||||
var debugChannel chan string
|
||||
|
||||
// Identity is an interface that allows any arbitrary type to be used for a unique key in ratelimit checks when implemented.
|
||||
|
@ -25,11 +20,6 @@ type Identity interface {
|
|||
UniqueKey() string
|
||||
}
|
||||
|
||||
type rated struct {
|
||||
seen *atomic.Value
|
||||
locker uint32
|
||||
}
|
||||
|
||||
// Limiter implements an Enforcer to create an arbitrary ratelimiter.
|
||||
type Limiter struct {
|
||||
Source Identity
|
||||
|
@ -37,23 +27,22 @@ type Limiter struct {
|
|||
Patrons *cache.Cache
|
||||
// Ruleset is the actual ratelimitting model.
|
||||
Ruleset Policy
|
||||
/* Debug mode (toggled here) enables debug messages
|
||||
/* debug mode (toggled here) enables debug messages
|
||||
delivered through a channel. See: DebugChannel() */
|
||||
Debug bool
|
||||
debug bool
|
||||
|
||||
locker uint32
|
||||
count atomic.Value
|
||||
known map[interface{}]rated
|
||||
known map[interface{}]*int64
|
||||
|
||||
dmu *sync.RWMutex
|
||||
*sync.RWMutex
|
||||
}
|
||||
|
||||
// Policy defines the mechanics of our ratelimiter.
|
||||
type Policy struct {
|
||||
// Window defines the duration in seconds that we should keep track of ratelimit triggers,
|
||||
Window int
|
||||
Window int64
|
||||
// Burst is the amount of times that Check will not trigger a limit within the duration defined by Window.
|
||||
Burst int
|
||||
Burst int64
|
||||
// Strict mode punishes triggers of the ratelimitby increasing the amount of time they have to wait every time they trigger the limitter.
|
||||
Strict bool
|
||||
}
|
||||
|
|
131
ratelimiter.go
131
ratelimiter.go
|
@ -2,6 +2,7 @@ package rate5
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
@ -26,8 +27,8 @@ func NewCustomLimiter(policy Policy) *Limiter {
|
|||
// NewLimiter returns a custom limiter witout Strict mode
|
||||
func NewLimiter(window int, burst int) *Limiter {
|
||||
return newLimiter(Policy{
|
||||
Window: window,
|
||||
Burst: burst,
|
||||
Window: int64(window),
|
||||
Burst: int64(burst),
|
||||
Strict: false,
|
||||
})
|
||||
}
|
||||
|
@ -44,92 +45,90 @@ func NewDefaultStrictLimiter() *Limiter {
|
|||
// NewStrictLimiter returns a custom limiter with Strict mode.
|
||||
func NewStrictLimiter(window int, burst int) *Limiter {
|
||||
return newLimiter(Policy{
|
||||
Window: window,
|
||||
Burst: burst,
|
||||
Window: int64(window),
|
||||
Burst: int64(burst),
|
||||
Strict: true,
|
||||
})
|
||||
}
|
||||
|
||||
func newLimiter(policy Policy) *Limiter {
|
||||
q := new(Limiter)
|
||||
q.Ruleset = policy
|
||||
q.Patrons = cache.New(time.Duration(q.Ruleset.Window)*time.Second, 5*time.Second)
|
||||
q.known = make(map[interface{}]rated)
|
||||
q.dmu = &sync.RWMutex{}
|
||||
q.SetDebug(false)
|
||||
|
||||
return q
|
||||
return &Limiter{
|
||||
Ruleset: policy,
|
||||
Patrons: cache.New(time.Duration(policy.Window)*time.Second, 5*time.Second),
|
||||
known: make(map[interface{}]*int64),
|
||||
RWMutex: &sync.RWMutex{},
|
||||
debug: false,
|
||||
}
|
||||
}
|
||||
|
||||
func (q *Limiter) SetDebug(on bool) {
|
||||
q.dmu.Lock()
|
||||
q.Debug = on
|
||||
q.dmu.Unlock()
|
||||
q.Lock()
|
||||
q.debug = on
|
||||
q.Unlock()
|
||||
}
|
||||
|
||||
// DebugChannel enables Debug mode and returns a channel where debug messages are sent (NOTE: You must read from this channel if created via this function or it will block)
|
||||
// DebugChannel enables debug mode and returns a channel where debug messages are sent.
|
||||
// NOTE: You must read from this channel if created via this function or it will block
|
||||
func (q *Limiter) DebugChannel() chan string {
|
||||
q.dmu.Lock()
|
||||
q.Lock()
|
||||
q.Patrons.OnEvicted(func(src string, count interface{}) {
|
||||
q.debugPrint("ratelimit (expired): ", src, " ", count)
|
||||
})
|
||||
q.Debug = true
|
||||
q.debug = true
|
||||
debugChannel = make(chan string, 20)
|
||||
q.dmu.Unlock()
|
||||
q.Unlock()
|
||||
return debugChannel
|
||||
}
|
||||
|
||||
func (s rated) inc() {
|
||||
for !atomic.CompareAndSwapUint32(&s.locker, stateUnlocked, stateLocked) {
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
defer atomic.StoreUint32(&s.locker, stateUnlocked)
|
||||
|
||||
if s.seen.Load() == nil {
|
||||
s.seen.Store(1)
|
||||
return
|
||||
}
|
||||
s.seen.Store(s.seen.Load().(int) + 1)
|
||||
func intPtr(i int64) *int64 {
|
||||
return &i
|
||||
}
|
||||
|
||||
func (q *Limiter) strictLogic(src string, count int) {
|
||||
for !atomic.CompareAndSwapUint32(&q.locker, stateUnlocked, stateLocked) {
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
func (q *Limiter) getHitsPtr(src string) *int64 {
|
||||
q.RLock()
|
||||
defer q.RUnlock()
|
||||
if _, ok := q.known[src]; ok {
|
||||
return q.known[src]
|
||||
}
|
||||
defer atomic.StoreUint32(&q.locker, stateUnlocked)
|
||||
q.RUnlock()
|
||||
q.Lock()
|
||||
q.known[src] = intPtr(0)
|
||||
q.Unlock()
|
||||
q.RLock()
|
||||
return q.known[src]
|
||||
}
|
||||
|
||||
if _, ok := q.known[src]; !ok {
|
||||
q.known[src] = rated{
|
||||
seen: &atomic.Value{},
|
||||
locker: stateUnlocked,
|
||||
}
|
||||
}
|
||||
q.known[src].inc()
|
||||
extwindow := q.Ruleset.Window + q.known[src].seen.Load().(int)
|
||||
func (q *Limiter) strictLogic(src string, count int64) {
|
||||
knownHits := q.getHitsPtr(src)
|
||||
atomic.AddInt64(knownHits, 1)
|
||||
extwindow := q.Ruleset.Window + atomic.LoadInt64(knownHits)
|
||||
if err := q.Patrons.Replace(src, count, time.Duration(extwindow)*time.Second); err != nil {
|
||||
q.debugPrint("ratelimit: " + err.Error())
|
||||
q.debugPrint("ratelimit (strict) error: " + err.Error())
|
||||
}
|
||||
q.debugPrint("ratelimit (strictly limited): ", count, " ", src)
|
||||
q.increment()
|
||||
q.debugPrint("ratelimit (strict) limited: ", count, " ", src)
|
||||
}
|
||||
|
||||
// Check checks and increments an Identities UniqueKey() output against a list of cached strings to determine and raise it's ratelimitting status.
|
||||
func (q *Limiter) Check(from Identity) bool {
|
||||
var count int
|
||||
func (q *Limiter) Check(from Identity) (limited bool) {
|
||||
var count int64
|
||||
var err error
|
||||
src := from.UniqueKey()
|
||||
if count, err = q.Patrons.IncrementInt(src, 1); err != nil {
|
||||
q.debugPrint("ratelimit (new): ", src)
|
||||
if err := q.Patrons.Add(src, 1, time.Duration(q.Ruleset.Window)*time.Second); err != nil {
|
||||
q.debugPrint("ratelimit: " + err.Error())
|
||||
count, err = q.Patrons.IncrementInt64(src, 1)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
q.debugPrint("ratelimit (new): ", src)
|
||||
if cacheErr := q.Patrons.Add(src, int64(1), time.Duration(q.Ruleset.Window)*time.Second); cacheErr != nil {
|
||||
q.debugPrint("ratelimit error: " + cacheErr.Error())
|
||||
}
|
||||
return false
|
||||
}
|
||||
return false
|
||||
q.debugPrint("ratelimit error: " + err.Error())
|
||||
return true
|
||||
}
|
||||
if count < q.Ruleset.Burst {
|
||||
return false
|
||||
}
|
||||
if !q.Ruleset.Strict {
|
||||
q.increment()
|
||||
q.debugPrint("ratelimit (limited): ", count, " ", src)
|
||||
return true
|
||||
}
|
||||
|
@ -140,7 +139,7 @@ func (q *Limiter) Check(from Identity) bool {
|
|||
// Peek checks an Identities UniqueKey() output against a list of cached strings to determine ratelimitting status without adding to its request count.
|
||||
func (q *Limiter) Peek(from Identity) bool {
|
||||
if ct, ok := q.Patrons.Get(from.UniqueKey()); ok {
|
||||
count := ct.(int)
|
||||
count := ct.(int64)
|
||||
if count > q.Ruleset.Burst {
|
||||
return true
|
||||
}
|
||||
|
@ -148,28 +147,12 @@ func (q *Limiter) Peek(from Identity) bool {
|
|||
return false
|
||||
}
|
||||
|
||||
func (q *Limiter) increment() {
|
||||
if q.count.Load() == nil {
|
||||
q.count.Store(1)
|
||||
return
|
||||
}
|
||||
q.count.Store(q.count.Load().(int) + 1)
|
||||
}
|
||||
|
||||
// GetGrandTotalRated returns the historic total amount of times we have ever reported something as ratelimited.
|
||||
func (q *Limiter) GetGrandTotalRated() int {
|
||||
if q.count.Load() == nil {
|
||||
return 0
|
||||
}
|
||||
return q.count.Load().(int)
|
||||
}
|
||||
|
||||
func (q *Limiter) debugPrint(a ...interface{}) {
|
||||
q.dmu.RLock()
|
||||
if q.Debug {
|
||||
q.dmu.RUnlock()
|
||||
q.RLock()
|
||||
if q.debug {
|
||||
q.RUnlock()
|
||||
debugChannel <- fmt.Sprint(a...)
|
||||
return
|
||||
}
|
||||
q.dmu.RUnlock()
|
||||
q.RUnlock()
|
||||
}
|
||||
|
|
|
@ -3,21 +3,16 @@ package rate5
|
|||
import (
|
||||
"crypto/rand"
|
||||
"encoding/binary"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
dummyTicker *ticker
|
||||
testDebug = true
|
||||
|
||||
stopDebug chan bool
|
||||
stopDebug = make(chan bool)
|
||||
)
|
||||
|
||||
func init() {
|
||||
stopDebug = make(chan bool)
|
||||
}
|
||||
|
||||
type randomPatron struct {
|
||||
key string
|
||||
Identity
|
||||
|
@ -38,7 +33,7 @@ func randomUint32() uint32 {
|
|||
}
|
||||
|
||||
func (rp *randomPatron) GenerateKey() {
|
||||
var keylen = 8
|
||||
var keylen = 10
|
||||
buf := make([]byte, keylen)
|
||||
for n := 0; n != keylen; n++ {
|
||||
buf[n] = charset[randomUint32()%uint32(len(charset))]
|
||||
|
@ -49,33 +44,19 @@ func (rp *randomPatron) GenerateKey() {
|
|||
func watchDebug(r *Limiter, t *testing.T) {
|
||||
t.Logf("debug enabled")
|
||||
rd := r.DebugChannel()
|
||||
|
||||
pre := "[Rate5] "
|
||||
var lastcount = 0
|
||||
var count = 0
|
||||
for {
|
||||
select {
|
||||
case msg := <-rd:
|
||||
t.Logf("%s Limit: %s \n", pre, msg)
|
||||
case <-stopDebug:
|
||||
return
|
||||
default:
|
||||
count++
|
||||
if count-lastcount >= 10 {
|
||||
lastcount = count
|
||||
t.Logf("Times limited: %d", r.GetGrandTotalRated())
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type ticker struct{}
|
||||
|
||||
func init() {
|
||||
dummyTicker = &ticker{}
|
||||
}
|
||||
|
||||
func (tick *ticker) UniqueKey() string {
|
||||
return "tick"
|
||||
}
|
||||
|
@ -87,11 +68,8 @@ func Test_NewCustomLimiter(t *testing.T) {
|
|||
Strict: false,
|
||||
})
|
||||
|
||||
//goland:noinspection GoBoolExpressions
|
||||
if testDebug {
|
||||
go watchDebug(limiter, t)
|
||||
time.Sleep(25 * time.Millisecond)
|
||||
}
|
||||
go watchDebug(limiter, t)
|
||||
time.Sleep(25 * time.Millisecond)
|
||||
|
||||
for n := 0; n < 9; n++ {
|
||||
limiter.Check(dummyTicker)
|
||||
|
@ -111,11 +89,6 @@ func Test_NewCustomLimiter(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
//goland:noinspection GoBoolExpressions
|
||||
if testDebug {
|
||||
t.Logf("[Finished NewCustomLimiter] Times ratelimited: %d", limiter.GetGrandTotalRated())
|
||||
}
|
||||
|
||||
stopDebug <- true
|
||||
limiter = nil
|
||||
}
|
||||
|
@ -125,11 +98,8 @@ func Test_NewDefaultStrictLimiter(t *testing.T) {
|
|||
// DefaultWindow = 5
|
||||
limiter := NewDefaultStrictLimiter()
|
||||
|
||||
//goland:noinspection GoBoolExpressions
|
||||
if testDebug {
|
||||
go watchDebug(limiter, t)
|
||||
time.Sleep(25 * time.Millisecond)
|
||||
}
|
||||
go watchDebug(limiter, t)
|
||||
time.Sleep(25 * time.Millisecond)
|
||||
|
||||
for n := 0; n < 24; n++ {
|
||||
limiter.Check(dummyTicker)
|
||||
|
@ -150,33 +120,27 @@ func Test_NewDefaultStrictLimiter(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
//goland:noinspection GoBoolExpressions
|
||||
if testDebug {
|
||||
t.Logf("[Finished NewCustomLimiter] Times ratelimited: %d", limiter.GetGrandTotalRated())
|
||||
}
|
||||
|
||||
stopDebug <- true
|
||||
limiter = nil
|
||||
}
|
||||
|
||||
// This test is only here for safety, if the package is not safe, this will often panic.
|
||||
// We give this a healthy amount of padding in terms of our checks as this is far beyond the tolerances we expect during runtime.
|
||||
// At the end of the day, not panicing here is passing.
|
||||
func Test_ConcurrentSafetyTest(t *testing.T) {
|
||||
func concurrentTest(t *testing.T, jobs int, iterCount int, burst int64, shouldLimit bool) { //nolint:funlen
|
||||
var randos map[int]*randomPatron
|
||||
randos = make(map[int]*randomPatron)
|
||||
|
||||
limiter := NewCustomLimiter(Policy{
|
||||
Window: 240,
|
||||
Burst: 5000,
|
||||
Burst: burst,
|
||||
Strict: true,
|
||||
})
|
||||
|
||||
limitNotice := sync.Once{}
|
||||
|
||||
limiter.SetDebug(false)
|
||||
|
||||
usedkeys := make(map[string]interface{})
|
||||
|
||||
const jobs = 1000
|
||||
|
||||
for n := 0; n != jobs ; n++ {
|
||||
for n := 0; n != jobs; n++ {
|
||||
randos[n] = new(randomPatron)
|
||||
ok := true
|
||||
for ok {
|
||||
|
@ -188,56 +152,59 @@ func Test_ConcurrentSafetyTest(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
t.Logf("generated %d Patrons with unique keys", len(randos))
|
||||
t.Logf("generated %d Patrons with unique keys, running Check() with them %d times concurrently with a burst limit of %d...",
|
||||
len(randos), iterCount, burst)
|
||||
|
||||
doneChan := make(chan bool, 10)
|
||||
finChan := make(chan bool, 10)
|
||||
var finished = 0
|
||||
|
||||
for _, rp := range randos {
|
||||
for n := 0; n != 5; n++ {
|
||||
for n := 0; n != iterCount; n++ {
|
||||
go func(randomp *randomPatron) {
|
||||
limiter.Check(randomp)
|
||||
limiter.Peek(randomp)
|
||||
if limiter.Peek(randomp) {
|
||||
limitNotice.Do(func() {
|
||||
t.Logf("(sync.Once) %s limited", randomp.UniqueKey())
|
||||
})
|
||||
}
|
||||
finChan <- true
|
||||
}(rp)
|
||||
}
|
||||
}
|
||||
|
||||
var done = false
|
||||
testloop:
|
||||
for {
|
||||
select {
|
||||
case <-debugChannel:
|
||||
t.Logf("[debug] %s", <-debugChannel)
|
||||
case <-finChan:
|
||||
finished++
|
||||
default:
|
||||
if finished == (jobs * 5) {
|
||||
done = true
|
||||
break
|
||||
if finished >= (jobs * iterCount) {
|
||||
break testloop
|
||||
}
|
||||
}
|
||||
if done {
|
||||
go func() {
|
||||
doneChan <- true
|
||||
}()
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
<-doneChan
|
||||
println("done")
|
||||
|
||||
for _, rp := range randos {
|
||||
if limiter.Peek(rp) {
|
||||
if limiter.Peek(rp) && !shouldLimit {
|
||||
if ct, ok := limiter.Patrons.Get(rp.UniqueKey()); ok {
|
||||
t.Logf("WARN: Should not have been limited. Ratelimiter count: %d, policy: %d", ct, limiter.Ruleset.Burst)
|
||||
t.Errorf("WARN: Should not have been limited. Ratelimiter count: %d, policy: %d", ct, limiter.Ruleset.Burst)
|
||||
} else {
|
||||
t.Errorf("randomPatron does not exist in ratelimiter at all!")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//goland:noinspection GoBoolExpressions
|
||||
if testDebug {
|
||||
t.Logf("[Finished StrictConcurrentStressTest] Times ratelimited: %d", limiter.GetGrandTotalRated())
|
||||
}
|
||||
}
|
||||
|
||||
func Test_ConcurrentShouldNotLimit(t *testing.T) {
|
||||
concurrentTest(t, 100, 20, 20, false)
|
||||
concurrentTest(t, 100, 50, 50, false)
|
||||
}
|
||||
|
||||
func Test_ConcurrentShouldLimit(t *testing.T) {
|
||||
concurrentTest(t, 100, 21, 20, true)
|
||||
concurrentTest(t, 100, 51, 50, true)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue