From c61f6b4a9c52327e241481e3bbe81fa5e8bf7c70 Mon Sep 17 00:00:00 2001 From: "kayos@tcp.direct" Date: Sun, 5 Mar 2023 03:56:26 -0800 Subject: [PATCH] Feat: Fix and finish bandwidth limiting io.Writer --- internal/util/speedometer.go | 205 +++++++++++++++---------- internal/util/speedometer_test.go | 245 ++++++++++++++++++++++++------ 2 files changed, 324 insertions(+), 126 deletions(-) diff --git a/internal/util/speedometer.go b/internal/util/speedometer.go index 8c736e6..8fbf090 100644 --- a/internal/util/speedometer.go +++ b/internal/util/speedometer.go @@ -1,7 +1,6 @@ package util import ( - "bufio" "errors" "io" "sync" @@ -11,84 +10,110 @@ import ( var ErrLimitReached = errors.New("limit reached") +// Speedometer is a wrapper around an io.Writer that will limit the rate at which data is written to the underlying writer. +// +// It is safe for concurrent use, but writers will block when slowed down. +// +// Optionally, it can be given; +// +// - a capacity, which will cause it to return an error if the capacity is exceeded. +// +// - a speed limit, causing slow downs of data written to the underlying writer if the speed limit is exceeded. type Speedometer struct { - cap int64 - speedLimit SpeedLimit - birth *time.Time - duration time.Duration - internal atomics - w io.Writer - BufioWriter *bufio.Writer + cap int64 + speedLimit *SpeedLimit + internal atomics + hardLock *sync.RWMutex + w io.Writer } type atomics struct { - closed *atomic.Bool - count *int64 - start *sync.Once - stop *sync.Once -} - -func NewSpeedometer(w io.Writer) *Speedometer { - z := int64(0) - speedo := &Speedometer{ - w: w, - cap: -1, - internal: atomics{ - count: &z, - closed: new(atomic.Bool), - stop: new(sync.Once), - start: new(sync.Once), - }, - } - speedo.internal.closed.Store(false) - speedo.BufioWriter = bufio.NewWriter(speedo) - return speedo + closed *atomic.Bool + count *int64 + start *sync.Once + stop *sync.Once + birth *atomic.Pointer[time.Time] + duration *atomic.Pointer[time.Duration] + slow *atomic.Bool } +// SpeedLimit is used to limit the rate at which data is written to the underlying writer. type SpeedLimit struct { - Bytes int - Per time.Duration + // Burst is the number of bytes that can be written to the underlying writer per Frame. + Burst int + // Frame is the duration of the frame in which Burst can be written to the underlying writer. + Frame time.Duration + // CheckEveryBytes is the number of bytes written before checking if the speed limit has been exceeded. CheckEveryBytes int - Delay time.Duration + // Delay is the duration to delay writing if the speed limit has been exceeded during a Write call. (blocking) + Delay time.Duration } -func NewLimitedSpeedometer(w io.Writer, speedLimit SpeedLimit) *Speedometer { +func regulateSpeedLimit(speedLimit *SpeedLimit) (*SpeedLimit, error) { + if speedLimit.Burst <= 0 || speedLimit.Frame <= 0 { + return nil, errors.New("invalid speed limit") + } + if speedLimit.CheckEveryBytes <= 0 { + speedLimit.CheckEveryBytes = speedLimit.Burst + } + if speedLimit.Delay <= 0 { + speedLimit.Delay = 100 * time.Millisecond + } + return speedLimit, nil +} + +func newSpeedometer(w io.Writer, speedLimit *SpeedLimit, cap int64) (*Speedometer, error) { + if w == nil { + return nil, errors.New("writer cannot be nil") + } + var err error + if speedLimit != nil { + if speedLimit, err = regulateSpeedLimit(speedLimit); err != nil { + return nil, err + } + } z := int64(0) speedo := &Speedometer{ w: w, - cap: -1, + cap: cap, speedLimit: speedLimit, + hardLock: &sync.RWMutex{}, internal: atomics{ - count: &z, - closed: new(atomic.Bool), - stop: new(sync.Once), - start: new(sync.Once), + count: &z, + birth: new(atomic.Pointer[time.Time]), + duration: new(atomic.Pointer[time.Duration]), + closed: new(atomic.Bool), + stop: new(sync.Once), + start: new(sync.Once), + slow: new(atomic.Bool), }, } + speedo.internal.birth.Store(&time.Time{}) speedo.internal.closed.Store(false) - speedo.BufioWriter = bufio.NewWriterSize(speedo, speedLimit.CheckEveryBytes) - return speedo + return speedo, nil } -func NewCappedSpeedometer(w io.Writer, cap int64) *Speedometer { - z := int64(0) - speedo := &Speedometer{ - w: w, - cap: cap, - internal: atomics{ - count: &z, - closed: new(atomic.Bool), - stop: new(sync.Once), - start: new(sync.Once), - }, - } - speedo.internal.closed.Store(false) - speedo.BufioWriter = bufio.NewWriterSize(speedo, int(cap)/10) - return speedo +// NewSpeedometer creates a new Speedometer that wraps the given io.Writer. +// It will not limit the rate at which data is written to the underlying writer, it only measures it. +func NewSpeedometer(w io.Writer) (*Speedometer, error) { + return newSpeedometer(w, nil, -1) +} + +// NewLimitedSpeedometer creates a new Speedometer that wraps the given io.Writer. +// If the speed limit is exceeded, writes to the underlying writer will be limited. +// See SpeedLimit for more information. +func NewLimitedSpeedometer(w io.Writer, speedLimit *SpeedLimit) (*Speedometer, error) { + return newSpeedometer(w, speedLimit, -1) +} + +// NewCappedSpeedometer creates a new Speedometer that wraps the given io.Writer. +// If len(written) bytes exceeds cap, writes to the underlying writer will be ceased permanently for the Speedometer. +func NewCappedSpeedometer(w io.Writer, cap int64) (*Speedometer, error) { + return newSpeedometer(w, nil, cap) } func (s *Speedometer) increment(inc int64) (int, error) { - if s.internal.closed.Load() { + if s.internal.closed.Load() || !s.hardLock.TryRLock() { return 0, io.ErrClosedPipe } var err error @@ -101,43 +126,73 @@ func (s *Speedometer) increment(inc int64) (int, error) { return int(inc), err } +// Running returns true if the Speedometer is still running. func (s *Speedometer) Running() bool { return !s.internal.closed.Load() } +// Total returns the total number of bytes written to the underlying writer. func (s *Speedometer) Total() int64 { return atomic.LoadInt64(s.internal.count) } -func (s *Speedometer) Reset() { - s.internal.count = new(int64) - s.internal.closed = new(atomic.Bool) - s.internal.start = new(sync.Once) - s.internal.stop = new(sync.Once) - s.BufioWriter = bufio.NewWriter(s) - s.internal.closed.Store(false) -} - +// Close stops the Speedometer. No additional writes will be accepted. func (s *Speedometer) Close() error { + s.hardLock.TryLock() + if s.internal.closed.Load() { + return io.ErrClosedPipe + } s.internal.stop.Do(func() { s.internal.closed.Store(true) stopped := time.Now() - s.duration = stopped.Sub(*s.birth) + birth := s.internal.birth.Load() + duration := stopped.Sub(*birth) + s.internal.duration.Store(&duration) }) return nil } +/*func (s *Speedometer) IsSlow() bool { + return s.internal.slow.Load() +}*/ + +// Rate returns the rate at which data is being written to the underlying writer per second. func (s *Speedometer) Rate() float64 { if s.internal.closed.Load() { - return float64(s.Total()) / s.duration.Seconds() + return float64(s.Total()) / s.internal.duration.Load().Seconds() } - return float64(s.Total()) / time.Since(*s.birth).Seconds() + return float64(s.Total()) / time.Since(*s.internal.birth.Load()).Seconds() } +func (s *Speedometer) slowDown() error { + switch { + case s.speedLimit == nil: + return nil + case s.speedLimit.Burst <= 0 || s.speedLimit.Frame <= 0, + s.speedLimit.CheckEveryBytes <= 0, s.speedLimit.Delay <= 0: + return errors.New("invalid speed limit") + default: + // + } + if s.Total()%int64(s.speedLimit.CheckEveryBytes) != 0 { + return nil + } + s.internal.slow.Store(true) + for s.Rate() > float64(s.speedLimit.Burst)/s.speedLimit.Frame.Seconds() { + time.Sleep(s.speedLimit.Delay) + } + s.internal.slow.Store(false) + return nil +} + +// Write writes p to the underlying writer, following all defined speed limits. func (s *Speedometer) Write(p []byte) (n int, err error) { + if !s.hardLock.TryRLock() { + return 0, io.ErrClosedPipe + } s.internal.start.Do(func() { - tn := time.Now() - s.birth = &tn + now := time.Now() + s.internal.birth.Store(&now) }) accepted, err := s.increment(int64(len(p))) if err != nil { @@ -145,14 +200,10 @@ func (s *Speedometer) Write(p []byte) (n int, err error) { if innerErr != nil { err = innerErr } - if wn < accepted { - err = io.ErrShortWrite - } return wn, err } + if err = s.slowDown(); err != nil { + return 0, err + } return s.w.Write(p) } - -/*func BenchmarkHeffalump_WriteHell(b *testing.B) { - -}*/ diff --git a/internal/util/speedometer_test.go b/internal/util/speedometer_test.go index d69115a..0a2959e 100644 --- a/internal/util/speedometer_test.go +++ b/internal/util/speedometer_test.go @@ -1,16 +1,27 @@ package util import ( - "bufio" "errors" "io" + "sync" + "sync/atomic" "testing" "time" ) -func writeStuff(target *bufio.Writer, count int) error { +type testWriter struct { + t *testing.T + total int64 +} + +func (w *testWriter) Write(p []byte) (n int, err error) { + atomic.AddInt64(&w.total, int64(len(p))) + return len(p), nil +} + +func writeStuff(target io.Writer, count int) error { write := func() error { - _, err := target.WriteString("a") + _, err := target.Write([]byte("a")) if err != nil { return err } @@ -18,11 +29,11 @@ func writeStuff(target *bufio.Writer, count int) error { } if count < 0 { - for { - if err := write(); err != nil { - return err - } + var err error + for err = write(); err == nil; err = write() { + time.Sleep(5 * time.Millisecond) } + return err } for i := 0; i < count; i++ { if err := write(); err != nil { @@ -55,47 +66,183 @@ func Test_Speedometer(t *testing.T) { } } - sp := NewSpeedometer(io.Discard) - errChan := make(chan error) - go func() { - errChan <- writeStuff(sp.BufioWriter, -1) - }() - time.Sleep(1 * time.Second) - _ = sp.Close() - err := <-errChan - cnt, err := sp.Write([]byte("a")) - isIt(results{err: io.ErrClosedPipe, written: 0}, results{err: err, written: cnt}) - sp.Reset() - cnt, err = sp.Write([]byte("a")) - isIt(results{err: nil, written: 1, total: 1}, results{err: err, written: cnt, total: sp.Total()}) - cnt, err = sp.Write([]byte("aa")) - isIt(results{err: nil, written: 2, total: 3}, results{err: err, written: cnt, total: sp.Total()}) - cnt, err = sp.Write([]byte("a")) - isIt(results{err: nil, written: 1, total: 4}, results{err: err, written: cnt, total: sp.Total()}) - cnt, err = sp.Write([]byte("a")) - isIt(results{err: nil, written: 1, total: 5}, results{err: err, written: cnt, total: sp.Total()}) - go func() { - errChan <- writeStuff(sp.BufioWriter, -1) - }() - var count = 0 - for sp.Running() { - select { - case err = <-errChan: - if !errors.Is(err, io.ErrClosedPipe) { - t.Errorf("unexpected error: %v", err) - } else { - if count < 5 { - t.Errorf("too few iterations: %d", count) - } - t.Logf("final rate: %v per second", sp.Rate()) - } - default: - if count > 5 { - _ = sp.Close() - } - t.Logf("rate: %v per second", sp.Rate()) - time.Sleep(1 * time.Second) - count++ + var ( + errChan = make(chan error, 10) + err error + cnt int + ) + + t.Run("EarlyClose", func(t *testing.T) { + sp, nerr := NewSpeedometer(&testWriter{t: t}) + if nerr != nil { + t.Errorf("unexpected error: %v", nerr) } - } + go func() { + errChan <- writeStuff(sp, -1) + }() + time.Sleep(1 * time.Second) + if closeErr := sp.Close(); closeErr != nil { + t.Errorf("wantErr: want %v, have %v", nil, closeErr) + } + err = <-errChan + if !errors.Is(err, io.ErrClosedPipe) { + t.Errorf("wantErr: want %v, have %v", io.ErrClosedPipe, err) + } + cnt, err = sp.Write([]byte("a")) + isIt(results{err: io.ErrClosedPipe, written: 0}, results{err: err, written: cnt}) + }) + + t.Run("Basic", func(t *testing.T) { + sp, nerr := NewSpeedometer(&testWriter{t: t}) + if nerr != nil { + t.Errorf("unexpected error: %v", nerr) + } + cnt, err = sp.Write([]byte("a")) + isIt(results{err: nil, written: 1, total: 1}, results{err: err, written: cnt, total: sp.Total()}) + cnt, err = sp.Write([]byte("aa")) + isIt(results{err: nil, written: 2, total: 3}, results{err: err, written: cnt, total: sp.Total()}) + cnt, err = sp.Write([]byte("a")) + isIt(results{err: nil, written: 1, total: 4}, results{err: err, written: cnt, total: sp.Total()}) + cnt, err = sp.Write([]byte("a")) + isIt(results{err: nil, written: 1, total: 5}, results{err: err, written: cnt, total: sp.Total()}) + }) + + t.Run("ConcurrentWrites", func(t *testing.T) { + count := int64(0) + sp, nerr := NewSpeedometer(&testWriter{t: t}) + if nerr != nil { + t.Errorf("unexpected error: %v", nerr) + } + wg := &sync.WaitGroup{} + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + var counted int + var err error + counted, err = sp.Write([]byte("a")) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + atomic.AddInt64(&count, int64(counted)) + wg.Done() + }() + } + wg.Wait() + isIt(results{err: nil, written: 1, total: 100}, results{err: err, written: cnt, total: sp.Total()}) + }) + + t.Run("GottaGoFast", func(t *testing.T) { + sp, nerr := NewSpeedometer(&testWriter{t: t}) + if nerr != nil { + t.Errorf("unexpected error: %v", nerr) + } + go func() { + errChan <- writeStuff(sp, -1) + }() + var count = 0 + for sp.Running() { + select { + case err = <-errChan: + if !errors.Is(err, io.ErrClosedPipe) { + t.Errorf("unexpected error: %v", err) + } else { + if count < 5 { + t.Errorf("too few iterations: %d", count) + } + t.Logf("final rate: %v per second", sp.Rate()) + } + default: + if count > 5 { + _ = sp.Close() + } + time.Sleep(100 * time.Millisecond) + t.Logf("rate: %v per second", sp.Rate()) + count++ + } + } + }) + + // test limiter with speedlimit + t.Run("CantGoFast", func(t *testing.T) { + t.Run("10BytesASecond", func(t *testing.T) { + sp, nerr := NewLimitedSpeedometer(&testWriter{t: t}, &SpeedLimit{ + Burst: 10, + Frame: time.Second, + CheckEveryBytes: 1, + Delay: 100 * time.Millisecond, + }) + if nerr != nil { + t.Errorf("unexpected error: %v", nerr) + } + for i := 0; i < 15; i++ { + if _, err = sp.Write([]byte("a")); err != nil { + t.Errorf("unexpected error: %v", err) + } + /*if sp.IsSlow() { + t.Errorf("unexpected slow state") + }*/ + t.Logf("rate: %v per second", sp.Rate()) + if sp.Rate() > 10 { + t.Errorf("speeding in a school zone (expected under %d): %v", sp.speedLimit.Burst, sp.Rate()) + } + } + }) + + t.Run("1000BytesPer5SecondsMeasuredEvery5000Bytes", func(t *testing.T) { + sp, nerr := NewLimitedSpeedometer(&testWriter{t: t}, &SpeedLimit{ + Burst: 1000, + Frame: 2 * time.Second, + CheckEveryBytes: 5000, + Delay: 500 * time.Millisecond, + }) + + if nerr != nil { + t.Errorf("unexpected error: %v", nerr) + } + + for i := 0; i < 4999; i++ { + if _, err = sp.Write([]byte("a")); err != nil { + t.Errorf("unexpected error: %v", err) + } + if i%1000 == 0 { + t.Logf("rate: %v per second", sp.Rate()) + } + if sp.Rate() < 1000 { + t.Errorf("shouldn't have slowed down yet (expected over %d): %v", sp.speedLimit.Burst, sp.Rate()) + } + } + if _, err = sp.Write([]byte("a")); err != nil { + t.Errorf("unexpected error: %v", err) + } + for i := 0; i < 10; i++ { + if _, err = sp.Write([]byte("a")); err != nil { + t.Errorf("unexpected error: %v", err) + } + t.Logf("rate: %v per second", sp.Rate()) + if sp.Rate() > 1000 { + t.Errorf("speeding in a school zone (expected under %d): %v", sp.speedLimit.Burst, sp.Rate()) + } + } + }) + }) + + // test capped speedometer + t.Run("OnlyALittle", func(t *testing.T) { + sp, nerr := NewCappedSpeedometer(&testWriter{t: t}, 1024) + if nerr != nil { + t.Errorf("unexpected error: %v", nerr) + } + for i := 0; i < 1024; i++ { + if _, err = sp.Write([]byte("a")); err != nil { + t.Errorf("unexpected error: %v", err) + } + if sp.Total() > 1024 { + t.Errorf("shouldn't have written more than 1024 bytes") + } + } + if _, err = sp.Write([]byte("a")); err == nil { + t.Errorf("expected error when writing over capacity") + } + }) + }