diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 2e25d5b..72dbc70 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -22,7 +22,7 @@ jobs: export PATH=$PATH:$(go env GOPATH)/bin go install github.com/securego/gosec/v2/cmd/gosec@latest gosec ./... - - name: go test -v ./... - run: go test -v ./... + - name: go test -race -v ./... + run: go test -race -v ./... - name: go build -v ./... run: go build -v ./... diff --git a/cmd/HellPot/HellPot.go b/cmd/HellPot/HellPot.go index 597d0d9..b2b3055 100644 --- a/cmd/HellPot/HellPot.go +++ b/cmd/HellPot/HellPot.go @@ -50,7 +50,7 @@ func main() { signal.Notify(stopChan, syscall.SIGINT, syscall.SIGTERM) go func() { - log.Error().Err(http.Serve()).Msg("HTTP error") + log.Fatal().Err(http.Serve()).Msg("HTTP error") }() <-stopChan // wait for SIGINT diff --git a/internal/config/logger.go b/internal/config/logger.go index 19b7cb0..a499ae3 100644 --- a/internal/config/logger.go +++ b/internal/config/logger.go @@ -18,9 +18,7 @@ var ( logger zerolog.Logger ) -// StartLogger instantiates an instance of our zerolog loggger so we can hook it in our main package. -// While this does return a logger, it should not be used for additional retrievals of the logger. Use GetLogger() -func StartLogger(pretty bool, targets ...io.Writer) zerolog.Logger { +func prepLogDir() { logDir = snek.GetString("logger.directory") if !strings.HasSuffix(logDir, "/") { logDir += "/" @@ -29,7 +27,11 @@ func StartLogger(pretty bool, targets ...io.Writer) zerolog.Logger { println("cannot create log directory: " + logDir + "(" + err.Error() + ")") os.Exit(1) } +} +// StartLogger instantiates an instance of our zerolog loggger so we can hook it in our main package. +// While this does return a logger, it should not be used for additional retrievals of the logger. Use GetLogger(). +func StartLogger(pretty bool, targets ...io.Writer) zerolog.Logger { logFileName := "HellPot" if snek.GetBool("logger.use_date_filename") { @@ -44,9 +46,11 @@ func StartLogger(pretty bool, targets ...io.Writer) zerolog.Logger { case len(targets) > 0: logFile = io.MultiWriter(targets...) default: + prepLogDir() CurrentLogFile = path.Join(logDir, logFileName+".log") - /* #nosec */ - if logFile, err = os.OpenFile(CurrentLogFile, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0o666); err != nil { + //nolint:lll + logFile, err = os.OpenFile(CurrentLogFile, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0o666) // #nosec G304 G302 -- we are not using user input to create the file + if err != nil { println("cannot create log file: " + err.Error()) os.Exit(1) } @@ -62,7 +66,7 @@ func StartLogger(pretty bool, targets ...io.Writer) zerolog.Logger { return logger } -// GetLogger retrieves our global logger object +// GetLogger retrieves our global logger object. func GetLogger() *zerolog.Logger { // future logic here return &logger diff --git a/internal/util/speedometer.go b/internal/util/speedometer.go new file mode 100644 index 0000000..e3cbbae --- /dev/null +++ b/internal/util/speedometer.go @@ -0,0 +1,237 @@ +package util + +import ( + "errors" + "fmt" + "io" + "sync" + "sync/atomic" + "time" +) + +var ErrLimitReached = errors.New("limit reached") + +// Speedometer is an io.Writer wrapper that will limit the rate at which data is written to the underlying target. +// +// It is safe for concurrent use, but writers will block when slowed down. +// +// Optionally, it can be given; +// +// - a capacity, which will cause it to return an error if the capacity is exceeded. +// +// - a speed limit, causing slow downs of data written to the underlying writer if the speed limit is exceeded. +type Speedometer struct { + ceiling int64 + speedLimit *SpeedLimit + internal atomics + w io.Writer +} + +type atomics struct { + count *atomic.Int64 + closed *atomic.Bool + start *sync.Once + stop *sync.Once + birth *atomic.Pointer[time.Time] + duration *atomic.Pointer[time.Duration] + slow *atomic.Bool +} + +func newAtomics() atomics { + manhattan := atomics{ + count: new(atomic.Int64), + closed: new(atomic.Bool), + start: new(sync.Once), + stop: new(sync.Once), + birth: new(atomic.Pointer[time.Time]), + duration: new(atomic.Pointer[time.Duration]), + slow: new(atomic.Bool), + } + manhattan.birth.Store(&time.Time{}) + manhattan.closed.Store(false) + manhattan.count.Store(0) + return manhattan +} + +// SpeedLimit is used to limit the rate at which data is written to the underlying writer. +type SpeedLimit struct { + // Burst is the number of bytes that can be written to the underlying writer per Frame. + Burst int + // Frame is the duration of the frame in which Burst can be written to the underlying writer. + Frame time.Duration + // CheckEveryBytes is the number of bytes written before checking if the speed limit has been exceeded. + CheckEveryBytes int + // Delay is the duration to delay writing if the speed limit has been exceeded during a Write call. (blocking) + Delay time.Duration +} + +const fallbackDelay = 100 + +func regulateSpeedLimit(speedLimit *SpeedLimit) (*SpeedLimit, error) { + if speedLimit.Burst <= 0 || speedLimit.Frame <= 0 { + return nil, errors.New("invalid speed limit") + } + if speedLimit.CheckEveryBytes <= 0 { + speedLimit.CheckEveryBytes = speedLimit.Burst + } + if speedLimit.Delay <= 0 { + speedLimit.Delay = fallbackDelay * time.Millisecond + } + return speedLimit, nil +} + +func newSpeedometer(w io.Writer, speedLimit *SpeedLimit, ceiling int64) (*Speedometer, error) { + if w == nil { + return nil, errors.New("writer cannot be nil") + } + var err error + if speedLimit != nil { + if speedLimit, err = regulateSpeedLimit(speedLimit); err != nil { + return nil, err + } + } + + return &Speedometer{ + w: w, + ceiling: ceiling, + speedLimit: speedLimit, + internal: newAtomics(), + }, nil +} + +// NewSpeedometer creates a new Speedometer that wraps the given io.Writer. +// It will not limit the rate at which data is written to the underlying writer, it only measures it. +func NewSpeedometer(w io.Writer) (*Speedometer, error) { + return newSpeedometer(w, nil, -1) +} + +// NewLimitedSpeedometer creates a new Speedometer that wraps the given io.Writer. +// If the speed limit is exceeded, writes to the underlying writer will be limited. +// See SpeedLimit for more information. +func NewLimitedSpeedometer(w io.Writer, speedLimit *SpeedLimit) (*Speedometer, error) { + return newSpeedometer(w, speedLimit, -1) +} + +// NewCappedSpeedometer creates a new Speedometer that wraps the given io.Writer. +// If len(written) bytes exceeds cap, writes to the underlying writer will be ceased permanently for the Speedometer. +func NewCappedSpeedometer(w io.Writer, capacity int64) (*Speedometer, error) { + return newSpeedometer(w, nil, capacity) +} + +// NewCappedLimitedSpeedometer creates a new Speedometer that wraps the given io.Writer. +// It is a combination of NewLimitedSpeedometer and NewCappedSpeedometer. +func NewCappedLimitedSpeedometer(w io.Writer, speedLimit *SpeedLimit, capacity int64) (*Speedometer, error) { + return newSpeedometer(w, speedLimit, capacity) +} + +func (s *Speedometer) increment(inc int64) (int, error) { + if s.internal.closed.Load() { + return 0, io.ErrClosedPipe + } + var err error + if s.ceiling > 0 && s.Total()+inc > s.ceiling { + _ = s.Close() + err = ErrLimitReached + inc = s.ceiling - s.Total() + } + s.internal.count.Add(inc) + return int(inc), err +} + +// Running returns true if the Speedometer is still running. +func (s *Speedometer) Running() bool { + return !s.internal.closed.Load() +} + +// Total returns the total number of bytes written to the underlying writer. +func (s *Speedometer) Total() int64 { + return s.internal.count.Load() +} + +// Close stops the Speedometer. No additional writes will be accepted. +func (s *Speedometer) Close() error { + if s.internal.closed.Load() { + return io.ErrClosedPipe + } + s.internal.stop.Do(func() { + s.internal.closed.Store(true) + stopped := time.Now() + birth := s.internal.birth.Load() + duration := stopped.Sub(*birth) + s.internal.duration.Store(&duration) + }) + return nil +} + +/*func (s *Speedometer) IsSlow() bool { + return s.internal.slow.Load() +}*/ + +// Rate returns the rate at which data is being written to the underlying writer per second. +func (s *Speedometer) Rate() float64 { + if s.internal.closed.Load() { + return float64(s.Total()) / s.internal.duration.Load().Seconds() + } + return float64(s.Total()) / time.Since(*s.internal.birth.Load()).Seconds() +} + +func (s *Speedometer) slowDown() error { + switch { + case s.speedLimit == nil: + return nil + case s.speedLimit.Burst <= 0 || s.speedLimit.Frame <= 0, + s.speedLimit.CheckEveryBytes <= 0, s.speedLimit.Delay <= 0: + return errors.New("invalid speed limit") + default: + // + } + if s.Total()%int64(s.speedLimit.CheckEveryBytes) != 0 { + return nil + } + s.internal.slow.Store(true) + for s.Rate() > float64(s.speedLimit.Burst)/s.speedLimit.Frame.Seconds() { + time.Sleep(s.speedLimit.Delay) + } + s.internal.slow.Store(false) + return nil +} + +// Write writes p to the underlying writer, following all defined speed limits. +func (s *Speedometer) Write(p []byte) (n int, err error) { + if s.internal.closed.Load() { + return 0, io.ErrClosedPipe + } + s.internal.start.Do(func() { + now := time.Now() + s.internal.birth.Store(&now) + }) + + // if no speed limit, just write and record + if s.speedLimit == nil { + n, err = s.w.Write(p) + if err != nil { + return n, fmt.Errorf("error writing to underlying writer: %w", err) + } + return s.increment(int64(len(p))) + } + + var ( + wErr error + accepted int + ) + accepted, wErr = s.increment(int64(len(p))) + + if wErr != nil { + return 0, fmt.Errorf("error incrementing: %w", wErr) + } + + if sErr := s.slowDown(); sErr != nil { + return 0, fmt.Errorf("error slowing down: %w", sErr) + } + + var iErr error + if n, iErr = s.w.Write(p[:accepted]); iErr != nil { + return n, fmt.Errorf("error writing to underlying writer: %w", iErr) + } + return +} diff --git a/internal/util/speedometer_test.go b/internal/util/speedometer_test.go new file mode 100644 index 0000000..20afaba --- /dev/null +++ b/internal/util/speedometer_test.go @@ -0,0 +1,393 @@ +package util + +import ( + "bytes" + "errors" + "fmt" + "io" + "net" + "sync" + "sync/atomic" + "testing" + "time" +) + +type testWriter struct { + t *testing.T + total int64 +} + +func (w *testWriter) Write(p []byte) (n int, err error) { + atomic.AddInt64(&w.total, int64(len(p))) + return len(p), nil +} + +func writeStuff(t *testing.T, target io.Writer, count int) error { + t.Helper() + write := func() error { + _, err := target.Write([]byte("a")) + if err != nil { + return fmt.Errorf("error writing: %w", err) + } + return nil + } + + if count < 0 { + var err error + for err = write(); err == nil; err = write() { + time.Sleep(5 * time.Millisecond) + } + return err + } + for i := 0; i < count; i++ { + if err := write(); err != nil { + return err + } + } + return nil +} + +//nolint:funlen +func Test_Speedometer(t *testing.T) { + type results struct { + total int64 + written int + rate float64 + err error + } + + isIt := func(want, have results) { + t.Helper() + if have.total != want.total { + t.Errorf("total: want %d, have %d", want.total, have.total) + } + if have.written != want.written { + t.Errorf("written: want %d, have %d", want.written, have.written) + } + if have.rate != want.rate { + t.Errorf("rate: want %f, have %f", want.rate, have.rate) + } + if !errors.Is(have.err, want.err) { + t.Errorf("wantErr: want %v, have %v", want.err, have.err) + } + } + + var ( + errChan = make(chan error, 10) + ) + + t.Run("EarlyClose", func(t *testing.T) { + var ( + err error + cnt int + ) + t.Parallel() + sp, nerr := NewSpeedometer(&testWriter{t: t}) + if nerr != nil { + t.Errorf("unexpected error: %v", nerr) + } + go func() { + errChan <- writeStuff(t, sp, -1) + }() + time.Sleep(1 * time.Second) + if closeErr := sp.Close(); closeErr != nil { + t.Errorf("wantErr: want %v, have %v", nil, closeErr) + } + err = <-errChan + if !errors.Is(err, io.ErrClosedPipe) { + t.Errorf("wantErr: want %v, have %v", io.ErrClosedPipe, err) + } + cnt, err = sp.Write([]byte("a")) + isIt(results{err: io.ErrClosedPipe, written: 0}, results{err: err, written: cnt}) + }) + + t.Run("Basic", func(t *testing.T) { + var ( + err error + cnt int + ) + t.Parallel() + sp, nerr := NewSpeedometer(&testWriter{t: t}) + if nerr != nil { + t.Errorf("unexpected error: %v", nerr) + } + cnt, err = sp.Write([]byte("a")) + isIt(results{err: nil, written: 1, total: 1}, results{err: err, written: cnt, total: sp.Total()}) + cnt, err = sp.Write([]byte("aa")) + isIt(results{err: nil, written: 2, total: 3}, results{err: err, written: cnt, total: sp.Total()}) + cnt, err = sp.Write([]byte("a")) + isIt(results{err: nil, written: 1, total: 4}, results{err: err, written: cnt, total: sp.Total()}) + cnt, err = sp.Write([]byte("a")) + isIt(results{err: nil, written: 1, total: 5}, results{err: err, written: cnt, total: sp.Total()}) + }) + + t.Run("ConcurrentWrites", func(t *testing.T) { + var ( + err error + ) + + count := int64(0) + sp, nerr := NewSpeedometer(&testWriter{t: t}) + if nerr != nil { + t.Errorf("unexpected error: %v", nerr) + } + wg := &sync.WaitGroup{} + wg.Add(100) + for i := 0; i < 100; i++ { + go func() { + var counted int + var gerr error + counted, gerr = sp.Write([]byte("a")) + if gerr != nil { + t.Errorf("unexpected error: %v", err) + } + atomic.AddInt64(&count, int64(counted)) + wg.Done() + }() + } + wg.Wait() + isIt(results{err: nil, written: 100, total: 100}, + results{err: err, written: int(atomic.LoadInt64(&count)), total: sp.Total()}) + }) + + t.Run("GottaGoFast", func(t *testing.T) { + t.Parallel() + var ( + err error + ) + sp, nerr := NewSpeedometer(&testWriter{t: t}) + if nerr != nil { + t.Errorf("unexpected error: %v", nerr) + } + go func() { + errChan <- writeStuff(t, sp, -1) + }() + var count = 0 + for sp.Running() { + select { + case err = <-errChan: + if !errors.Is(err, io.ErrClosedPipe) { + t.Errorf("unexpected error: %v", err) + } else { + if count < 5 { + t.Errorf("too few iterations: %d", count) + } + t.Logf("final rate: %v per second", sp.Rate()) + } + default: + if count > 5 { + _ = sp.Close() + } + time.Sleep(100 * time.Millisecond) + t.Logf("rate: %v per second", sp.Rate()) + count++ + } + } + }) + + // test limiter with speedlimit + t.Run("CantGoFast", func(t *testing.T) { + t.Parallel() + t.Run("10BytesASecond", func(t *testing.T) { + t.Parallel() + var ( + err error + ) + sp, nerr := NewLimitedSpeedometer(&testWriter{t: t}, &SpeedLimit{ + Burst: 10, + Frame: time.Second, + CheckEveryBytes: 1, + Delay: 100 * time.Millisecond, + }) + if nerr != nil { + t.Errorf("unexpected error: %v", nerr) + } + for i := 0; i < 15; i++ { + if _, err = sp.Write([]byte("a")); err != nil { + t.Errorf("unexpected error: %v", err) + } + /*if sp.IsSlow() { + t.Errorf("unexpected slow state") + }*/ + t.Logf("rate: %v per second", sp.Rate()) + if sp.Rate() > 10 { + t.Errorf("speeding in a school zone (expected under %d): %v", sp.speedLimit.Burst, sp.Rate()) + } + } + }) + + t.Run("1000BytesPer5SecondsMeasuredEvery5000Bytes", func(t *testing.T) { + t.Parallel() + var ( + err error + ) + sp, nerr := NewLimitedSpeedometer(&testWriter{t: t}, &SpeedLimit{ + Burst: 1000, + Frame: 2 * time.Second, + CheckEveryBytes: 5000, + Delay: 500 * time.Millisecond, + }) + + if nerr != nil { + t.Errorf("unexpected error: %v", nerr) + } + + for i := 0; i < 4999; i++ { + if _, err = sp.Write([]byte("a")); err != nil { + t.Errorf("unexpected error: %v", err) + } + if i%1000 == 0 { + t.Logf("rate: %v per second", sp.Rate()) + } + if sp.Rate() < 1000 { + t.Errorf("shouldn't have slowed down yet (expected over %d): %v", sp.speedLimit.Burst, sp.Rate()) + } + } + if _, err = sp.Write([]byte("a")); err != nil { + t.Errorf("unexpected error: %v", err) + } + for i := 0; i < 10; i++ { + if _, err = sp.Write([]byte("a")); err != nil { + t.Errorf("unexpected error: %v", err) + } + t.Logf("rate: %v per second", sp.Rate()) + if sp.Rate() > 1000 { + t.Errorf("speeding in a school zone (expected under %d): %v", sp.speedLimit.Burst, sp.Rate()) + } + } + }) + }) + + // test capped speedometer + t.Run("OnlyALittle", func(t *testing.T) { + t.Parallel() + var ( + err error + ) + sp, nerr := NewCappedSpeedometer(&testWriter{t: t}, 1024) + if nerr != nil { + t.Errorf("unexpected error: %v", nerr) + } + for i := 0; i < 1024; i++ { + if _, err = sp.Write([]byte("a")); err != nil { + t.Errorf("unexpected error: %v", err) + } + if sp.Total() > 1024 { + t.Errorf("shouldn't have written more than 1024 bytes") + } + } + if _, err = sp.Write([]byte("a")); err == nil { + t.Errorf("expected error when writing over capacity") + } + }) + + t.Run("SynSynAckAck", func(t *testing.T) { + t.Parallel() + var ( + server net.Listener + err error + ) + //goland:noinspection GoCommentLeadingSpace + if server, err = net.Listen("tcp", ":8080"); err != nil { // #nosec:G102 - this is a unit test. + t.Fatalf("Failed to start server: %v", err) + } + defer func(server net.Listener) { + if cErr := server.Close(); cErr != nil { + t.Errorf("Failed to close server: %v", err) + } + }(server) + + go func() { + var ( + conn net.Conn + aErr error + ) + if conn, aErr = server.Accept(); aErr != nil { + t.Errorf("Failed to accept connection: %v", err) + } + + t.Logf("Accepted connection from %s", conn.RemoteAddr().String()) + + defer func(conn net.Conn) { + if cErr := conn.Close(); cErr != nil { + t.Errorf("Failed to close connection: %v", err) + } + }(conn) + + speedLimit := &SpeedLimit{ + Burst: 512, + Frame: time.Second, + CheckEveryBytes: 1, + Delay: 10 * time.Millisecond, + } + + var ( + speedometer *Speedometer + sErr error + ) + if speedometer, sErr = NewCappedLimitedSpeedometer(conn, speedLimit, 4096); sErr != nil { + t.Errorf("Failed to create speedometer: %v", sErr) + } + + buf := make([]byte, 1024) + for i := range buf { + targ := byte('E') + if i%2 == 0 { + targ = byte('e') + } + buf[i] = targ + } + for { + n, wErr := speedometer.Write(buf) + switch { + case errors.Is(wErr, io.EOF), errors.Is(wErr, ErrLimitReached): + return + case wErr != nil: + t.Errorf("Failed to write: %v", wErr) + case n != len(buf): + t.Errorf("Failed to write all bytes: %d", n) + default: + t.Logf("Wrote %d bytes", n) + } + } + }() + + var ( + client net.Conn + aErr error + ) + + if client, aErr = net.Dial("tcp", "localhost:8080"); aErr != nil { + t.Fatalf("Failed to connect to server: %v", err) + } + + defer func(client net.Conn) { + if clErr := client.Close(); clErr != nil { + t.Errorf("Failed to close client: %v", err) + } + }(client) + + buf := &bytes.Buffer{} + startTime := time.Now() + n, cpErr := io.Copy(buf, client) + if cpErr != nil { + t.Errorf("Failed to copy: %v", cpErr) + } + + duration := time.Since(startTime) + if buf.Len() == 0 || n == 0 { + t.Fatalf("No data received") + } + + rate := measureRate(t, n, duration) + + if rate > 512.0 { + t.Fatalf("Rate exceeded: got %f, expected <= 100.0", rate) + } + }) +} + +func measureRate(t *testing.T, received int64, duration time.Duration) float64 { + t.Helper() + return float64(received) / duration.Seconds() +}