Feat: MustWriteString, MustPut

This commit is contained in:
kayos@tcp.direct 2022-10-02 22:06:42 -07:00
parent fb2b04efbc
commit f8d9462a66
Signed by: kayos
GPG Key ID: 4B841471B4BEE979
3 changed files with 223 additions and 100 deletions

5
pool/errors.go Normal file

@ -0,0 +1,5 @@
package pool
import "errors"
var ErrBufferReturned = errors.New("buffer already returned")

@ -1,16 +1,22 @@
package pool package pool
import ( import (
"errors"
"strings" "strings"
"sync" "sync"
) )
var ErrBufferReturned = errors.New("buffer already returned")
type String struct { type String struct {
*strings.Builder *strings.Builder
*sync.Once o *sync.Once
}
// NewStringFactory creates a new strings.Builder pool.
func NewStringFactory() StringFactory {
return StringFactory{
pool: &sync.Pool{
New: func() any { return new(strings.Builder) },
},
}
} }
func (s String) String() string { func (s String) String() string {
@ -35,6 +41,17 @@ func (s String) WriteString(str string) (int, error) {
return s.Builder.WriteString(str) return s.Builder.WriteString(str)
} }
// MustWstr means Must Write String, like WriteString but will panic on error.
func (s String) MustWstr(str string) {
if s.Builder == nil {
panic(ErrBufferReturned)
}
if str == "" {
panic("nil string")
}
_, _ = s.Builder.WriteString(str)
}
func (s String) Len() int { func (s String) Len() int {
if s.Builder == nil { if s.Builder == nil {
return 0 return 0
@ -82,19 +99,10 @@ type StringFactory struct {
pool *sync.Pool pool *sync.Pool
} }
// NewStringFactory creates a new strings.Builder pool.
func NewStringFactory() StringFactory {
return StringFactory{
pool: &sync.Pool{
New: func() any { return new(strings.Builder) },
},
}
}
// Put returns a strings.Builder back into to the pool after resetting it. // Put returns a strings.Builder back into to the pool after resetting it.
func (sf StringFactory) Put(buf *String) error { func (sf StringFactory) Put(buf *String) error {
var err = ErrBufferReturned var err = ErrBufferReturned
buf.Do(func() { buf.o.Do(func() {
_ = buf.Reset() _ = buf.Reset()
sf.pool.Put(buf.Builder) sf.pool.Put(buf.Builder)
buf.Builder = nil buf.Builder = nil
@ -103,6 +111,19 @@ func (sf StringFactory) Put(buf *String) error {
return err return err
} }
func (sf StringFactory) MustPut(buf *String) {
var err = ErrBufferReturned
buf.o.Do(func() {
_ = buf.Reset()
sf.pool.Put(buf.Builder)
buf.Builder = nil
err = nil
})
if err != nil {
panic(err)
}
}
// Get returns a strings.Builder from the pool. // Get returns a strings.Builder from the pool.
func (sf StringFactory) Get() *String { func (sf StringFactory) Get() *String {
return &String{ return &String{

@ -2,94 +2,191 @@ package pool
import ( import (
"testing" "testing"
"time"
) )
func assertPanic(t *testing.T, f func()) {
defer func() {
if r := recover(); r == nil {
t.Errorf("The code did not panic")
}
}()
f()
}
func TestStringFactoryPanic(t *testing.T) {
s := NewStringFactory()
t.Run("StringsMustWrite", func(t *testing.T) {
buf := s.Get()
buf.MustWstr("hello world")
if buf.Len() == 0 {
t.Fatalf("The buffer is empty after we wrote to it")
}
if buf.String() != "hello world" {
t.Fatalf("The buffer has the wrong content")
}
})
t.Run("StringsMustWritePanic", func(t *testing.T) {
var badString *string = nil
buf := s.Get()
assertPanic(t, func() {
buf.MustWstr(*badString)
})
assertPanic(t, func() {
buf.MustWstr("")
})
if err := s.Put(buf); err != nil {
t.Fatalf("The buffer was not returned: %v", err)
}
})
t.Run("StringsPanic", func(t *testing.T) {
buf := s.Get()
err := s.Put(buf)
if err != nil {
t.Fatalf("The buffer was not returned: %v", err)
}
assertPanic(t, func() {
s.MustPut(buf)
})
assertPanic(t, func() {
buf.MustWstr("hello")
})
})
}
func TestStringFactory(t *testing.T) { func TestStringFactory(t *testing.T) {
s := NewStringFactory() s := NewStringFactory()
buf := s.Get() t.Run("StringPoolHelloWorld", func(t *testing.T) {
if _, err := buf.WriteString("hello"); err != nil { t.Parallel()
t.Fatal(err) buf := s.Get()
} if _, err := buf.WriteString("hello"); err != nil {
if buf.String() != "hello" { t.Fatal(err)
t.Fatal("unexpected string") }
} if buf.String() != "hello" {
if err := buf.WriteByte(' '); err != nil { t.Fatal("unexpected string")
t.Fatal(err) }
} if err := buf.WriteByte(' '); err != nil {
if buf.String() != "hello " { t.Fatal(err)
t.Fatalf("unexpected string: %s", buf.String()) }
} if buf.String() != "hello " {
if _, err := buf.WriteRune('w'); err != nil { t.Fatalf("unexpected string: %s", buf.String())
t.Fatal(err) }
} if _, err := buf.WriteRune('w'); err != nil {
if buf.String() != "hello w" { t.Fatal(err)
t.Fatalf("unexpected string: %s", buf.String()) }
} if buf.String() != "hello w" {
if _, err := buf.Write([]byte("orld")); err != nil { t.Fatalf("unexpected string: %s", buf.String())
t.Fatal(err) }
} if _, err := buf.Write([]byte("orld")); err != nil {
if err := buf.Grow(1); err != nil { t.Fatal(err)
t.Fatal(err) }
} if err := buf.Grow(1); err != nil {
if buf.Cap() == 0 { t.Fatal(err)
t.Fatal("expected capacity, got 0") }
} if buf.Cap() == 0 {
if err := buf.Reset(); err != nil { t.Fatal("expected capacity, got 0")
t.Fatal(err) }
} if err := buf.Reset(); err != nil {
if buf.String() != "" { t.Fatal(err)
t.Fatalf("unexpected string: %s", buf.String()) }
} if buf.String() != "" {
if err := s.Put(buf); err != nil { t.Fatalf("unexpected string: %s", buf.String())
t.Fatal(err) }
} if err := s.Put(buf); err != nil {
if err := s.Put(buf); err == nil { t.Fatal(err)
t.Fatal("expected error") }
} if err := s.Put(buf); err == nil {
if s.Get().Len() > 0 { t.Fatal("expected error")
t.Fatalf("StringFactory.Put() did not reset the buffer") }
} })
if err := s.Put(buf); err == nil { t.Run("StringPoolCheckGetLength", func(t *testing.T) {
t.Fatalf("StringFactory.Put() should have returned an error after already returning the buffer") t.Parallel()
} buf := s.Get()
if err := buf.Grow(10); err == nil { if buf.Len() > 0 {
t.Fatalf("StringFactory.Grow() should not work after returning the buffer") t.Fatalf("StringFactory.Put() did not reset the buffer")
} }
if buf.Cap() != 0 { if err := s.Put(buf); err != nil {
t.Fatalf("StringFactory.Cap() should return 0 after returning the buffer") t.Fatal(err)
} }
got := s.Get() if err := s.Put(buf); err == nil {
if got.String() != "" { t.Fatalf("StringFactory.Put() should have returned an error after already returning the buffer")
t.Fatalf("should have gotten a clean buffer") }
} })
if err := s.Put(got); err != nil { t.Run("StringPoolGrowBuffer", func(t *testing.T) {
t.Fatalf("unexpected error: %v", err) t.Parallel()
} buf := s.Get()
if _, err := got.WriteString("a"); err == nil { if err := buf.Grow(1); err != nil {
t.Fatalf("should not be able to write to a returned buffer") t.Fatal(err)
} }
if _, err := got.WriteRune('a'); err == nil { if buf.Cap() != 1 {
t.Fatalf("should not be able to write to a returned buffer") t.Fatalf("expected capacity of 1, got %d", buf.Cap())
} }
if err := got.WriteByte('a'); err == nil { if err := s.Put(buf); err != nil {
t.Fatalf("should not be able to write to a returned buffer") t.Fatal(err)
} }
if _, err := got.Write([]byte("a")); err == nil { if err := buf.Grow(10); err == nil {
t.Fatalf("should not be able to write to a returned buffer") t.Fatalf("StringFactory.Grow() should not work after returning the buffer")
} }
if err := got.Reset(); err == nil { if buf.Cap() != 0 {
t.Fatalf("should not be able to reset a returned buffer") t.Fatalf("StringFactory.Cap() should return 0 after returning the buffer")
} }
if str := got.String(); str != "" { })
t.Fatalf("should not be able to get string from a returned buffer") t.Run("StringPoolCleanBuffer", func(t *testing.T) {
} t.Parallel()
if got.Len() != 0 { time.Sleep(100 * time.Millisecond)
t.Fatalf("should not be able to write to a returned buffer") got := s.Get()
} if got.String() != "" {
if got = s.Get(); got.Len() > 0 { t.Fatalf("should have gotten a clean buffer")
t.Fatalf("StringFactory.Put() did not reset the buffer") }
} if err := s.Put(got); err != nil {
if got.Cap() != 0 { t.Fatalf("unexpected error: %v", err)
t.Fatalf("StringFactory.Put() did not reset the buffer") }
} })
t.Run("StringPoolWriteStringToReturnedBuffer", func(t *testing.T) {
t.Parallel()
got := s.Get()
s.MustPut(got)
if _, err := got.WriteString("a"); err == nil {
t.Fatalf("should not be able to write to a returned buffer")
}
})
t.Run("StringPoolWriteRuneToReturnedBuffer", func(t *testing.T) {
t.Parallel()
got := s.Get()
s.MustPut(got)
if _, err := got.WriteRune('a'); err == nil {
t.Fatalf("should not be able to write to a returned buffer")
}
})
t.Run("StringPoolWriteByteToReturnedBuffer", func(t *testing.T) {
t.Parallel()
got := s.Get()
s.MustPut(got)
if err := got.WriteByte('a'); err == nil {
t.Fatalf("should not be able to write to a returned buffer")
}
})
t.Run("StringPoolWriteToReturnedBuffer", func(t *testing.T) {
t.Parallel()
got := s.Get()
s.MustPut(got)
if _, err := got.Write([]byte("a")); err == nil {
t.Fatalf("should not be able to write to a returned buffer")
}
})
t.Run("StringPoolResetReturnedBuffer", func(t *testing.T) {
t.Parallel()
got := s.Get()
s.MustPut(got)
if err := got.Reset(); err == nil {
t.Fatalf("should not be able to reset a returned buffer")
}
if str := got.String(); str != "" {
t.Fatalf("should not be able to get string from a returned buffer")
}
if got.Len() != 0 {
t.Fatalf("should not be able to write to a returned buffer")
}
})
} }