Merge pull request #2 from yunginnanet/overhaul

This commit is contained in:
kayos 2022-07-06 04:36:02 -07:00 committed by GitHub
commit 78c11c9a7a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 104 additions and 541 deletions

View File

@ -1,376 +0,0 @@
package main
import (
"bufio"
"crypto/rand"
"encoding/base64"
"encoding/binary"
"fmt"
"net"
"os"
"strings"
"sync"
"time"
rate5 "github.com/yunginnanet/Rate5"
)
// characters used for registration IDs
const charset = "abcdefghijklmnopqrstuvwxyz1234567890"
var (
// Rater is our connection ratelimiter using default limiter settings.
Rater *rate5.Limiter
// RegRater will only allow one registration per 50 seconds and will add to the wait each time you get limited. (by IP)
RegRater *rate5.Limiter
// CmdRater will slow down commands sent, if not logged in by IP, if logged in by ID.
CmdRater *rate5.Limiter
srv *Server
keySize = 8
)
// Server is an instance of our concurrent TCP server including a map of active clients
type Server struct {
Map map[string]*Client
AuthLog map[string][]Login
Exempt map[string]bool
mu *sync.RWMutex
}
// Login represents a successful login by a user
type Login struct {
IP string
Time time.Time
}
// Client represents a known patron of our Server
type Client struct {
ID string
Conn net.Conn
loggedin bool
connected bool
authlog []Login
deadline time.Duration
read *bufio.Reader
}
// UniqueKey is an implementation of our Identity interface, in short: Rate5 doesn't care where you derive the string used for ratelimiting
func (c Client) UniqueKey() string {
var err error
var host string
if c.loggedin {
return c.ID
}
if host, _, err = net.SplitHostPort(c.Conn.RemoteAddr().String()); err == nil {
return host
}
panic(err)
}
func argParse() {
if len(os.Args) < 1 {
return
}
for i, arg := range os.Args {
switch arg {
case "-e":
fallthrough
case "--exempt":
if len(os.Args) <= i+1 {
return
}
srv.Exempt[os.Args[i+1]] = true
default:
continue
}
}
}
func init() {
// Rater is our connection ratelimiter
Rater = rate5.NewDefaultLimiter()
// RegRater will only allow one registration per 50 seconds and will add to the wait each time you get limited
RegRater = rate5.NewStrictLimiter(50, 1)
// CmdRater will slow down commands send when connected
CmdRater = rate5.NewLimiter(10, 20)
srv = &Server{
Map: make(map[string]*Client),
AuthLog: make(map[string][]Login),
Exempt: make(map[string]bool),
mu: &sync.RWMutex{},
}
argParse()
rd := Rater.DebugChannel()
rrd := RegRater.DebugChannel()
crd := CmdRater.DebugChannel()
go watchDebug(rd, rrd, crd)
}
func watchDebug(rd, rrd, crd chan string) {
pre := "[Rate5] "
var lastcount = 0
var count = 0
for {
select {
case msg := <-rd:
fmt.Printf("%s Limit: %s \n", pre, msg)
count++
case msg := <-rrd:
fmt.Printf("%s RegLimit: %s \n", pre, msg)
count++
case msg := <-crd:
fmt.Printf("%s CmdLimit: %s \n", pre, msg)
count++
default:
if count-lastcount >= 25 {
lastcount = count
fmt.Println("Rater: ", Rater.GetGrandTotalRated())
fmt.Println("RegRater: ", RegRater.GetGrandTotalRated())
fmt.Println("CmdRater: ", CmdRater.GetGrandTotalRated())
}
time.Sleep(time.Duration(10) * time.Millisecond)
}
}
}
func (s *Server) preLogin(c *Client) {
c.send("Auth: ")
in := c.recv()
switch {
case s.authCheck(c, in):
c.loggedin = true
c.deadline = time.Duration(480) * time.Second
c.send("successful login")
return
case in == "register":
// no exemption for strict ratelimiter (rate5 testing)
if RegRater.Check(c) {
c.send("you already registered recently\n")
return
}
println("new registration from " + c.UniqueKey())
s.setID(c, s.getUnusedID())
c.send("\nregistration success\n[New ID]: " + c.ID)
return
default:
c.send("invalid. type 'REGISTER' to register a new ID\n")
return
}
}
func (s *Server) mainPrompt(c *Client) {
for c.connected {
c.send("\nRate5 > ")
switch c.recv() {
case "history":
c.send("account logins:\n")
for _, login := range s.AuthLog[c.ID] {
c.send(login.Time.Format("Mon, 02 Jan 2006 15:04:05 MST") + ": " + login.IP + "\n")
}
case "help":
c.send("history, whoami, logout\n")
case "whoami":
c.send(c.ID + "\n")
case "quit":
fallthrough
case "exit":
case "logout":
c.loggedin = false
return
default:
c.send("unknown command, are you lost?")
continue
}
}
}
func isExempt(c *Client) bool {
srv.mu.RLock()
_, exempt := srv.Exempt[c.UniqueKey()]
srv.mu.RUnlock()
return exempt
}
func connRateCheck(c *Client) bool {
if isExempt(c) {
return false
}
if Rater.Check(c) {
c.send("too many connections")
println(c.UniqueKey() + " ratelimited")
return true
}
return false
}
func closeConn(c *Client) {
if err := c.Conn.Close(); err != nil {
println(err.Error())
}
println("closed: " + c.Conn.RemoteAddr().String())
}
func (s *Server) handleTCP(c *Client) {
if err := c.Conn.(*net.TCPConn).SetLinger(0); err != nil {
fmt.Println("error while setting setlinger:", err.Error())
}
defer closeConn(c)
if rated := connRateCheck(c); rated {
return
}
c.read = bufio.NewReader(c.Conn)
if _, err := c.Conn.Write(loginBanner()); err != nil {
return
}
for !c.loggedin {
if !c.connected {
return
}
s.preLogin(c)
}
s.mainPrompt(c)
}
func (c *Client) send(data string) {
if err := c.Conn.SetReadDeadline(time.Now().Add(c.deadline)); err != nil {
fmt.Println("error while setting deadline:", err.Error())
}
if _, err := c.Conn.Write([]byte(data)); err != nil {
c.connected = false
}
}
func (c *Client) recv() string {
if err := c.Conn.SetReadDeadline(time.Now().Add(c.deadline)); err != nil {
fmt.Println("error while setting deadline:", err.Error())
}
if !isExempt(c) {
if CmdRater.Check(c) {
if !c.loggedin {
// if they hit the ratelimiter during log-in, disconnect them
c.connected = false
}
time.Sleep(time.Duration(1250) * time.Millisecond)
}
}
in, err := c.read.ReadString('\n')
if err != nil {
println(c.UniqueKey() + ": " + err.Error())
c.connected = false
return in
}
c.read.Reset(c.Conn)
return strings.ToLower(strings.TrimRight(in, "\n"))
}
func randUint32() uint32 {
b := make([]byte, 4096)
if _, err := rand.Read(b); err != nil {
panic(err)
}
return binary.BigEndian.Uint32(b)
}
func keygen() string {
chrlen := len(charset)
b := make([]byte, keySize)
for i := 0; i != keySize; i++ {
b[i] = charset[randUint32()%uint32(chrlen)]
}
return string(b)
}
// getUnusedKey assures that our newly generated ID is not in use
func (s *Server) getUnusedID() string {
s.mu.RLock()
var newkey string
for {
newkey = keygen()
if _, ok := s.Map[newkey]; !ok {
break
} else {
println("key already exists! generating new...")
}
}
s.mu.RUnlock()
return newkey
}
// setID sets the clients ID safely
func (s *Server) setID(c *Client, id string) {
s.mu.Lock()
defer s.mu.Unlock()
c.ID = id
s.Map[id] = c
}
func (s *Server) replaceSession(c *Client, id string) {
s.mu.Lock()
s.AuthLog[id] = append(s.AuthLog[id], Login{
// we're not logged in so UniqueKey is still the IP address
IP: c.UniqueKey(),
Time: time.Now(),
})
defer s.mu.Unlock()
delete(s.Map, id)
s.Map[id] = c
c.ID = id
}
func (s *Server) authCheck(c *Client, id string) bool {
s.mu.RLock()
if old, ok := s.Map[id]; ok {
s.mu.RUnlock()
old.connected = false
closeConn(old)
s.replaceSession(c, id)
return true
}
s.mu.RUnlock()
return false
}
func loginBanner() []byte {
var data []byte
var err error
login := "CnwgG1s5MDs0MG1SG1swbRtbMG0gG1s5Nzs0MG3DhhtbMG0bWzBtIBtbOTc7NDBtzpMbWzBtG1swbSAbWzk3OzQwbc6jG1swbRtbMG0gG1swbRtbOTc7MzJtNRtbMG0bWzBtIHwKCg=="
if data, err = base64.StdEncoding.DecodeString(login); err == nil {
return data
}
panic(err)
}
func main() {
l, err := net.Listen("tcp", "127.0.0.1:4444")
if err != nil {
panic(err.Error())
}
println("listening...")
for {
conn, err := l.Accept()
if err != nil {
println(err.Error())
}
go srv.handleTCP(&Client{
Conn: conn,
connected: true,
loggedin: false,
deadline: time.Duration(12) * time.Second,
})
}
}

