1
4
mirror of https://github.com/yunginnanet/HellPot synced 2024-06-25 23:38:02 +00:00

Feat: Fix and finish bandwidth limiting io.Writer

This commit is contained in:
kayos@tcp.direct 2023-03-05 03:56:26 -08:00
parent b8b4b56cba
commit c61f6b4a9c
Signed by: kayos
GPG Key ID: 4B841471B4BEE979
2 changed files with 324 additions and 126 deletions

@ -1,7 +1,6 @@
package util
import (
"bufio"
"errors"
"io"
"sync"
@ -11,84 +10,110 @@ import (
var ErrLimitReached = errors.New("limit reached")
// Speedometer is a wrapper around an io.Writer that will limit the rate at which data is written to the underlying writer.
//
// It is safe for concurrent use, but writers will block when slowed down.
//
// Optionally, it can be given;
//
// - a capacity, which will cause it to return an error if the capacity is exceeded.
//
// - a speed limit, causing slow downs of data written to the underlying writer if the speed limit is exceeded.
type Speedometer struct {
cap int64
speedLimit SpeedLimit
birth *time.Time
duration time.Duration
internal atomics
w io.Writer
BufioWriter *bufio.Writer
cap int64
speedLimit *SpeedLimit
internal atomics
hardLock *sync.RWMutex
w io.Writer
}
type atomics struct {
closed *atomic.Bool
count *int64
start *sync.Once
stop *sync.Once
}
func NewSpeedometer(w io.Writer) *Speedometer {
z := int64(0)
speedo := &Speedometer{
w: w,
cap: -1,
internal: atomics{
count: &z,
closed: new(atomic.Bool),
stop: new(sync.Once),
start: new(sync.Once),
},
}
speedo.internal.closed.Store(false)
speedo.BufioWriter = bufio.NewWriter(speedo)
return speedo
closed *atomic.Bool
count *int64
start *sync.Once
stop *sync.Once
birth *atomic.Pointer[time.Time]
duration *atomic.Pointer[time.Duration]
slow *atomic.Bool
}
// SpeedLimit is used to limit the rate at which data is written to the underlying writer.
type SpeedLimit struct {
Bytes int
Per time.Duration
// Burst is the number of bytes that can be written to the underlying writer per Frame.
Burst int
// Frame is the duration of the frame in which Burst can be written to the underlying writer.
Frame time.Duration
// CheckEveryBytes is the number of bytes written before checking if the speed limit has been exceeded.
CheckEveryBytes int
Delay time.Duration
// Delay is the duration to delay writing if the speed limit has been exceeded during a Write call. (blocking)
Delay time.Duration
}
func NewLimitedSpeedometer(w io.Writer, speedLimit SpeedLimit) *Speedometer {
func regulateSpeedLimit(speedLimit *SpeedLimit) (*SpeedLimit, error) {
if speedLimit.Burst <= 0 || speedLimit.Frame <= 0 {
return nil, errors.New("invalid speed limit")
}
if speedLimit.CheckEveryBytes <= 0 {
speedLimit.CheckEveryBytes = speedLimit.Burst
}
if speedLimit.Delay <= 0 {
speedLimit.Delay = 100 * time.Millisecond
}
return speedLimit, nil
}
func newSpeedometer(w io.Writer, speedLimit *SpeedLimit, cap int64) (*Speedometer, error) {
if w == nil {
return nil, errors.New("writer cannot be nil")
}
var err error
if speedLimit != nil {
if speedLimit, err = regulateSpeedLimit(speedLimit); err != nil {
return nil, err
}
}
z := int64(0)
speedo := &Speedometer{
w: w,
cap: -1,
cap: cap,
speedLimit: speedLimit,
hardLock: &sync.RWMutex{},
internal: atomics{
count: &z,
closed: new(atomic.Bool),
stop: new(sync.Once),
start: new(sync.Once),
count: &z,
birth: new(atomic.Pointer[time.Time]),
duration: new(atomic.Pointer[time.Duration]),
closed: new(atomic.Bool),
stop: new(sync.Once),
start: new(sync.Once),
slow: new(atomic.Bool),
},
}
speedo.internal.birth.Store(&time.Time{})
speedo.internal.closed.Store(false)
speedo.BufioWriter = bufio.NewWriterSize(speedo, speedLimit.CheckEveryBytes)
return speedo
return speedo, nil
}
func NewCappedSpeedometer(w io.Writer, cap int64) *Speedometer {
z := int64(0)
speedo := &Speedometer{
w: w,
cap: cap,
internal: atomics{
count: &z,
closed: new(atomic.Bool),
stop: new(sync.Once),
start: new(sync.Once),
},
}
speedo.internal.closed.Store(false)
speedo.BufioWriter = bufio.NewWriterSize(speedo, int(cap)/10)
return speedo
// NewSpeedometer creates a new Speedometer that wraps the given io.Writer.
// It will not limit the rate at which data is written to the underlying writer, it only measures it.
func NewSpeedometer(w io.Writer) (*Speedometer, error) {
return newSpeedometer(w, nil, -1)
}
// NewLimitedSpeedometer creates a new Speedometer that wraps the given io.Writer.
// If the speed limit is exceeded, writes to the underlying writer will be limited.
// See SpeedLimit for more information.
func NewLimitedSpeedometer(w io.Writer, speedLimit *SpeedLimit) (*Speedometer, error) {
return newSpeedometer(w, speedLimit, -1)
}
// NewCappedSpeedometer creates a new Speedometer that wraps the given io.Writer.
// If len(written) bytes exceeds cap, writes to the underlying writer will be ceased permanently for the Speedometer.
func NewCappedSpeedometer(w io.Writer, cap int64) (*Speedometer, error) {
return newSpeedometer(w, nil, cap)
}
func (s *Speedometer) increment(inc int64) (int, error) {
if s.internal.closed.Load() {
if s.internal.closed.Load() || !s.hardLock.TryRLock() {
return 0, io.ErrClosedPipe
}
var err error
@ -101,43 +126,73 @@ func (s *Speedometer) increment(inc int64) (int, error) {
return int(inc), err
}
// Running returns true if the Speedometer is still running.
func (s *Speedometer) Running() bool {
return !s.internal.closed.Load()
}
// Total returns the total number of bytes written to the underlying writer.
func (s *Speedometer) Total() int64 {
return atomic.LoadInt64(s.internal.count)
}
func (s *Speedometer) Reset() {
s.internal.count = new(int64)
s.internal.closed = new(atomic.Bool)
s.internal.start = new(sync.Once)
s.internal.stop = new(sync.Once)
s.BufioWriter = bufio.NewWriter(s)
s.internal.closed.Store(false)
}
// Close stops the Speedometer. No additional writes will be accepted.
func (s *Speedometer) Close() error {
s.hardLock.TryLock()
if s.internal.closed.Load() {
return io.ErrClosedPipe
}
s.internal.stop.Do(func() {
s.internal.closed.Store(true)
stopped := time.Now()
s.duration = stopped.Sub(*s.birth)
birth := s.internal.birth.Load()
duration := stopped.Sub(*birth)
s.internal.duration.Store(&duration)
})
return nil
}
/*func (s *Speedometer) IsSlow() bool {
return s.internal.slow.Load()
}*/
// Rate returns the rate at which data is being written to the underlying writer per second.
func (s *Speedometer) Rate() float64 {
if s.internal.closed.Load() {
return float64(s.Total()) / s.duration.Seconds()
return float64(s.Total()) / s.internal.duration.Load().Seconds()
}
return float64(s.Total()) / time.Since(*s.birth).Seconds()
return float64(s.Total()) / time.Since(*s.internal.birth.Load()).Seconds()
}
func (s *Speedometer) slowDown() error {
switch {
case s.speedLimit == nil:
return nil
case s.speedLimit.Burst <= 0 || s.speedLimit.Frame <= 0,
s.speedLimit.CheckEveryBytes <= 0, s.speedLimit.Delay <= 0:
return errors.New("invalid speed limit")
default:
//
}
if s.Total()%int64(s.speedLimit.CheckEveryBytes) != 0 {
return nil
}
s.internal.slow.Store(true)
for s.Rate() > float64(s.speedLimit.Burst)/s.speedLimit.Frame.Seconds() {
time.Sleep(s.speedLimit.Delay)
}
s.internal.slow.Store(false)
return nil
}
// Write writes p to the underlying writer, following all defined speed limits.
func (s *Speedometer) Write(p []byte) (n int, err error) {
if !s.hardLock.TryRLock() {
return 0, io.ErrClosedPipe
}
s.internal.start.Do(func() {
tn := time.Now()
s.birth = &tn
now := time.Now()
s.internal.birth.Store(&now)
})
accepted, err := s.increment(int64(len(p)))
if err != nil {
@ -145,14 +200,10 @@ func (s *Speedometer) Write(p []byte) (n int, err error) {
if innerErr != nil {
err = innerErr
}
if wn < accepted {
err = io.ErrShortWrite
}
return wn, err
}
if err = s.slowDown(); err != nil {
return 0, err
}
return s.w.Write(p)
}
/*func BenchmarkHeffalump_WriteHell(b *testing.B) {
}*/

@ -1,16 +1,27 @@
package util
import (
"bufio"
"errors"
"io"
"sync"
"sync/atomic"
"testing"
"time"
)
func writeStuff(target *bufio.Writer, count int) error {
type testWriter struct {
t *testing.T
total int64
}
func (w *testWriter) Write(p []byte) (n int, err error) {
atomic.AddInt64(&w.total, int64(len(p)))
return len(p), nil
}
func writeStuff(target io.Writer, count int) error {
write := func() error {
_, err := target.WriteString("a")
_, err := target.Write([]byte("a"))
if err != nil {
return err
}
@ -18,11 +29,11 @@ func writeStuff(target *bufio.Writer, count int) error {
}
if count < 0 {
for {
if err := write(); err != nil {
return err
}
var err error
for err = write(); err == nil; err = write() {
time.Sleep(5 * time.Millisecond)
}
return err
}
for i := 0; i < count; i++ {
if err := write(); err != nil {
@ -55,47 +66,183 @@ func Test_Speedometer(t *testing.T) {
}
}
sp := NewSpeedometer(io.Discard)
errChan := make(chan error)
go func() {
errChan <- writeStuff(sp.BufioWriter, -1)
}()
time.Sleep(1 * time.Second)
_ = sp.Close()
err := <-errChan
cnt, err := sp.Write([]byte("a"))
isIt(results{err: io.ErrClosedPipe, written: 0}, results{err: err, written: cnt})
sp.Reset()
cnt, err = sp.Write([]byte("a"))
isIt(results{err: nil, written: 1, total: 1}, results{err: err, written: cnt, total: sp.Total()})
cnt, err = sp.Write([]byte("aa"))
isIt(results{err: nil, written: 2, total: 3}, results{err: err, written: cnt, total: sp.Total()})
cnt, err = sp.Write([]byte("a"))
isIt(results{err: nil, written: 1, total: 4}, results{err: err, written: cnt, total: sp.Total()})
cnt, err = sp.Write([]byte("a"))
isIt(results{err: nil, written: 1, total: 5}, results{err: err, written: cnt, total: sp.Total()})
go func() {
errChan <- writeStuff(sp.BufioWriter, -1)
}()
var count = 0
for sp.Running() {
select {
case err = <-errChan:
if !errors.Is(err, io.ErrClosedPipe) {
t.Errorf("unexpected error: %v", err)
} else {
if count < 5 {
t.Errorf("too few iterations: %d", count)
}
t.Logf("final rate: %v per second", sp.Rate())
}
default:
if count > 5 {
_ = sp.Close()
}
t.Logf("rate: %v per second", sp.Rate())
time.Sleep(1 * time.Second)
count++
var (
errChan = make(chan error, 10)
err error
cnt int
)
t.Run("EarlyClose", func(t *testing.T) {
sp, nerr := NewSpeedometer(&testWriter{t: t})
if nerr != nil {
t.Errorf("unexpected error: %v", nerr)
}
}
go func() {
errChan <- writeStuff(sp, -1)
}()
time.Sleep(1 * time.Second)
if closeErr := sp.Close(); closeErr != nil {
t.Errorf("wantErr: want %v, have %v", nil, closeErr)
}
err = <-errChan
if !errors.Is(err, io.ErrClosedPipe) {
t.Errorf("wantErr: want %v, have %v", io.ErrClosedPipe, err)
}
cnt, err = sp.Write([]byte("a"))
isIt(results{err: io.ErrClosedPipe, written: 0}, results{err: err, written: cnt})
})
t.Run("Basic", func(t *testing.T) {
sp, nerr := NewSpeedometer(&testWriter{t: t})
if nerr != nil {
t.Errorf("unexpected error: %v", nerr)
}
cnt, err = sp.Write([]byte("a"))
isIt(results{err: nil, written: 1, total: 1}, results{err: err, written: cnt, total: sp.Total()})
cnt, err = sp.Write([]byte("aa"))
isIt(results{err: nil, written: 2, total: 3}, results{err: err, written: cnt, total: sp.Total()})
cnt, err = sp.Write([]byte("a"))
isIt(results{err: nil, written: 1, total: 4}, results{err: err, written: cnt, total: sp.Total()})
cnt, err = sp.Write([]byte("a"))
isIt(results{err: nil, written: 1, total: 5}, results{err: err, written: cnt, total: sp.Total()})
})
t.Run("ConcurrentWrites", func(t *testing.T) {
count := int64(0)
sp, nerr := NewSpeedometer(&testWriter{t: t})
if nerr != nil {
t.Errorf("unexpected error: %v", nerr)
}
wg := &sync.WaitGroup{}
for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
var counted int
var err error
counted, err = sp.Write([]byte("a"))
if err != nil {
t.Errorf("unexpected error: %v", err)
}
atomic.AddInt64(&count, int64(counted))
wg.Done()
}()
}
wg.Wait()
isIt(results{err: nil, written: 1, total: 100}, results{err: err, written: cnt, total: sp.Total()})
})
t.Run("GottaGoFast", func(t *testing.T) {
sp, nerr := NewSpeedometer(&testWriter{t: t})
if nerr != nil {
t.Errorf("unexpected error: %v", nerr)
}
go func() {
errChan <- writeStuff(sp, -1)
}()
var count = 0
for sp.Running() {
select {
case err = <-errChan:
if !errors.Is(err, io.ErrClosedPipe) {
t.Errorf("unexpected error: %v", err)
} else {
if count < 5 {
t.Errorf("too few iterations: %d", count)
}
t.Logf("final rate: %v per second", sp.Rate())
}
default:
if count > 5 {
_ = sp.Close()
}
time.Sleep(100 * time.Millisecond)
t.Logf("rate: %v per second", sp.Rate())
count++
}
}
})
// test limiter with speedlimit
t.Run("CantGoFast", func(t *testing.T) {
t.Run("10BytesASecond", func(t *testing.T) {
sp, nerr := NewLimitedSpeedometer(&testWriter{t: t}, &SpeedLimit{
Burst: 10,
Frame: time.Second,
CheckEveryBytes: 1,
Delay: 100 * time.Millisecond,
})
if nerr != nil {
t.Errorf("unexpected error: %v", nerr)
}
for i := 0; i < 15; i++ {
if _, err = sp.Write([]byte("a")); err != nil {
t.Errorf("unexpected error: %v", err)
}
/*if sp.IsSlow() {
t.Errorf("unexpected slow state")
}*/
t.Logf("rate: %v per second", sp.Rate())
if sp.Rate() > 10 {
t.Errorf("speeding in a school zone (expected under %d): %v", sp.speedLimit.Burst, sp.Rate())
}
}
})
t.Run("1000BytesPer5SecondsMeasuredEvery5000Bytes", func(t *testing.T) {
sp, nerr := NewLimitedSpeedometer(&testWriter{t: t}, &SpeedLimit{
Burst: 1000,
Frame: 2 * time.Second,
CheckEveryBytes: 5000,
Delay: 500 * time.Millisecond,
})
if nerr != nil {
t.Errorf("unexpected error: %v", nerr)
}
for i := 0; i < 4999; i++ {
if _, err = sp.Write([]byte("a")); err != nil {
t.Errorf("unexpected error: %v", err)
}
if i%1000 == 0 {
t.Logf("rate: %v per second", sp.Rate())
}
if sp.Rate() < 1000 {
t.Errorf("shouldn't have slowed down yet (expected over %d): %v", sp.speedLimit.Burst, sp.Rate())
}
}
if _, err = sp.Write([]byte("a")); err != nil {
t.Errorf("unexpected error: %v", err)
}
for i := 0; i < 10; i++ {
if _, err = sp.Write([]byte("a")); err != nil {
t.Errorf("unexpected error: %v", err)
}
t.Logf("rate: %v per second", sp.Rate())
if sp.Rate() > 1000 {
t.Errorf("speeding in a school zone (expected under %d): %v", sp.speedLimit.Burst, sp.Rate())
}
}
})
})
// test capped speedometer
t.Run("OnlyALittle", func(t *testing.T) {
sp, nerr := NewCappedSpeedometer(&testWriter{t: t}, 1024)
if nerr != nil {
t.Errorf("unexpected error: %v", nerr)
}
for i := 0; i < 1024; i++ {
if _, err = sp.Write([]byte("a")); err != nil {
t.Errorf("unexpected error: %v", err)
}
if sp.Total() > 1024 {
t.Errorf("shouldn't have written more than 1024 bytes")
}
}
if _, err = sp.Write([]byte("a")); err == nil {
t.Errorf("expected error when writing over capacity")
}
})
}