mirror of
https://github.com/yunginnanet/Rate5
synced 2024-06-28 10:00:52 +00:00
Enhance: implement atomic.Value for strict logic + update example
This commit is contained in:
parent
fa5679adcf
commit
27822b0603
@ -7,6 +7,7 @@ import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@ -58,12 +59,35 @@ type Client struct {
|
||||
|
||||
// UniqueKey is an implementation of our Identity interface, in short: Rate5 doesn't care where you derive the string used for ratelimiting
|
||||
func (c Client) UniqueKey() string {
|
||||
var err error
|
||||
var host string
|
||||
if c.loggedin {
|
||||
return c.ID
|
||||
}
|
||||
if host, _, err = net.SplitHostPort(c.Conn.RemoteAddr().String()); err == nil {
|
||||
return host
|
||||
}
|
||||
panic(err)
|
||||
}
|
||||
|
||||
host, _, _ := net.SplitHostPort(c.Conn.RemoteAddr().String())
|
||||
return host
|
||||
func argParse() {
|
||||
if len(os.Args) < 1 {
|
||||
return
|
||||
}
|
||||
for i, arg := range os.Args {
|
||||
switch arg {
|
||||
case "-e":
|
||||
fallthrough
|
||||
case "--exempt":
|
||||
if len(os.Args) <= i+1 {
|
||||
return
|
||||
}
|
||||
srv.Exempt[os.Args[i+1]] = true
|
||||
default:
|
||||
continue
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
func init() {
|
||||
@ -82,97 +106,70 @@ func init() {
|
||||
mu: &sync.RWMutex{},
|
||||
}
|
||||
|
||||
//srv.Exempt["127.0.0.1"] = true
|
||||
argParse()
|
||||
|
||||
|
||||
rd := Rater.DebugChannel()
|
||||
rrd := RegRater.DebugChannel()
|
||||
crd := CmdRater.DebugChannel()
|
||||
|
||||
pre := "[Rate5] "
|
||||
go func() {
|
||||
var lastcount = 0
|
||||
var count = 0
|
||||
for {
|
||||
select {
|
||||
case msg := <-rd:
|
||||
fmt.Printf("%s Limit: %s \n", pre, msg)
|
||||
count++
|
||||
case msg := <-rrd:
|
||||
fmt.Printf("%s RegLimit: %s \n", pre, msg)
|
||||
count++
|
||||
case msg := <-crd:
|
||||
fmt.Printf("%s CmdLimit: %s \n", pre, msg)
|
||||
count++
|
||||
default:
|
||||
if count - lastcount >= 25 {
|
||||
lastcount = count
|
||||
fmt.Println("Rater: ", Rater.GetGrandTotalRated())
|
||||
fmt.Println("RegRater: ", RegRater.GetGrandTotalRated())
|
||||
fmt.Println("CmdRater: ", CmdRater.GetGrandTotalRated())
|
||||
}
|
||||
time.Sleep(time.Duration(10) * time.Millisecond)
|
||||
}
|
||||
}
|
||||
}()
|
||||
go watchDebug(rd, rrd, crd)
|
||||
}
|
||||
|
||||
func (s *Server) handleTCP(c *Client) {
|
||||
if err := c.Conn.(*net.TCPConn).SetLinger(0); err != nil {
|
||||
fmt.Println("error while setting setlinger:", err.Error())
|
||||
func watchDebug(rd, rrd, crd chan string) {
|
||||
pre := "[Rate5] "
|
||||
var lastcount = 0
|
||||
var count = 0
|
||||
for {
|
||||
select {
|
||||
case msg := <-rd:
|
||||
fmt.Printf("%s Limit: %s \n", pre, msg)
|
||||
count++
|
||||
case msg := <-rrd:
|
||||
fmt.Printf("%s RegLimit: %s \n", pre, msg)
|
||||
count++
|
||||
case msg := <-crd:
|
||||
fmt.Printf("%s CmdLimit: %s \n", pre, msg)
|
||||
count++
|
||||
default:
|
||||
if count-lastcount >= 25 {
|
||||
lastcount = count
|
||||
fmt.Println("Rater: ", Rater.GetGrandTotalRated())
|
||||
fmt.Println("RegRater: ", RegRater.GetGrandTotalRated())
|
||||
fmt.Println("CmdRater: ", CmdRater.GetGrandTotalRated())
|
||||
}
|
||||
time.Sleep(time.Duration(10) * time.Millisecond)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// skip ratelimit checking for exempt clients
|
||||
srv.mu.RLock()
|
||||
_, exempt := srv.Exempt[c.UniqueKey()]
|
||||
srv.mu.RUnlock()
|
||||
|
||||
defer func() {
|
||||
c.Conn.Close()
|
||||
println("closed: " + c.Conn.RemoteAddr().String())
|
||||
}()
|
||||
|
||||
// Returns true if ratelimited
|
||||
if Rater.Check(c) {
|
||||
c.Conn.Write([]byte("too many connections"))
|
||||
println(c.UniqueKey() + " ratelimited")
|
||||
func (s *Server) preLogin(c *Client) {
|
||||
c.send("Auth: ")
|
||||
in := c.recv()
|
||||
switch {
|
||||
case s.authCheck(c, in):
|
||||
c.loggedin = true
|
||||
c.deadline = time.Duration(480) * time.Second
|
||||
c.send("successful login")
|
||||
return
|
||||
case in == "register":
|
||||
// no exemption for strict ratelimiter (rate5 testing)
|
||||
if RegRater.Check(c) {
|
||||
c.send("you already registered recently\n")
|
||||
return
|
||||
}
|
||||
println("new registration from " + c.UniqueKey())
|
||||
s.setID(c, s.getUnusedID())
|
||||
c.send("\nregistration success\n[New ID]: " + c.ID)
|
||||
return
|
||||
default:
|
||||
c.send("invalid. type 'REGISTER' to register a new ID\n")
|
||||
return
|
||||
}
|
||||
|
||||
c.read = bufio.NewReader(c.Conn)
|
||||
|
||||
c.Conn.Write(loginBanner())
|
||||
|
||||
for {
|
||||
if !c.connected {
|
||||
return
|
||||
}
|
||||
|
||||
time.Sleep(time.Duration(25) * time.Millisecond)
|
||||
if !c.loggedin {
|
||||
c.send("Auth: ")
|
||||
in := c.recv()
|
||||
switch {
|
||||
case s.authCheck(c, in):
|
||||
c.loggedin = true
|
||||
c.deadline = time.Duration(480) * time.Second
|
||||
c.send("successful login")
|
||||
continue
|
||||
case in == "register":
|
||||
if !RegRater.Check(c) || exempt {
|
||||
println("new registration from " + c.UniqueKey())
|
||||
s.setID(c, s.getUnusedID())
|
||||
c.send("\nregistration success\n[New ID]: " + c.ID)
|
||||
return
|
||||
} else {
|
||||
c.send("you already registered recently\n")
|
||||
}
|
||||
continue
|
||||
default:
|
||||
c.send("invalid. type 'REGISTER' to register a new ID\n")
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) mainPrompt(c *Client) {
|
||||
for c.connected {
|
||||
c.send("\nRate5 > ")
|
||||
switch c.recv() {
|
||||
case "history":
|
||||
@ -190,10 +187,60 @@ func (s *Server) handleTCP(c *Client) {
|
||||
case "logout":
|
||||
c.loggedin = false
|
||||
return
|
||||
default:
|
||||
c.send("unknown command, are you lost?")
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func isExempt(c *Client) bool {
|
||||
srv.mu.RLock()
|
||||
_, exempt := srv.Exempt[c.UniqueKey()]
|
||||
srv.mu.RUnlock()
|
||||
return exempt
|
||||
}
|
||||
|
||||
func connRateCheck(c *Client) bool {
|
||||
if isExempt(c) {
|
||||
return false
|
||||
}
|
||||
if Rater.Check(c) {
|
||||
c.send("too many connections")
|
||||
println(c.UniqueKey() + " ratelimited")
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func closeConn(c *Client) {
|
||||
if err := c.Conn.Close(); err != nil {
|
||||
println(err.Error())
|
||||
}
|
||||
println("closed: " + c.Conn.RemoteAddr().String())
|
||||
}
|
||||
|
||||
func (s *Server) handleTCP(c *Client) {
|
||||
if err := c.Conn.(*net.TCPConn).SetLinger(0); err != nil {
|
||||
fmt.Println("error while setting setlinger:", err.Error())
|
||||
}
|
||||
defer closeConn(c)
|
||||
if rated := connRateCheck(c); rated {
|
||||
return
|
||||
}
|
||||
c.read = bufio.NewReader(c.Conn)
|
||||
if _, err := c.Conn.Write(loginBanner()); err != nil {
|
||||
return
|
||||
}
|
||||
for !c.loggedin {
|
||||
if !c.connected {
|
||||
return
|
||||
}
|
||||
s.preLogin(c)
|
||||
}
|
||||
s.mainPrompt(c)
|
||||
}
|
||||
|
||||
func (c *Client) send(data string) {
|
||||
if err := c.Conn.SetReadDeadline(time.Now().Add(c.deadline)); err != nil {
|
||||
fmt.Println("error while setting deadline:", err.Error())
|
||||
@ -208,18 +255,16 @@ func (c *Client) recv() string {
|
||||
fmt.Println("error while setting deadline:", err.Error())
|
||||
}
|
||||
|
||||
// skip ratelimit checking for exempt clients
|
||||
srv.mu.RLock()
|
||||
_, ok := srv.Exempt[c.UniqueKey()]
|
||||
srv.mu.RUnlock()
|
||||
|
||||
if CmdRater.Check(c) && !ok {
|
||||
if !c.loggedin {
|
||||
// if they hit the ratelimiter during log-in, disconnect them
|
||||
c.connected = false
|
||||
if !isExempt(c) {
|
||||
if CmdRater.Check(c) {
|
||||
if !c.loggedin {
|
||||
// if they hit the ratelimiter during log-in, disconnect them
|
||||
c.connected = false
|
||||
}
|
||||
time.Sleep(time.Duration(1250) * time.Millisecond)
|
||||
}
|
||||
time.Sleep(time.Duration(1250) * time.Millisecond)
|
||||
}
|
||||
|
||||
in, err := c.read.ReadString('\n')
|
||||
if err != nil {
|
||||
println(c.UniqueKey() + ": " + err.Error())
|
||||
@ -289,20 +334,23 @@ func (s *Server) authCheck(c *Client, id string) bool {
|
||||
if old, ok := s.Map[id]; ok {
|
||||
s.mu.RUnlock()
|
||||
old.connected = false
|
||||
old.Conn.Close()
|
||||
closeConn(old)
|
||||
s.replaceSession(c, id)
|
||||
return true
|
||||
}
|
||||
|
||||
s.mu.RUnlock()
|
||||
return false
|
||||
|
||||
}
|
||||
|
||||
func loginBanner() []byte {
|
||||
var data []byte
|
||||
var err error
|
||||
login := "CnwgG1s5MDs0MG1SG1swbRtbMG0gG1s5Nzs0MG3DhhtbMG0bWzBtIBtbOTc7NDBtzpMbWzBtG1swbSAbWzk3OzQwbc6jG1swbRtbMG0gG1swbRtbOTc7MzJtNRtbMG0bWzBtIHwKCg=="
|
||||
data, _ := base64.StdEncoding.DecodeString(login)
|
||||
return data
|
||||
if data, err = base64.StdEncoding.DecodeString(login); err == nil {
|
||||
return data
|
||||
}
|
||||
panic(err)
|
||||
}
|
||||
|
||||
func main() {
|
||||
|
@ -2,6 +2,7 @@ package rate5
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/patrickmn/go-cache"
|
||||
)
|
||||
@ -20,6 +21,10 @@ type Identity interface {
|
||||
UniqueKey() string
|
||||
}
|
||||
|
||||
type rated struct {
|
||||
seen atomic.Value
|
||||
}
|
||||
|
||||
// Limiter implements an Enforcer to create an arbitrary ratelimiter.
|
||||
type Limiter struct {
|
||||
Source Identity
|
||||
@ -32,7 +37,7 @@ type Limiter struct {
|
||||
Debug bool
|
||||
|
||||
count int
|
||||
known map[interface{}]int
|
||||
known map[interface{}]*rated
|
||||
mu *sync.RWMutex
|
||||
}
|
||||
|
||||
|
@ -3,6 +3,7 @@ package rate5
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/patrickmn/go-cache"
|
||||
@ -53,7 +54,7 @@ func newLimiter(policy Policy) *Limiter {
|
||||
q := new(Limiter)
|
||||
q.Ruleset = policy
|
||||
q.Patrons = cache.New(time.Duration(q.Ruleset.Window)*time.Second, 5*time.Second)
|
||||
q.known = make(map[interface{}]int)
|
||||
q.known = make(map[interface{}]*rated)
|
||||
q.mu = &sync.RWMutex{}
|
||||
return q
|
||||
}
|
||||
@ -68,14 +69,24 @@ func (q *Limiter) DebugChannel() chan string {
|
||||
return debugChannel
|
||||
}
|
||||
|
||||
func (s *rated) inc() {
|
||||
if s.seen.Load() == nil {
|
||||
s.seen.Store(1)
|
||||
return
|
||||
}
|
||||
s.seen.Store(s.seen.Load().(int) + 1)
|
||||
}
|
||||
|
||||
func (q *Limiter) strictLogic(src string, count int) {
|
||||
q.mu.Lock()
|
||||
if _, ok := q.known[src]; !ok {
|
||||
q.known[src] = 1
|
||||
q.known[src]=&rated{
|
||||
seen: atomic.Value{},
|
||||
}
|
||||
}
|
||||
|
||||
q.known[src]++
|
||||
extwindow := q.Ruleset.Window + q.known[src]
|
||||
q.known[src].inc()
|
||||
extwindow := q.Ruleset.Window + q.known[src].seen.Load().(int)
|
||||
|
||||
if err := q.Patrons.Replace(src, count, time.Duration(extwindow)*time.Second); err != nil {
|
||||
q.debugPrint("Rate5: " + err.Error())
|
||||
|
Loading…
Reference in New Issue
Block a user