2
go.mod
View File

@ -1,5 +1,5 @@
module github.com/yunginnanet/Rate5
go 1.17
go 1.18
require github.com/patrickmn/go-cache v2.1.0+incompatible

View File

@ -1,8 +1,8 @@
package rate5
import (
"sync/atomic"
"sync"
"github.com/patrickmn/go-cache"
)
@ -13,11 +13,6 @@ const (
DefaultBurst = 25
)
const (
stateUnlocked uint32 = iota
stateLocked
)
var debugChannel chan string
// Identity is an interface that allows any arbitrary type to be used for a unique key in ratelimit checks when implemented.
@ -25,11 +20,6 @@ type Identity interface {
UniqueKey() string
}
type rated struct {
seen *atomic.Value
locker uint32
}
// Limiter implements an Enforcer to create an arbitrary ratelimiter.
type Limiter struct {
Source Identity
@ -37,23 +27,22 @@ type Limiter struct {
Patrons *cache.Cache
// Ruleset is the actual ratelimitting model.
Ruleset Policy
/* Debug mode (toggled here) enables debug messages
/* debug mode (toggled here) enables debug messages
delivered through a channel. See: DebugChannel() */
Debug bool
debug bool
locker uint32
count atomic.Value
known map[interface{}]rated
known map[interface{}]*int64
dmu *sync.RWMutex
*sync.RWMutex
}
// Policy defines the mechanics of our ratelimiter.
type Policy struct {
// Window defines the duration in seconds that we should keep track of ratelimit triggers,
Window int
Window int64
// Burst is the amount of times that Check will not trigger a limit within the duration defined by Window.
Burst int
Burst int64
// Strict mode punishes triggers of the ratelimitby increasing the amount of time they have to wait every time they trigger the limitter.
Strict bool
}

View File

@ -2,6 +2,7 @@ package rate5
import (
"fmt"
"strings"
"sync"
"sync/atomic"
"time"
@ -26,8 +27,8 @@ func NewCustomLimiter(policy Policy) *Limiter {
// NewLimiter returns a custom limiter witout Strict mode
func NewLimiter(window int, burst int) *Limiter {
return newLimiter(Policy{
Window: window,
Burst: burst,
Window: int64(window),
Burst: int64(burst),
Strict: false,
})
}
@ -44,92 +45,90 @@ func NewDefaultStrictLimiter() *Limiter {
// NewStrictLimiter returns a custom limiter with Strict mode.
func NewStrictLimiter(window int, burst int) *Limiter {
return newLimiter(Policy{
Window: window,
Burst: burst,
Window: int64(window),
Burst: int64(burst),
Strict: true,
})
}
func newLimiter(policy Policy) *Limiter {
q := new(Limiter)
q.Ruleset = policy
q.Patrons = cache.New(time.Duration(q.Ruleset.Window)*time.Second, 5*time.Second)
q.known = make(map[interface{}]rated)
q.dmu = &sync.RWMutex{}
q.SetDebug(false)
return q
return &Limiter{
Ruleset: policy,
Patrons: cache.New(time.Duration(policy.Window)*time.Second, 5*time.Second),
known: make(map[interface{}]*int64),
RWMutex: &sync.RWMutex{},
debug: false,
}
}
func (q *Limiter) SetDebug(on bool) {
q.dmu.Lock()
q.Debug = on
q.dmu.Unlock()
q.Lock()
q.debug = on
q.Unlock()
}
// DebugChannel enables Debug mode and returns a channel where debug messages are sent (NOTE: You must read from this channel if created via this function or it will block)
// DebugChannel enables debug mode and returns a channel where debug messages are sent.
// NOTE: You must read from this channel if created via this function or it will block
func (q *Limiter) DebugChannel() chan string {
q.dmu.Lock()
q.Lock()
q.Patrons.OnEvicted(func(src string, count interface{}) {
q.debugPrint("ratelimit (expired): ", src, " ", count)
})
q.Debug = true
q.debug = true
debugChannel = make(chan string, 20)
q.dmu.Unlock()
q.Unlock()
return debugChannel
}
func (s rated) inc() {
for !atomic.CompareAndSwapUint32(&s.locker, stateUnlocked, stateLocked) {
time.Sleep(10 * time.Millisecond)
}
defer atomic.StoreUint32(&s.locker, stateUnlocked)
if s.seen.Load() == nil {
s.seen.Store(1)
return
}
s.seen.Store(s.seen.Load().(int) + 1)
func intPtr(i int64) *int64 {
return &i
}
func (q *Limiter) strictLogic(src string, count int) {
for !atomic.CompareAndSwapUint32(&q.locker, stateUnlocked, stateLocked) {
time.Sleep(10 * time.Millisecond)
func (q *Limiter) getHitsPtr(src string) *int64 {
q.RLock()
defer q.RUnlock()
if _, ok := q.known[src]; ok {
return q.known[src]
}
defer atomic.StoreUint32(&q.locker, stateUnlocked)
q.RUnlock()
q.Lock()
q.known[src] = intPtr(0)
q.Unlock()
q.RLock()
return q.known[src]
}
if _, ok := q.known[src]; !ok {
q.known[src] = rated{
seen: &atomic.Value{},
locker: stateUnlocked,
}
}
q.known[src].inc()
extwindow := q.Ruleset.Window + q.known[src].seen.Load().(int)
func (q *Limiter) strictLogic(src string, count int64) {
knownHits := q.getHitsPtr(src)
atomic.AddInt64(knownHits, 1)
extwindow := q.Ruleset.Window + atomic.LoadInt64(knownHits)
if err := q.Patrons.Replace(src, count, time.Duration(extwindow)*time.Second); err != nil {
q.debugPrint("ratelimit: " + err.Error())
q.debugPrint("ratelimit (strict) error: " + err.Error())
}
q.debugPrint("ratelimit (strictly limited): ", count, " ", src)
q.increment()
q.debugPrint("ratelimit (strict) limited: ", count, " ", src)
}
// Check checks and increments an Identities UniqueKey() output against a list of cached strings to determine and raise it's ratelimitting status.
func (q *Limiter) Check(from Identity) bool {
var count int
func (q *Limiter) Check(from Identity) (limited bool) {
var count int64
var err error
src := from.UniqueKey()
if count, err = q.Patrons.IncrementInt(src, 1); err != nil {
q.debugPrint("ratelimit (new): ", src)
if err := q.Patrons.Add(src, 1, time.Duration(q.Ruleset.Window)*time.Second); err != nil {
q.debugPrint("ratelimit: " + err.Error())
count, err = q.Patrons.IncrementInt64(src, 1)
if err != nil {
if strings.Contains(err.Error(), "not found") {
q.debugPrint("ratelimit (new): ", src)
if cacheErr := q.Patrons.Add(src, int64(1), time.Duration(q.Ruleset.Window)*time.Second); cacheErr != nil {
q.debugPrint("ratelimit error: " + cacheErr.Error())
}
return false
}
return false
q.debugPrint("ratelimit error: " + err.Error())
return true
}
if count < q.Ruleset.Burst {
return false
}
if !q.Ruleset.Strict {
q.increment()
q.debugPrint("ratelimit (limited): ", count, " ", src)
return true
}
@ -140,7 +139,7 @@ func (q *Limiter) Check(from Identity) bool {
// Peek checks an Identities UniqueKey() output against a list of cached strings to determine ratelimitting status without adding to its request count.
func (q *Limiter) Peek(from Identity) bool {
if ct, ok := q.Patrons.Get(from.UniqueKey()); ok {
count := ct.(int)
count := ct.(int64)
if count > q.Ruleset.Burst {
return true
}
@ -148,28 +147,12 @@ func (q *Limiter) Peek(from Identity) bool {
return false
}
func (q *Limiter) increment() {
if q.count.Load() == nil {
q.count.Store(1)
return
}
q.count.Store(q.count.Load().(int) + 1)
}
// GetGrandTotalRated returns the historic total amount of times we have ever reported something as ratelimited.
func (q *Limiter) GetGrandTotalRated() int {
if q.count.Load() == nil {
return 0
}
return q.count.Load().(int)
}
func (q *Limiter) debugPrint(a ...interface{}) {
q.dmu.RLock()
if q.Debug {
q.dmu.RUnlock()
q.RLock()
if q.debug {
q.RUnlock()
debugChannel <- fmt.Sprint(a...)
return
}
q.dmu.RUnlock()
q.RUnlock()
}

View File

@ -3,21 +3,16 @@ package rate5
import (
"crypto/rand"
"encoding/binary"
"sync"
"testing"
"time"
)
var (
dummyTicker *ticker
testDebug = true
stopDebug chan bool
stopDebug = make(chan bool)
)
func init() {
stopDebug = make(chan bool)
}
type randomPatron struct {
key string
Identity
@ -38,7 +33,7 @@ func randomUint32() uint32 {
}
func (rp *randomPatron) GenerateKey() {
var keylen = 8
var keylen = 10
buf := make([]byte, keylen)
for n := 0; n != keylen; n++ {
buf[n] = charset[randomUint32()%uint32(len(charset))]
@ -49,33 +44,19 @@ func (rp *randomPatron) GenerateKey() {
func watchDebug(r *Limiter, t *testing.T) {
t.Logf("debug enabled")
rd := r.DebugChannel()
pre := "[Rate5] "
var lastcount = 0
var count = 0
for {
select {
case msg := <-rd:
t.Logf("%s Limit: %s \n", pre, msg)
case <-stopDebug:
return
default:
count++
if count-lastcount >= 10 {
lastcount = count
t.Logf("Times limited: %d", r.GetGrandTotalRated())
}
time.Sleep(10 * time.Millisecond)
}
}
}
type ticker struct{}
func init() {
dummyTicker = &ticker{}
}
func (tick *ticker) UniqueKey() string {
return "tick"
}
@ -87,11 +68,8 @@ func Test_NewCustomLimiter(t *testing.T) {
Strict: false,
})
//goland:noinspection GoBoolExpressions
if testDebug {
go watchDebug(limiter, t)
time.Sleep(25 * time.Millisecond)
}
go watchDebug(limiter, t)
time.Sleep(25 * time.Millisecond)
for n := 0; n < 9; n++ {
limiter.Check(dummyTicker)
@ -111,11 +89,6 @@ func Test_NewCustomLimiter(t *testing.T) {
}
}
//goland:noinspection GoBoolExpressions
if testDebug {
t.Logf("[Finished NewCustomLimiter] Times ratelimited: %d", limiter.GetGrandTotalRated())
}
stopDebug <- true
limiter = nil
}
@ -125,11 +98,8 @@ func Test_NewDefaultStrictLimiter(t *testing.T) {
// DefaultWindow = 5
limiter := NewDefaultStrictLimiter()
//goland:noinspection GoBoolExpressions
if testDebug {
go watchDebug(limiter, t)
time.Sleep(25 * time.Millisecond)
}
go watchDebug(limiter, t)
time.Sleep(25 * time.Millisecond)
for n := 0; n < 24; n++ {
limiter.Check(dummyTicker)
@ -150,33 +120,27 @@ func Test_NewDefaultStrictLimiter(t *testing.T) {
}
}
//goland:noinspection GoBoolExpressions
if testDebug {
t.Logf("[Finished NewCustomLimiter] Times ratelimited: %d", limiter.GetGrandTotalRated())
}
stopDebug <- true
limiter = nil
}
// This test is only here for safety, if the package is not safe, this will often panic.
// We give this a healthy amount of padding in terms of our checks as this is far beyond the tolerances we expect during runtime.
// At the end of the day, not panicing here is passing.
func Test_ConcurrentSafetyTest(t *testing.T) {
func concurrentTest(t *testing.T, jobs int, iterCount int, burst int64, shouldLimit bool) { //nolint:funlen
var randos map[int]*randomPatron
randos = make(map[int]*randomPatron)
limiter := NewCustomLimiter(Policy{
Window: 240,
Burst: 5000,
Burst: burst,
Strict: true,
})
limitNotice := sync.Once{}
limiter.SetDebug(false)
usedkeys := make(map[string]interface{})
const jobs = 1000
for n := 0; n != jobs ; n++ {
for n := 0; n != jobs; n++ {
randos[n] = new(randomPatron)
ok := true
for ok {
@ -188,56 +152,59 @@ func Test_ConcurrentSafetyTest(t *testing.T) {
}
}
t.Logf("generated %d Patrons with unique keys", len(randos))
t.Logf("generated %d Patrons with unique keys, running Check() with them %d times concurrently with a burst limit of %d...",
len(randos), iterCount, burst)
doneChan := make(chan bool, 10)
finChan := make(chan bool, 10)
var finished = 0
for _, rp := range randos {
for n := 0; n != 5; n++ {
for n := 0; n != iterCount; n++ {
go func(randomp *randomPatron) {
limiter.Check(randomp)
limiter.Peek(randomp)
if limiter.Peek(randomp) {
limitNotice.Do(func() {
t.Logf("(sync.Once) %s limited", randomp.UniqueKey())
})
}
finChan <- true
}(rp)
}
}
var done = false
testloop:
for {
select {
case <-debugChannel:
t.Logf("[debug] %s", <-debugChannel)
case <-finChan:
finished++
default:
if finished == (jobs * 5) {
done = true
break
if finished >= (jobs * iterCount) {
break testloop
}
}
if done {
go func() {
doneChan <- true
}()
break
}
}
<-doneChan
println("done")
for _, rp := range randos {
if limiter.Peek(rp) {
if limiter.Peek(rp) && !shouldLimit {
if ct, ok := limiter.Patrons.Get(rp.UniqueKey()); ok {
t.Logf("WARN: Should not have been limited. Ratelimiter count: %d, policy: %d", ct, limiter.Ruleset.Burst)
t.Errorf("WARN: Should not have been limited. Ratelimiter count: %d, policy: %d", ct, limiter.Ruleset.Burst)
} else {
t.Errorf("randomPatron does not exist in ratelimiter at all!")
}
}
}
//goland:noinspection GoBoolExpressions
if testDebug {
t.Logf("[Finished StrictConcurrentStressTest] Times ratelimited: %d", limiter.GetGrandTotalRated())
}
}
func Test_ConcurrentShouldNotLimit(t *testing.T) {
concurrentTest(t, 100, 20, 20, false)
concurrentTest(t, 100, 50, 50, false)
}
func Test_ConcurrentShouldLimit(t *testing.T) {
concurrentTest(t, 100, 21, 20, true)
concurrentTest(t, 100, 51, 50, true)
}