From f8d9462a6683b3456bd300cfdf0efc915895d33a Mon Sep 17 00:00:00 2001 From: "kayos@tcp.direct" Date: Sun, 2 Oct 2022 22:06:42 -0700 Subject: [PATCH] Feat: MustWriteString, MustPut --- pool/errors.go | 5 + pool/strings.go | 49 +++++--- pool/strings_test.go | 269 +++++++++++++++++++++++++++++-------------- 3 files changed, 223 insertions(+), 100 deletions(-) create mode 100644 pool/errors.go diff --git a/pool/errors.go b/pool/errors.go new file mode 100644 index 0000000..eefb157 --- /dev/null +++ b/pool/errors.go @@ -0,0 +1,5 @@ +package pool + +import "errors" + +var ErrBufferReturned = errors.New("buffer already returned") diff --git a/pool/strings.go b/pool/strings.go index 866c7e3..4d51390 100644 --- a/pool/strings.go +++ b/pool/strings.go @@ -1,16 +1,22 @@ package pool import ( - "errors" "strings" "sync" ) -var ErrBufferReturned = errors.New("buffer already returned") - type String struct { *strings.Builder - *sync.Once + o *sync.Once +} + +// NewStringFactory creates a new strings.Builder pool. +func NewStringFactory() StringFactory { + return StringFactory{ + pool: &sync.Pool{ + New: func() any { return new(strings.Builder) }, + }, + } } func (s String) String() string { @@ -35,6 +41,17 @@ func (s String) WriteString(str string) (int, error) { return s.Builder.WriteString(str) } +// MustWstr means Must Write String, like WriteString but will panic on error. +func (s String) MustWstr(str string) { + if s.Builder == nil { + panic(ErrBufferReturned) + } + if str == "" { + panic("nil string") + } + _, _ = s.Builder.WriteString(str) +} + func (s String) Len() int { if s.Builder == nil { return 0 @@ -82,19 +99,10 @@ type StringFactory struct { pool *sync.Pool } -// NewStringFactory creates a new strings.Builder pool. -func NewStringFactory() StringFactory { - return StringFactory{ - pool: &sync.Pool{ - New: func() any { return new(strings.Builder) }, - }, - } -} - // Put returns a strings.Builder back into to the pool after resetting it. func (sf StringFactory) Put(buf *String) error { var err = ErrBufferReturned - buf.Do(func() { + buf.o.Do(func() { _ = buf.Reset() sf.pool.Put(buf.Builder) buf.Builder = nil @@ -103,6 +111,19 @@ func (sf StringFactory) Put(buf *String) error { return err } +func (sf StringFactory) MustPut(buf *String) { + var err = ErrBufferReturned + buf.o.Do(func() { + _ = buf.Reset() + sf.pool.Put(buf.Builder) + buf.Builder = nil + err = nil + }) + if err != nil { + panic(err) + } +} + // Get returns a strings.Builder from the pool. func (sf StringFactory) Get() *String { return &String{ diff --git a/pool/strings_test.go b/pool/strings_test.go index c70c138..5225be2 100644 --- a/pool/strings_test.go +++ b/pool/strings_test.go @@ -2,94 +2,191 @@ package pool import ( "testing" + "time" ) +func assertPanic(t *testing.T, f func()) { + defer func() { + if r := recover(); r == nil { + t.Errorf("The code did not panic") + } + }() + f() +} + +func TestStringFactoryPanic(t *testing.T) { + s := NewStringFactory() + t.Run("StringsMustWrite", func(t *testing.T) { + buf := s.Get() + buf.MustWstr("hello world") + if buf.Len() == 0 { + t.Fatalf("The buffer is empty after we wrote to it") + } + if buf.String() != "hello world" { + t.Fatalf("The buffer has the wrong content") + } + }) + t.Run("StringsMustWritePanic", func(t *testing.T) { + var badString *string = nil + buf := s.Get() + assertPanic(t, func() { + buf.MustWstr(*badString) + }) + assertPanic(t, func() { + buf.MustWstr("") + }) + if err := s.Put(buf); err != nil { + t.Fatalf("The buffer was not returned: %v", err) + } + }) + t.Run("StringsPanic", func(t *testing.T) { + buf := s.Get() + err := s.Put(buf) + if err != nil { + t.Fatalf("The buffer was not returned: %v", err) + } + assertPanic(t, func() { + s.MustPut(buf) + }) + assertPanic(t, func() { + buf.MustWstr("hello") + }) + }) +} + func TestStringFactory(t *testing.T) { s := NewStringFactory() - buf := s.Get() - if _, err := buf.WriteString("hello"); err != nil { - t.Fatal(err) - } - if buf.String() != "hello" { - t.Fatal("unexpected string") - } - if err := buf.WriteByte(' '); err != nil { - t.Fatal(err) - } - if buf.String() != "hello " { - t.Fatalf("unexpected string: %s", buf.String()) - } - if _, err := buf.WriteRune('w'); err != nil { - t.Fatal(err) - } - if buf.String() != "hello w" { - t.Fatalf("unexpected string: %s", buf.String()) - } - if _, err := buf.Write([]byte("orld")); err != nil { - t.Fatal(err) - } - if err := buf.Grow(1); err != nil { - t.Fatal(err) - } - if buf.Cap() == 0 { - t.Fatal("expected capacity, got 0") - } - if err := buf.Reset(); err != nil { - t.Fatal(err) - } - if buf.String() != "" { - t.Fatalf("unexpected string: %s", buf.String()) - } - if err := s.Put(buf); err != nil { - t.Fatal(err) - } - if err := s.Put(buf); err == nil { - t.Fatal("expected error") - } - if s.Get().Len() > 0 { - t.Fatalf("StringFactory.Put() did not reset the buffer") - } - if err := s.Put(buf); err == nil { - t.Fatalf("StringFactory.Put() should have returned an error after already returning the buffer") - } - if err := buf.Grow(10); err == nil { - t.Fatalf("StringFactory.Grow() should not work after returning the buffer") - } - if buf.Cap() != 0 { - t.Fatalf("StringFactory.Cap() should return 0 after returning the buffer") - } - got := s.Get() - if got.String() != "" { - t.Fatalf("should have gotten a clean buffer") - } - if err := s.Put(got); err != nil { - t.Fatalf("unexpected error: %v", err) - } - if _, err := got.WriteString("a"); err == nil { - t.Fatalf("should not be able to write to a returned buffer") - } - if _, err := got.WriteRune('a'); err == nil { - t.Fatalf("should not be able to write to a returned buffer") - } - if err := got.WriteByte('a'); err == nil { - t.Fatalf("should not be able to write to a returned buffer") - } - if _, err := got.Write([]byte("a")); err == nil { - t.Fatalf("should not be able to write to a returned buffer") - } - if err := got.Reset(); err == nil { - t.Fatalf("should not be able to reset a returned buffer") - } - if str := got.String(); str != "" { - t.Fatalf("should not be able to get string from a returned buffer") - } - if got.Len() != 0 { - t.Fatalf("should not be able to write to a returned buffer") - } - if got = s.Get(); got.Len() > 0 { - t.Fatalf("StringFactory.Put() did not reset the buffer") - } - if got.Cap() != 0 { - t.Fatalf("StringFactory.Put() did not reset the buffer") - } + t.Run("StringPoolHelloWorld", func(t *testing.T) { + t.Parallel() + buf := s.Get() + if _, err := buf.WriteString("hello"); err != nil { + t.Fatal(err) + } + if buf.String() != "hello" { + t.Fatal("unexpected string") + } + if err := buf.WriteByte(' '); err != nil { + t.Fatal(err) + } + if buf.String() != "hello " { + t.Fatalf("unexpected string: %s", buf.String()) + } + if _, err := buf.WriteRune('w'); err != nil { + t.Fatal(err) + } + if buf.String() != "hello w" { + t.Fatalf("unexpected string: %s", buf.String()) + } + if _, err := buf.Write([]byte("orld")); err != nil { + t.Fatal(err) + } + if err := buf.Grow(1); err != nil { + t.Fatal(err) + } + if buf.Cap() == 0 { + t.Fatal("expected capacity, got 0") + } + if err := buf.Reset(); err != nil { + t.Fatal(err) + } + if buf.String() != "" { + t.Fatalf("unexpected string: %s", buf.String()) + } + if err := s.Put(buf); err != nil { + t.Fatal(err) + } + if err := s.Put(buf); err == nil { + t.Fatal("expected error") + } + }) + t.Run("StringPoolCheckGetLength", func(t *testing.T) { + t.Parallel() + buf := s.Get() + if buf.Len() > 0 { + t.Fatalf("StringFactory.Put() did not reset the buffer") + } + if err := s.Put(buf); err != nil { + t.Fatal(err) + } + if err := s.Put(buf); err == nil { + t.Fatalf("StringFactory.Put() should have returned an error after already returning the buffer") + } + }) + t.Run("StringPoolGrowBuffer", func(t *testing.T) { + t.Parallel() + buf := s.Get() + if err := buf.Grow(1); err != nil { + t.Fatal(err) + } + if buf.Cap() != 1 { + t.Fatalf("expected capacity of 1, got %d", buf.Cap()) + } + if err := s.Put(buf); err != nil { + t.Fatal(err) + } + if err := buf.Grow(10); err == nil { + t.Fatalf("StringFactory.Grow() should not work after returning the buffer") + } + if buf.Cap() != 0 { + t.Fatalf("StringFactory.Cap() should return 0 after returning the buffer") + } + }) + t.Run("StringPoolCleanBuffer", func(t *testing.T) { + t.Parallel() + time.Sleep(100 * time.Millisecond) + got := s.Get() + if got.String() != "" { + t.Fatalf("should have gotten a clean buffer") + } + if err := s.Put(got); err != nil { + t.Fatalf("unexpected error: %v", err) + } + }) + + t.Run("StringPoolWriteStringToReturnedBuffer", func(t *testing.T) { + t.Parallel() + got := s.Get() + s.MustPut(got) + if _, err := got.WriteString("a"); err == nil { + t.Fatalf("should not be able to write to a returned buffer") + } + }) + t.Run("StringPoolWriteRuneToReturnedBuffer", func(t *testing.T) { + t.Parallel() + got := s.Get() + s.MustPut(got) + if _, err := got.WriteRune('a'); err == nil { + t.Fatalf("should not be able to write to a returned buffer") + } + }) + t.Run("StringPoolWriteByteToReturnedBuffer", func(t *testing.T) { + t.Parallel() + got := s.Get() + s.MustPut(got) + if err := got.WriteByte('a'); err == nil { + t.Fatalf("should not be able to write to a returned buffer") + } + }) + t.Run("StringPoolWriteToReturnedBuffer", func(t *testing.T) { + t.Parallel() + got := s.Get() + s.MustPut(got) + if _, err := got.Write([]byte("a")); err == nil { + t.Fatalf("should not be able to write to a returned buffer") + } + }) + t.Run("StringPoolResetReturnedBuffer", func(t *testing.T) { + t.Parallel() + got := s.Get() + s.MustPut(got) + if err := got.Reset(); err == nil { + t.Fatalf("should not be able to reset a returned buffer") + } + if str := got.String(); str != "" { + t.Fatalf("should not be able to get string from a returned buffer") + } + if got.Len() != 0 { + t.Fatalf("should not be able to write to a returned buffer") + } + }) }