mirror of https://github.com/yunginnanet/HellPot
Begin limiting writer implementation
This commit is contained in:
parent
ecff40185f
commit
7433634ad3
|
@ -7,9 +7,9 @@ import (
|
|||
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"github.com/yunginnanet/HellPot/config"
|
||||
"github.com/yunginnanet/HellPot/extra"
|
||||
"github.com/yunginnanet/HellPot/http"
|
||||
"github.com/yunginnanet/HellPot/internal/config"
|
||||
"github.com/yunginnanet/HellPot/internal/extra"
|
||||
"github.com/yunginnanet/HellPot/internal/http"
|
||||
)
|
||||
|
||||
var log zerolog.Logger
|
||||
|
@ -17,7 +17,7 @@ var log zerolog.Logger
|
|||
func init() {
|
||||
config.Init()
|
||||
if config.BannerOnly {
|
||||
extra.PrintBanner()
|
||||
extra.Banner()
|
||||
os.Exit(0)
|
||||
}
|
||||
log = config.StartLogger()
|
||||
|
|
|
@ -1,27 +0,0 @@
|
|||
package extra
|
||||
|
||||
import (
|
||||
"os"
|
||||
"runtime"
|
||||
|
||||
"github.com/yunginnanet/HellPot/config"
|
||||
)
|
||||
|
||||
func bannerFail(errs ...error) {
|
||||
println("failed printing banner, consider using --nocolor")
|
||||
for _, err := range errs {
|
||||
if err != nil {
|
||||
println(err.Error())
|
||||
}
|
||||
}
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Banner prints out our banner (using spooky magic)
|
||||
func Banner() {
|
||||
if runtime.GOOS == "windows" || config.NoColor {
|
||||
println(config.Title + " " + config.Version)
|
||||
return
|
||||
}
|
||||
PrintBanner()
|
||||
}
|
|
@ -9,46 +9,34 @@ import (
|
|||
"io"
|
||||
"sync"
|
||||
|
||||
"github.com/yunginnanet/HellPot/config"
|
||||
"github.com/yunginnanet/HellPot/internal/config"
|
||||
)
|
||||
|
||||
var log = config.GetLogger()
|
||||
|
||||
// DefaultHeffalump represents a Heffalump type
|
||||
var DefaultHeffalump = NewHeffalump(DefaultMarkovMap, 100*1<<10)
|
||||
var DefaultHeffalump *Heffalump
|
||||
|
||||
// Heffalump represents our buffer pool and markov map from Heffalump
|
||||
// https://github.com/carlmjohnson/heffalump
|
||||
type Heffalump struct {
|
||||
pool sync.Pool
|
||||
pool *sync.Pool
|
||||
buffsize int
|
||||
mm MarkovMap
|
||||
}
|
||||
|
||||
// NewHeffalump instantiates a new Heffalump for markov generation and buffer/io operations
|
||||
// https://github.com/carlmjohnson/heffalump
|
||||
func NewHeffalump(mm MarkovMap, buffsize int) *Heffalump {
|
||||
return &Heffalump{
|
||||
pool: sync.Pool{},
|
||||
pool: &sync.Pool{New: func() interface{} {
|
||||
b := make([]byte, buffsize)
|
||||
return b
|
||||
}},
|
||||
buffsize: buffsize,
|
||||
mm: mm,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Heffalump) getBuffer() []byte {
|
||||
x := h.pool.Get()
|
||||
if buf, ok := x.([]byte); ok {
|
||||
return buf
|
||||
}
|
||||
return make([]byte, h.buffsize)
|
||||
}
|
||||
|
||||
func (h *Heffalump) putBuffer(buf []byte) {
|
||||
h.pool.Put(buf)
|
||||
}
|
||||
|
||||
// WriteHell writes markov chain heffalump hell to the provided io.Writer
|
||||
// https://github.com/carlmjohnson/heffalump
|
||||
func (h *Heffalump) WriteHell(bw *bufio.Writer) (int64, error) {
|
||||
var n int64
|
||||
var err error
|
||||
|
@ -59,13 +47,12 @@ func (h *Heffalump) WriteHell(bw *bufio.Writer) (int64, error) {
|
|||
}
|
||||
}()
|
||||
|
||||
buf := h.getBuffer()
|
||||
defer h.putBuffer(buf)
|
||||
buf := h.pool.Get().([]byte)
|
||||
defer h.pool.Put(buf)
|
||||
|
||||
if _, err = bw.WriteString("<HTML>\n<BODY>\n"); err != nil {
|
||||
if _, err = bw.WriteString("<html>\n<body>\n"); err != nil {
|
||||
return n, err
|
||||
}
|
||||
|
||||
if n, err = io.CopyBuffer(bw, h.mm, buf); err != nil {
|
||||
return n, nil
|
||||
}
|
||||
|
|
|
@ -0,0 +1,28 @@
|
|||
package heffalump
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"github.com/yunginnanet/HellPot/internal/util"
|
||||
)
|
||||
|
||||
func TestHeffalump_WriteHell(t *testing.T) {
|
||||
sp := util.NewCappedSpeedometer(io.Discard, 55555)
|
||||
var err error
|
||||
var count int64
|
||||
for err == nil {
|
||||
var cnt int64
|
||||
cnt, err = DefaultHeffalump.WriteHell(sp.BufioWriter)
|
||||
t.Logf("written: %d", cnt)
|
||||
count += cnt
|
||||
}
|
||||
if !errors.Is(err, util.ErrLimitReached) {
|
||||
t.Errorf("expected %v, got %v", util.ErrLimitReached, err)
|
||||
} else {
|
||||
t.Logf("err: %v", err)
|
||||
}
|
||||
t.Logf("count: %d", count)
|
||||
t.Logf("rate: %f per second", sp.Rate())
|
||||
}
|
|
@ -5,12 +5,29 @@ import (
|
|||
"io"
|
||||
"math/rand"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode"
|
||||
"unicode/utf8"
|
||||
|
||||
"git.tcp.direct/kayos/common/squish"
|
||||
)
|
||||
|
||||
var DefaultMarkovMap MarkovMap
|
||||
|
||||
func init() {
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
// DefaultMarkovMap is a Markov chain based on src.
|
||||
src, err := squish.UnpackStr(SrcGz)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if len(src) < 1 {
|
||||
panic("failed to unpack source")
|
||||
}
|
||||
DefaultMarkovMap = MakeMarkovMap(strings.NewReader(src))
|
||||
DefaultHeffalump = NewHeffalump(DefaultMarkovMap, 100*1<<10)
|
||||
}
|
||||
|
||||
// ScanHTML is a basic split function for a Scanner that returns each
|
||||
// space-separated word of text or HTML tag, with surrounding spaces deleted.
|
||||
// It will never return an empty string. The definition of space is set by
|
||||
|
@ -25,14 +42,15 @@ func ScanHTML(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
|||
break
|
||||
}
|
||||
}
|
||||
if r == '<' {
|
||||
switch {
|
||||
case r == '<':
|
||||
// Scan until closing bracket
|
||||
for i := start; i < len(data); i++ {
|
||||
if data[i] == '>' {
|
||||
return i + 1, data[start : i+1], nil
|
||||
}
|
||||
}
|
||||
} else {
|
||||
default:
|
||||
// Scan until space, marking end of word.
|
||||
for width, i := 0, start; i < len(data); i += width {
|
||||
var r rune
|
||||
|
@ -55,14 +73,6 @@ func ScanHTML(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
|||
|
||||
type tokenPair [2]string
|
||||
|
||||
// DefaultMarkovMap is a Markov chain based on src.
|
||||
var DefaultMarkovMap MarkovMap
|
||||
|
||||
func init() {
|
||||
src, _ := squish.UnpackStr(srcGz)
|
||||
DefaultMarkovMap = MakeMarkovMap(strings.NewReader(src))
|
||||
}
|
||||
|
||||
// MarkovMap is a map that acts as a Markov chain generator.
|
||||
type MarkovMap map[tokenPair][]string
|
||||
|
||||
|
@ -80,7 +90,7 @@ func (mm MarkovMap) Fill(r io.Reader) {
|
|||
s := bufio.NewScanner(r)
|
||||
s.Split(ScanHTML)
|
||||
for s.Scan() {
|
||||
w3 := s.Text()
|
||||
w3 = s.Text()
|
||||
mm.Add(w1, w2, w3)
|
||||
w1, w2 = w2, w3
|
||||
}
|
||||
|
@ -101,7 +111,6 @@ func (mm MarkovMap) Get(w1, w2 string) string {
|
|||
if !ok {
|
||||
return ""
|
||||
}
|
||||
|
||||
// We don't care about cryptographically sound entropy here, ignore gosec G404.
|
||||
/* #nosec */
|
||||
r := rand.Intn(len(suffix))
|
||||
|
@ -111,7 +120,6 @@ func (mm MarkovMap) Get(w1, w2 string) string {
|
|||
// Read fills p with data from calling Get on the MarkovMap.
|
||||
func (mm MarkovMap) Read(p []byte) (n int, err error) {
|
||||
var w1, w2, w3 string
|
||||
|
||||
for {
|
||||
w3 = mm.Get(w1, w2)
|
||||
if n+len(w3)+1 >= len(p) {
|
||||
|
@ -121,6 +129,5 @@ func (mm MarkovMap) Read(p []byte) (n int, err error) {
|
|||
n += copy(p[n:], "\n")
|
||||
w1, w2 = w2, w3
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
|
File diff suppressed because one or more lines are too long
|
@ -19,7 +19,7 @@ var (
|
|||
customconfig = false
|
||||
home string
|
||||
prefConfigLocation string
|
||||
snek *viper.Viper
|
||||
snek = viper.New()
|
||||
)
|
||||
|
||||
// exported generic vars
|
||||
|
@ -32,30 +32,7 @@ var (
|
|||
Filename string
|
||||
)
|
||||
|
||||
func init() {
|
||||
prefConfigLocation = home + "/.config/" + Title
|
||||
snek = viper.New()
|
||||
}
|
||||
|
||||
func windowsConfig() {
|
||||
newconfig := "hellpot-config"
|
||||
snek.SetConfigName(newconfig)
|
||||
if err := snek.MergeInConfig(); err == nil {
|
||||
return
|
||||
}
|
||||
if err := snek.SafeWriteConfigAs(newconfig + ".toml"); err != nil {
|
||||
fmt.Println(err.Error())
|
||||
os.Exit(1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func writeConfig() {
|
||||
//goland:noinspection GoBoolExpressions
|
||||
if runtime.GOOS == "windows" {
|
||||
windowsConfig()
|
||||
return
|
||||
}
|
||||
if _, err := os.Stat(prefConfigLocation); os.IsNotExist(err) {
|
||||
if err = os.MkdirAll(prefConfigLocation, 0o750); err != nil {
|
||||
println("error writing new config: " + err.Error())
|
||||
|
@ -64,7 +41,7 @@ func writeConfig() {
|
|||
}
|
||||
Filename = prefConfigLocation + "/" + "config.toml"
|
||||
if err := snek.SafeWriteConfigAs(Filename); err != nil {
|
||||
fmt.Println("Failed to write new configuration file: " + err.Error())
|
||||
fmt.Println("Failed to write new configuration file to '" + Filename + "': " + err.Error())
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
@ -119,9 +96,8 @@ func loadCustomConfig(path string) {
|
|||
os.Exit(1)
|
||||
}
|
||||
|
||||
Filename, _ = filepath.Abs(path)
|
||||
|
||||
if len(Filename) < 1 {
|
||||
Filename, err = filepath.Abs(path)
|
||||
if len(Filename) < 1 || err != nil {
|
||||
Filename = path
|
||||
}
|
||||
|
|
@ -3,6 +3,7 @@ package config
|
|||
import (
|
||||
"io"
|
||||
"os"
|
||||
"path"
|
||||
"runtime"
|
||||
|
||||
"github.com/spf13/afero"
|
||||
|
@ -13,8 +14,8 @@ func init() {
|
|||
if home, err = os.UserHomeDir(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defOpts["logger"]["directory"] = home + "/.local/share/" + Title + "/logs/"
|
||||
|
||||
defOpts["logger"]["directory"] = path.Join(home, ".local", "share", Title+"logs")
|
||||
prefConfigLocation = path.Join(home, ".config", Title)
|
||||
}
|
||||
|
||||
var (
|
||||
|
@ -81,7 +82,6 @@ func setDefaults() {
|
|||
memfs := afero.NewMemMapFs()
|
||||
//goland:noinspection GoBoolExpressions
|
||||
if runtime.GOOS == "windows" {
|
||||
snek.SetDefault("logger.directory", "./hellpot-logs/")
|
||||
defNoColor = true
|
||||
}
|
||||
for _, def := range configSections {
|
|
@ -2,6 +2,7 @@ package config
|
|||
|
||||
import (
|
||||
"os"
|
||||
"path"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
|
@ -36,7 +37,7 @@ func StartLogger() zerolog.Logger {
|
|||
logFileName = logFileName + "_" + tn
|
||||
}
|
||||
|
||||
CurrentLogFile = logDir + logFileName + ".log"
|
||||
CurrentLogFile = path.Join(logDir, logFileName+".log")
|
||||
|
||||
var err error
|
||||
|
|
@ -4,12 +4,14 @@ import (
|
|||
crip "crypto/rand"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.tcp.direct/kayos/common/squish"
|
||||
|
||||
"github.com/yunginnanet/HellPot/config"
|
||||
"github.com/yunginnanet/HellPot/internal/config"
|
||||
)
|
||||
|
||||
const hellpot = "H4sIAAAAAAACA8VXvW7bQAze9QpZOGQUZNXntBD6Ahm7Gx1cx0jdRnKRKAUCdPDgQavOgB/QTxLZ1P3oRJ5Obo0CtnE5feSR30fylOhmfjv9PEtzwIXIj4dds/xw2jsequNB2gizXd3Mxad2O81PX7AAe+UNGneuR8aUOuTsqQUDXAMv1cJE5Tfbn6GaKz45kpid+lQc3zoNY5zmEUEt+jCGNZUjeYr0StZYmbwtwNavuCaUFWA8MjxVIImjNas6TPQT9Tnq4MnYJF0zkhVU4rLvqflscU/ox0Lg45qKTjoSmiLQPA+ZuTT7BbrckpfWKMkUquTErIPEYbPoKjamy6SjR0feGssPUMYTCDWEnrR8c0m7hJ2B4jekK2KUsBfa7bpTD0ftnmKPE9nN2IzcLc99vxhIUbszlwqrJoklpQWlI6AeQh9nDHXj2ldOvyat/vZdDxVfzZdbSuspRUe/+IKZtxq2GWlbZzS6jnrnDEXGCkXUGnahuTgAA+DY9HU8FUoYH3ji/q84HetDWmT/Y3ml6oX21/eCNzB46+6UuVTSQHXgGmzUTJT/zeNQ3zCvysEBuH3hER9CbhNa6FoLHSBfT2gmK/rFKCj/K1nTfcBduKHVwgjo+Y+HilXBEAqhKg1X6lQzMaIF6ZK6ipVILR0Awh16SWy9KsxvZXWbL34oGpNmMcPNdYFmiE40+qV9cg4Logjm2uXjukzK5a/kYf28WpaTn4u3zcvkfvX09GVTnuFfEYzBNujvr9+S5SafvL0Wj+uiWBSrsov/I6axmMXiLhYf40zE2TTOZnF2F2fNn2n0DpcvBxhQEAAA"
|
||||
|
@ -56,9 +58,29 @@ func ru() uint32 {
|
|||
return binary.LittleEndian.Uint32(b)
|
||||
}
|
||||
|
||||
// PrintBanner prints our entropic banner
|
||||
func PrintBanner() {
|
||||
// printBanner prints our entropic banner
|
||||
func printBanner() {
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
defer time.Sleep(5 * time.Millisecond)
|
||||
println("\n" + process(hellpot) + "\n\n")
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
}
|
||||
|
||||
func bannerFail(errs ...error) {
|
||||
println("failed printing banner, consider using --nocolor")
|
||||
for _, err := range errs {
|
||||
if err != nil {
|
||||
println(err.Error())
|
||||
}
|
||||
}
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Banner prints out our banner (using spooky magic)
|
||||
func Banner() {
|
||||
//goland:noinspection GoBoolExpressions
|
||||
if runtime.GOOS == "windows" || config.NoColor {
|
||||
println(config.Title + " " + config.Version)
|
||||
return
|
||||
}
|
||||
printBanner()
|
||||
}
|
|
@ -0,0 +1 @@
|
|||
package extra
|
|
@ -6,7 +6,7 @@ import (
|
|||
|
||||
"github.com/valyala/fasthttp"
|
||||
|
||||
"github.com/yunginnanet/HellPot/config"
|
||||
"github.com/yunginnanet/HellPot/internal/config"
|
||||
)
|
||||
|
||||
func robotsTXT(ctx *fasthttp.RequestCtx) {
|
|
@ -12,8 +12,8 @@ import (
|
|||
"github.com/rs/zerolog"
|
||||
"github.com/valyala/fasthttp"
|
||||
|
||||
"github.com/yunginnanet/HellPot/config"
|
||||
"github.com/yunginnanet/HellPot/heffalump"
|
||||
"github.com/yunginnanet/HellPot/internal/config"
|
||||
)
|
||||
|
||||
var log *zerolog.Logger
|
|
@ -1,5 +1,4 @@
|
|||
//go:build linux || darwin || freebsd
|
||||
// +build linux darwin freebsd
|
||||
|
||||
package http
|
||||
|
||||
|
@ -11,7 +10,7 @@ import (
|
|||
"github.com/fasthttp/router"
|
||||
"github.com/valyala/fasthttp"
|
||||
|
||||
"github.com/yunginnanet/HellPot/config"
|
||||
"github.com/yunginnanet/HellPot/internal/config"
|
||||
)
|
||||
|
||||
func listenOnUnixSocket(addr string, r *router.Router) error {
|
|
@ -0,0 +1,158 @@
|
|||
package util
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"io"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
var ErrLimitReached = errors.New("limit reached")
|
||||
|
||||
type Speedometer struct {
|
||||
cap int64
|
||||
speedLimit SpeedLimit
|
||||
birth *time.Time
|
||||
duration time.Duration
|
||||
internal atomics
|
||||
w io.Writer
|
||||
BufioWriter *bufio.Writer
|
||||
}
|
||||
|
||||
type atomics struct {
|
||||
closed *atomic.Bool
|
||||
count *int64
|
||||
start *sync.Once
|
||||
stop *sync.Once
|
||||
}
|
||||
|
||||
func NewSpeedometer(w io.Writer) *Speedometer {
|
||||
z := int64(0)
|
||||
speedo := &Speedometer{
|
||||
w: w,
|
||||
cap: -1,
|
||||
internal: atomics{
|
||||
count: &z,
|
||||
closed: new(atomic.Bool),
|
||||
stop: new(sync.Once),
|
||||
start: new(sync.Once),
|
||||
},
|
||||
}
|
||||
speedo.internal.closed.Store(false)
|
||||
speedo.BufioWriter = bufio.NewWriter(speedo)
|
||||
return speedo
|
||||
}
|
||||
|
||||
type SpeedLimit struct {
|
||||
Bytes int
|
||||
Per time.Duration
|
||||
CheckEveryBytes int
|
||||
Delay time.Duration
|
||||
}
|
||||
|
||||
func NewLimitedSpeedometer(w io.Writer, speedLimit SpeedLimit) *Speedometer {
|
||||
z := int64(0)
|
||||
speedo := &Speedometer{
|
||||
w: w,
|
||||
cap: -1,
|
||||
speedLimit: speedLimit,
|
||||
internal: atomics{
|
||||
count: &z,
|
||||
closed: new(atomic.Bool),
|
||||
stop: new(sync.Once),
|
||||
start: new(sync.Once),
|
||||
},
|
||||
}
|
||||
speedo.internal.closed.Store(false)
|
||||
speedo.BufioWriter = bufio.NewWriterSize(speedo, speedLimit.CheckEveryBytes)
|
||||
return speedo
|
||||
}
|
||||
|
||||
func NewCappedSpeedometer(w io.Writer, cap int64) *Speedometer {
|
||||
z := int64(0)
|
||||
speedo := &Speedometer{
|
||||
w: w,
|
||||
cap: cap,
|
||||
internal: atomics{
|
||||
count: &z,
|
||||
closed: new(atomic.Bool),
|
||||
stop: new(sync.Once),
|
||||
start: new(sync.Once),
|
||||
},
|
||||
}
|
||||
speedo.internal.closed.Store(false)
|
||||
speedo.BufioWriter = bufio.NewWriterSize(speedo, int(cap)/10)
|
||||
return speedo
|
||||
}
|
||||
|
||||
func (s *Speedometer) increment(inc int64) (int, error) {
|
||||
if s.internal.closed.Load() {
|
||||
return 0, io.ErrClosedPipe
|
||||
}
|
||||
var err error
|
||||
if s.cap > 0 && s.Total()+inc > s.cap {
|
||||
_ = s.Close()
|
||||
err = ErrLimitReached
|
||||
inc = s.cap - s.Total()
|
||||
}
|
||||
atomic.AddInt64(s.internal.count, inc)
|
||||
return int(inc), err
|
||||
}
|
||||
|
||||
func (s *Speedometer) Running() bool {
|
||||
return !s.internal.closed.Load()
|
||||
}
|
||||
|
||||
func (s *Speedometer) Total() int64 {
|
||||
return atomic.LoadInt64(s.internal.count)
|
||||
}
|
||||
|
||||
func (s *Speedometer) Reset() {
|
||||
s.internal.count = new(int64)
|
||||
s.internal.closed = new(atomic.Bool)
|
||||
s.internal.start = new(sync.Once)
|
||||
s.internal.stop = new(sync.Once)
|
||||
s.BufioWriter = bufio.NewWriter(s)
|
||||
s.internal.closed.Store(false)
|
||||
}
|
||||
|
||||
func (s *Speedometer) Close() error {
|
||||
s.internal.stop.Do(func() {
|
||||
s.internal.closed.Store(true)
|
||||
stopped := time.Now()
|
||||
s.duration = stopped.Sub(*s.birth)
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Speedometer) Rate() float64 {
|
||||
if s.internal.closed.Load() {
|
||||
return float64(s.Total()) / s.duration.Seconds()
|
||||
}
|
||||
return float64(s.Total()) / time.Since(*s.birth).Seconds()
|
||||
}
|
||||
|
||||
func (s *Speedometer) Write(p []byte) (n int, err error) {
|
||||
s.internal.start.Do(func() {
|
||||
tn := time.Now()
|
||||
s.birth = &tn
|
||||
})
|
||||
accepted, err := s.increment(int64(len(p)))
|
||||
if err != nil {
|
||||
wn, innerErr := s.w.Write(p[:accepted])
|
||||
if innerErr != nil {
|
||||
err = innerErr
|
||||
}
|
||||
if wn < accepted {
|
||||
err = io.ErrShortWrite
|
||||
}
|
||||
return wn, err
|
||||
}
|
||||
return s.w.Write(p)
|
||||
}
|
||||
|
||||
/*func BenchmarkHeffalump_WriteHell(b *testing.B) {
|
||||
|
||||
}*/
|
|
@ -0,0 +1,101 @@
|
|||
package util
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"io"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func writeStuff(target *bufio.Writer, count int) error {
|
||||
write := func() error {
|
||||
_, err := target.WriteString("a")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if count < 0 {
|
||||
for {
|
||||
if err := write(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
for i := 0; i < count; i++ {
|
||||
if err := write(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func Test_Speedometer(t *testing.T) {
|
||||
type results struct {
|
||||
total int64
|
||||
written int
|
||||
rate float64
|
||||
err error
|
||||
}
|
||||
|
||||
isIt := func(want, have results) {
|
||||
if have.total != want.total {
|
||||
t.Errorf("total: want %d, have %d", want.total, have.total)
|
||||
}
|
||||
if have.written != want.written {
|
||||
t.Errorf("written: want %d, have %d", want.written, have.written)
|
||||
}
|
||||
if have.rate != want.rate {
|
||||
t.Errorf("rate: want %f, have %f", want.rate, have.rate)
|
||||
}
|
||||
if have.err != want.err {
|
||||
t.Errorf("wantErr: want %v, have %v", want.err, have.err)
|
||||
}
|
||||
}
|
||||
|
||||
sp := NewSpeedometer(io.Discard)
|
||||
errChan := make(chan error)
|
||||
go func() {
|
||||
errChan <- writeStuff(sp.BufioWriter, -1)
|
||||
}()
|
||||
time.Sleep(1 * time.Second)
|
||||
_ = sp.Close()
|
||||
err := <-errChan
|
||||
cnt, err := sp.Write([]byte("a"))
|
||||
isIt(results{err: io.ErrClosedPipe, written: 0}, results{err: err, written: cnt})
|
||||
sp.Reset()
|
||||
cnt, err = sp.Write([]byte("a"))
|
||||
isIt(results{err: nil, written: 1, total: 1}, results{err: err, written: cnt, total: sp.Total()})
|
||||
cnt, err = sp.Write([]byte("aa"))
|
||||
isIt(results{err: nil, written: 2, total: 3}, results{err: err, written: cnt, total: sp.Total()})
|
||||
cnt, err = sp.Write([]byte("a"))
|
||||
isIt(results{err: nil, written: 1, total: 4}, results{err: err, written: cnt, total: sp.Total()})
|
||||
cnt, err = sp.Write([]byte("a"))
|
||||
isIt(results{err: nil, written: 1, total: 5}, results{err: err, written: cnt, total: sp.Total()})
|
||||
go func() {
|
||||
errChan <- writeStuff(sp.BufioWriter, -1)
|
||||
}()
|
||||
var count = 0
|
||||
for sp.Running() {
|
||||
select {
|
||||
case err = <-errChan:
|
||||
if !errors.Is(err, io.ErrClosedPipe) {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
} else {
|
||||
if count < 5 {
|
||||
t.Errorf("too few iterations: %d", count)
|
||||
}
|
||||
t.Logf("final rate: %v per second", sp.Rate())
|
||||
}
|
||||
default:
|
||||
if count > 5 {
|
||||
_ = sp.Close()
|
||||
}
|
||||
t.Logf("rate: %v per second", sp.Rate())
|
||||
time.Sleep(1 * time.Second)
|
||||
count++
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue