mirror of
https://github.com/yunginnanet/HellPot
synced 2024-06-28 16:50:51 +00:00
Feat: Fix and finish bandwidth limiting io.Writer
This commit is contained in:
parent
b8b4b56cba
commit
c61f6b4a9c
@ -1,7 +1,6 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"io"
|
||||
"sync"
|
||||
@ -11,14 +10,21 @@ import (
|
||||
|
||||
var ErrLimitReached = errors.New("limit reached")
|
||||
|
||||
// Speedometer is a wrapper around an io.Writer that will limit the rate at which data is written to the underlying writer.
|
||||
//
|
||||
// It is safe for concurrent use, but writers will block when slowed down.
|
||||
//
|
||||
// Optionally, it can be given;
|
||||
//
|
||||
// - a capacity, which will cause it to return an error if the capacity is exceeded.
|
||||
//
|
||||
// - a speed limit, causing slow downs of data written to the underlying writer if the speed limit is exceeded.
|
||||
type Speedometer struct {
|
||||
cap int64
|
||||
speedLimit SpeedLimit
|
||||
birth *time.Time
|
||||
duration time.Duration
|
||||
speedLimit *SpeedLimit
|
||||
internal atomics
|
||||
hardLock *sync.RWMutex
|
||||
w io.Writer
|
||||
BufioWriter *bufio.Writer
|
||||
}
|
||||
|
||||
type atomics struct {
|
||||
@ -26,69 +32,88 @@ type atomics struct {
|
||||
count *int64
|
||||
start *sync.Once
|
||||
stop *sync.Once
|
||||
birth *atomic.Pointer[time.Time]
|
||||
duration *atomic.Pointer[time.Duration]
|
||||
slow *atomic.Bool
|
||||
}
|
||||
|
||||
func NewSpeedometer(w io.Writer) *Speedometer {
|
||||
z := int64(0)
|
||||
speedo := &Speedometer{
|
||||
w: w,
|
||||
cap: -1,
|
||||
internal: atomics{
|
||||
count: &z,
|
||||
closed: new(atomic.Bool),
|
||||
stop: new(sync.Once),
|
||||
start: new(sync.Once),
|
||||
},
|
||||
}
|
||||
speedo.internal.closed.Store(false)
|
||||
speedo.BufioWriter = bufio.NewWriter(speedo)
|
||||
return speedo
|
||||
}
|
||||
|
||||
// SpeedLimit is used to limit the rate at which data is written to the underlying writer.
|
||||
type SpeedLimit struct {
|
||||
Bytes int
|
||||
Per time.Duration
|
||||
// Burst is the number of bytes that can be written to the underlying writer per Frame.
|
||||
Burst int
|
||||
// Frame is the duration of the frame in which Burst can be written to the underlying writer.
|
||||
Frame time.Duration
|
||||
// CheckEveryBytes is the number of bytes written before checking if the speed limit has been exceeded.
|
||||
CheckEveryBytes int
|
||||
// Delay is the duration to delay writing if the speed limit has been exceeded during a Write call. (blocking)
|
||||
Delay time.Duration
|
||||
}
|
||||
|
||||
func NewLimitedSpeedometer(w io.Writer, speedLimit SpeedLimit) *Speedometer {
|
||||
z := int64(0)
|
||||
speedo := &Speedometer{
|
||||
w: w,
|
||||
cap: -1,
|
||||
speedLimit: speedLimit,
|
||||
internal: atomics{
|
||||
count: &z,
|
||||
closed: new(atomic.Bool),
|
||||
stop: new(sync.Once),
|
||||
start: new(sync.Once),
|
||||
},
|
||||
func regulateSpeedLimit(speedLimit *SpeedLimit) (*SpeedLimit, error) {
|
||||
if speedLimit.Burst <= 0 || speedLimit.Frame <= 0 {
|
||||
return nil, errors.New("invalid speed limit")
|
||||
}
|
||||
speedo.internal.closed.Store(false)
|
||||
speedo.BufioWriter = bufio.NewWriterSize(speedo, speedLimit.CheckEveryBytes)
|
||||
return speedo
|
||||
if speedLimit.CheckEveryBytes <= 0 {
|
||||
speedLimit.CheckEveryBytes = speedLimit.Burst
|
||||
}
|
||||
if speedLimit.Delay <= 0 {
|
||||
speedLimit.Delay = 100 * time.Millisecond
|
||||
}
|
||||
return speedLimit, nil
|
||||
}
|
||||
|
||||
func NewCappedSpeedometer(w io.Writer, cap int64) *Speedometer {
|
||||
func newSpeedometer(w io.Writer, speedLimit *SpeedLimit, cap int64) (*Speedometer, error) {
|
||||
if w == nil {
|
||||
return nil, errors.New("writer cannot be nil")
|
||||
}
|
||||
var err error
|
||||
if speedLimit != nil {
|
||||
if speedLimit, err = regulateSpeedLimit(speedLimit); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
z := int64(0)
|
||||
speedo := &Speedometer{
|
||||
w: w,
|
||||
cap: cap,
|
||||
speedLimit: speedLimit,
|
||||
hardLock: &sync.RWMutex{},
|
||||
internal: atomics{
|
||||
count: &z,
|
||||
birth: new(atomic.Pointer[time.Time]),
|
||||
duration: new(atomic.Pointer[time.Duration]),
|
||||
closed: new(atomic.Bool),
|
||||
stop: new(sync.Once),
|
||||
start: new(sync.Once),
|
||||
slow: new(atomic.Bool),
|
||||
},
|
||||
}
|
||||
speedo.internal.birth.Store(&time.Time{})
|
||||
speedo.internal.closed.Store(false)
|
||||
speedo.BufioWriter = bufio.NewWriterSize(speedo, int(cap)/10)
|
||||
return speedo
|
||||
return speedo, nil
|
||||
}
|
||||
|
||||
// NewSpeedometer creates a new Speedometer that wraps the given io.Writer.
|
||||
// It will not limit the rate at which data is written to the underlying writer, it only measures it.
|
||||
func NewSpeedometer(w io.Writer) (*Speedometer, error) {
|
||||
return newSpeedometer(w, nil, -1)
|
||||
}
|
||||
|
||||
// NewLimitedSpeedometer creates a new Speedometer that wraps the given io.Writer.
|
||||
// If the speed limit is exceeded, writes to the underlying writer will be limited.
|
||||
// See SpeedLimit for more information.
|
||||
func NewLimitedSpeedometer(w io.Writer, speedLimit *SpeedLimit) (*Speedometer, error) {
|
||||
return newSpeedometer(w, speedLimit, -1)
|
||||
}
|
||||
|
||||
// NewCappedSpeedometer creates a new Speedometer that wraps the given io.Writer.
|
||||
// If len(written) bytes exceeds cap, writes to the underlying writer will be ceased permanently for the Speedometer.
|
||||
func NewCappedSpeedometer(w io.Writer, cap int64) (*Speedometer, error) {
|
||||
return newSpeedometer(w, nil, cap)
|
||||
}
|
||||
|
||||
func (s *Speedometer) increment(inc int64) (int, error) {
|
||||
if s.internal.closed.Load() {
|
||||
if s.internal.closed.Load() || !s.hardLock.TryRLock() {
|
||||
return 0, io.ErrClosedPipe
|
||||
}
|
||||
var err error
|
||||
@ -101,43 +126,73 @@ func (s *Speedometer) increment(inc int64) (int, error) {
|
||||
return int(inc), err
|
||||
}
|
||||
|
||||
// Running returns true if the Speedometer is still running.
|
||||
func (s *Speedometer) Running() bool {
|
||||
return !s.internal.closed.Load()
|
||||
}
|
||||
|
||||
// Total returns the total number of bytes written to the underlying writer.
|
||||
func (s *Speedometer) Total() int64 {
|
||||
return atomic.LoadInt64(s.internal.count)
|
||||
}
|
||||
|
||||
func (s *Speedometer) Reset() {
|
||||
s.internal.count = new(int64)
|
||||
s.internal.closed = new(atomic.Bool)
|
||||
s.internal.start = new(sync.Once)
|
||||
s.internal.stop = new(sync.Once)
|
||||
s.BufioWriter = bufio.NewWriter(s)
|
||||
s.internal.closed.Store(false)
|
||||
}
|
||||
|
||||
// Close stops the Speedometer. No additional writes will be accepted.
|
||||
func (s *Speedometer) Close() error {
|
||||
s.hardLock.TryLock()
|
||||
if s.internal.closed.Load() {
|
||||
return io.ErrClosedPipe
|
||||
}
|
||||
s.internal.stop.Do(func() {
|
||||
s.internal.closed.Store(true)
|
||||
stopped := time.Now()
|
||||
s.duration = stopped.Sub(*s.birth)
|
||||
birth := s.internal.birth.Load()
|
||||
duration := stopped.Sub(*birth)
|
||||
s.internal.duration.Store(&duration)
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
/*func (s *Speedometer) IsSlow() bool {
|
||||
return s.internal.slow.Load()
|
||||
}*/
|
||||
|
||||
// Rate returns the rate at which data is being written to the underlying writer per second.
|
||||
func (s *Speedometer) Rate() float64 {
|
||||
if s.internal.closed.Load() {
|
||||
return float64(s.Total()) / s.duration.Seconds()
|
||||
return float64(s.Total()) / s.internal.duration.Load().Seconds()
|
||||
}
|
||||
return float64(s.Total()) / time.Since(*s.birth).Seconds()
|
||||
return float64(s.Total()) / time.Since(*s.internal.birth.Load()).Seconds()
|
||||
}
|
||||
|
||||
func (s *Speedometer) slowDown() error {
|
||||
switch {
|
||||
case s.speedLimit == nil:
|
||||
return nil
|
||||
case s.speedLimit.Burst <= 0 || s.speedLimit.Frame <= 0,
|
||||
s.speedLimit.CheckEveryBytes <= 0, s.speedLimit.Delay <= 0:
|
||||
return errors.New("invalid speed limit")
|
||||
default:
|
||||
//
|
||||
}
|
||||
if s.Total()%int64(s.speedLimit.CheckEveryBytes) != 0 {
|
||||
return nil
|
||||
}
|
||||
s.internal.slow.Store(true)
|
||||
for s.Rate() > float64(s.speedLimit.Burst)/s.speedLimit.Frame.Seconds() {
|
||||
time.Sleep(s.speedLimit.Delay)
|
||||
}
|
||||
s.internal.slow.Store(false)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Write writes p to the underlying writer, following all defined speed limits.
|
||||
func (s *Speedometer) Write(p []byte) (n int, err error) {
|
||||
if !s.hardLock.TryRLock() {
|
||||
return 0, io.ErrClosedPipe
|
||||
}
|
||||
s.internal.start.Do(func() {
|
||||
tn := time.Now()
|
||||
s.birth = &tn
|
||||
now := time.Now()
|
||||
s.internal.birth.Store(&now)
|
||||
})
|
||||
accepted, err := s.increment(int64(len(p)))
|
||||
if err != nil {
|
||||
@ -145,14 +200,10 @@ func (s *Speedometer) Write(p []byte) (n int, err error) {
|
||||
if innerErr != nil {
|
||||
err = innerErr
|
||||
}
|
||||
if wn < accepted {
|
||||
err = io.ErrShortWrite
|
||||
}
|
||||
return wn, err
|
||||
}
|
||||
if err = s.slowDown(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return s.w.Write(p)
|
||||
}
|
||||
|
||||
/*func BenchmarkHeffalump_WriteHell(b *testing.B) {
|
||||
|
||||
}*/
|
||||
|
@ -1,16 +1,27 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"io"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func writeStuff(target *bufio.Writer, count int) error {
|
||||
type testWriter struct {
|
||||
t *testing.T
|
||||
total int64
|
||||
}
|
||||
|
||||
func (w *testWriter) Write(p []byte) (n int, err error) {
|
||||
atomic.AddInt64(&w.total, int64(len(p)))
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func writeStuff(target io.Writer, count int) error {
|
||||
write := func() error {
|
||||
_, err := target.WriteString("a")
|
||||
_, err := target.Write([]byte("a"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -18,12 +29,12 @@ func writeStuff(target *bufio.Writer, count int) error {
|
||||
}
|
||||
|
||||
if count < 0 {
|
||||
for {
|
||||
if err := write(); err != nil {
|
||||
var err error
|
||||
for err = write(); err == nil; err = write() {
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
for i := 0; i < count; i++ {
|
||||
if err := write(); err != nil {
|
||||
return err
|
||||
@ -55,17 +66,37 @@ func Test_Speedometer(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
sp := NewSpeedometer(io.Discard)
|
||||
errChan := make(chan error)
|
||||
var (
|
||||
errChan = make(chan error, 10)
|
||||
err error
|
||||
cnt int
|
||||
)
|
||||
|
||||
t.Run("EarlyClose", func(t *testing.T) {
|
||||
sp, nerr := NewSpeedometer(&testWriter{t: t})
|
||||
if nerr != nil {
|
||||
t.Errorf("unexpected error: %v", nerr)
|
||||
}
|
||||
go func() {
|
||||
errChan <- writeStuff(sp.BufioWriter, -1)
|
||||
errChan <- writeStuff(sp, -1)
|
||||
}()
|
||||
time.Sleep(1 * time.Second)
|
||||
_ = sp.Close()
|
||||
err := <-errChan
|
||||
cnt, err := sp.Write([]byte("a"))
|
||||
if closeErr := sp.Close(); closeErr != nil {
|
||||
t.Errorf("wantErr: want %v, have %v", nil, closeErr)
|
||||
}
|
||||
err = <-errChan
|
||||
if !errors.Is(err, io.ErrClosedPipe) {
|
||||
t.Errorf("wantErr: want %v, have %v", io.ErrClosedPipe, err)
|
||||
}
|
||||
cnt, err = sp.Write([]byte("a"))
|
||||
isIt(results{err: io.ErrClosedPipe, written: 0}, results{err: err, written: cnt})
|
||||
sp.Reset()
|
||||
})
|
||||
|
||||
t.Run("Basic", func(t *testing.T) {
|
||||
sp, nerr := NewSpeedometer(&testWriter{t: t})
|
||||
if nerr != nil {
|
||||
t.Errorf("unexpected error: %v", nerr)
|
||||
}
|
||||
cnt, err = sp.Write([]byte("a"))
|
||||
isIt(results{err: nil, written: 1, total: 1}, results{err: err, written: cnt, total: sp.Total()})
|
||||
cnt, err = sp.Write([]byte("aa"))
|
||||
@ -74,8 +105,39 @@ func Test_Speedometer(t *testing.T) {
|
||||
isIt(results{err: nil, written: 1, total: 4}, results{err: err, written: cnt, total: sp.Total()})
|
||||
cnt, err = sp.Write([]byte("a"))
|
||||
isIt(results{err: nil, written: 1, total: 5}, results{err: err, written: cnt, total: sp.Total()})
|
||||
})
|
||||
|
||||
t.Run("ConcurrentWrites", func(t *testing.T) {
|
||||
count := int64(0)
|
||||
sp, nerr := NewSpeedometer(&testWriter{t: t})
|
||||
if nerr != nil {
|
||||
t.Errorf("unexpected error: %v", nerr)
|
||||
}
|
||||
wg := &sync.WaitGroup{}
|
||||
for i := 0; i < 100; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
errChan <- writeStuff(sp.BufioWriter, -1)
|
||||
var counted int
|
||||
var err error
|
||||
counted, err = sp.Write([]byte("a"))
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
atomic.AddInt64(&count, int64(counted))
|
||||
wg.Done()
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
isIt(results{err: nil, written: 1, total: 100}, results{err: err, written: cnt, total: sp.Total()})
|
||||
})
|
||||
|
||||
t.Run("GottaGoFast", func(t *testing.T) {
|
||||
sp, nerr := NewSpeedometer(&testWriter{t: t})
|
||||
if nerr != nil {
|
||||
t.Errorf("unexpected error: %v", nerr)
|
||||
}
|
||||
go func() {
|
||||
errChan <- writeStuff(sp, -1)
|
||||
}()
|
||||
var count = 0
|
||||
for sp.Running() {
|
||||
@ -93,9 +155,94 @@ func Test_Speedometer(t *testing.T) {
|
||||
if count > 5 {
|
||||
_ = sp.Close()
|
||||
}
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
t.Logf("rate: %v per second", sp.Rate())
|
||||
time.Sleep(1 * time.Second)
|
||||
count++
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// test limiter with speedlimit
|
||||
t.Run("CantGoFast", func(t *testing.T) {
|
||||
t.Run("10BytesASecond", func(t *testing.T) {
|
||||
sp, nerr := NewLimitedSpeedometer(&testWriter{t: t}, &SpeedLimit{
|
||||
Burst: 10,
|
||||
Frame: time.Second,
|
||||
CheckEveryBytes: 1,
|
||||
Delay: 100 * time.Millisecond,
|
||||
})
|
||||
if nerr != nil {
|
||||
t.Errorf("unexpected error: %v", nerr)
|
||||
}
|
||||
for i := 0; i < 15; i++ {
|
||||
if _, err = sp.Write([]byte("a")); err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
/*if sp.IsSlow() {
|
||||
t.Errorf("unexpected slow state")
|
||||
}*/
|
||||
t.Logf("rate: %v per second", sp.Rate())
|
||||
if sp.Rate() > 10 {
|
||||
t.Errorf("speeding in a school zone (expected under %d): %v", sp.speedLimit.Burst, sp.Rate())
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("1000BytesPer5SecondsMeasuredEvery5000Bytes", func(t *testing.T) {
|
||||
sp, nerr := NewLimitedSpeedometer(&testWriter{t: t}, &SpeedLimit{
|
||||
Burst: 1000,
|
||||
Frame: 2 * time.Second,
|
||||
CheckEveryBytes: 5000,
|
||||
Delay: 500 * time.Millisecond,
|
||||
})
|
||||
|
||||
if nerr != nil {
|
||||
t.Errorf("unexpected error: %v", nerr)
|
||||
}
|
||||
|
||||
for i := 0; i < 4999; i++ {
|
||||
if _, err = sp.Write([]byte("a")); err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
if i%1000 == 0 {
|
||||
t.Logf("rate: %v per second", sp.Rate())
|
||||
}
|
||||
if sp.Rate() < 1000 {
|
||||
t.Errorf("shouldn't have slowed down yet (expected over %d): %v", sp.speedLimit.Burst, sp.Rate())
|
||||
}
|
||||
}
|
||||
if _, err = sp.Write([]byte("a")); err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
for i := 0; i < 10; i++ {
|
||||
if _, err = sp.Write([]byte("a")); err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
t.Logf("rate: %v per second", sp.Rate())
|
||||
if sp.Rate() > 1000 {
|
||||
t.Errorf("speeding in a school zone (expected under %d): %v", sp.speedLimit.Burst, sp.Rate())
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
// test capped speedometer
|
||||
t.Run("OnlyALittle", func(t *testing.T) {
|
||||
sp, nerr := NewCappedSpeedometer(&testWriter{t: t}, 1024)
|
||||
if nerr != nil {
|
||||
t.Errorf("unexpected error: %v", nerr)
|
||||
}
|
||||
for i := 0; i < 1024; i++ {
|
||||
if _, err = sp.Write([]byte("a")); err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
if sp.Total() > 1024 {
|
||||
t.Errorf("shouldn't have written more than 1024 bytes")
|
||||
}
|
||||
}
|
||||
if _, err = sp.Write([]byte("a")); err == nil {
|
||||
t.Errorf("expected error when writing over capacity")
|
||||
}
|
||||
})
|
||||
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user