From a9ce9387acb6827cbb8028806ece2e266567521c Mon Sep 17 00:00:00 2001 From: "kayos@tcp.direct" Date: Sun, 2 Oct 2022 23:46:32 -0700 Subject: [PATCH] Feat[pool]: safer bytes.Buffer+sync.Pool --- pool/bytes.go | 255 +++++++++++++++++++++++ pool/bytes_test.go | 499 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 754 insertions(+) create mode 100644 pool/bytes.go create mode 100644 pool/bytes_test.go diff --git a/pool/bytes.go b/pool/bytes.go new file mode 100644 index 0000000..0f04ce6 --- /dev/null +++ b/pool/bytes.go @@ -0,0 +1,255 @@ +package pool + +import ( + "bytes" + "io" + "sync" +) + +type BufferFactory struct { + pool *sync.Pool +} + +func NewBufferFactory() BufferFactory { + return BufferFactory{ + pool: &sync.Pool{ + New: func() any { return new(bytes.Buffer) }, + }, + } +} + +func (cf BufferFactory) Put(buf *Buffer) error { + var err = ErrBufferReturned + buf.o.Do(func() { + _ = buf.Reset() + cf.pool.Put(buf.Buffer) + buf.Buffer = nil + err = nil + }) + return err +} + +func (cf BufferFactory) MustPut(buf *Buffer) { + if err := cf.Put(buf); err != nil { + panic(err) + } +} + +func (cf BufferFactory) Get() *Buffer { + return &Buffer{ + cf.pool.Get().(*bytes.Buffer), + &sync.Once{}, + } +} + +type Buffer struct { + *bytes.Buffer + o *sync.Once +} + +func (c Buffer) Bytes() []byte { + if c.Buffer == nil { + return nil + } + return c.Buffer.Bytes() +} + +func (c Buffer) MustBytes() []byte { + if c.Buffer == nil { + panic(ErrBufferReturned) + } + return c.Buffer.Bytes() +} + +func (c Buffer) String() string { + if c.Buffer == nil { + return "" + } + return c.Buffer.String() +} + +func (c Buffer) MustString() string { + if c.Buffer == nil { + panic(ErrBufferReturned) + } + return c.Buffer.String() +} + +func (c Buffer) Reset() error { + if c.Buffer == nil { + return ErrBufferReturned + } + c.Buffer.Reset() + return nil +} + +func (c Buffer) MustReset() { + if err := c.Reset(); err != nil { + panic(err) + } + c.Buffer.Reset() +} + +func (c Buffer) Len() int { + if c.Buffer == nil { + return 0 + } + return c.Buffer.Len() +} + +func (c Buffer) MustLen() int { + if c.Buffer == nil { + panic(ErrBufferReturned) + } + return c.Buffer.Len() +} + +func (c Buffer) Write(p []byte) (int, error) { + if c.Buffer == nil { + return 0, ErrBufferReturned + } + return c.Buffer.Write(p) +} + +func (c Buffer) MustWrite(p []byte) { + if _, err := c.Write(p); err != nil { + panic(err) + } +} + +func (c Buffer) WriteRune(r rune) (int, error) { + if c.Buffer == nil { + return 0, ErrBufferReturned + } + return c.Buffer.WriteRune(r) +} + +func (c Buffer) MustWriteRune(r rune) { + if _, err := c.WriteRune(r); err != nil { + panic(err) + } +} + +func (c Buffer) WriteByte(cyte byte) error { + if c.Buffer == nil { + return ErrBufferReturned + } + return c.Buffer.WriteByte(cyte) +} + +func (c Buffer) MustWriteByte(cyte byte) { + if err := c.WriteByte(cyte); err != nil { + panic(err) + } +} + +func (c Buffer) WriteString(str string) (int, error) { + if c.Buffer == nil { + return 0, ErrBufferReturned + } + return c.Buffer.WriteString(str) +} + +func (c Buffer) Grow(n int) error { + if c.Buffer == nil { + return ErrBufferReturned + } + c.Buffer.Grow(n) + return nil +} + +func (c Buffer) Cap() int { + if c.Buffer == nil { + return 0 + } + return c.Buffer.Cap() +} + +func (c Buffer) Truncate(n int) error { + if c.Buffer == nil { + return ErrBufferReturned + } + c.Buffer.Truncate(n) + return nil +} + +func (c Buffer) MustTruncate(n int) { + if err := c.Truncate(n); err != nil { + panic(err) + } +} + +func (c Buffer) ReadFrom(r io.Reader) (int64, error) { + if c.Buffer == nil { + return 0, ErrBufferReturned + } + return c.Buffer.ReadFrom(r) +} + +func (c Buffer) MustReadFrom(r io.Reader) { + if _, err := c.ReadFrom(r); err != nil { + panic(err) + } +} + +func (c Buffer) WriteTo(w io.Writer) (int64, error) { + if c.Buffer == nil { + return 0, ErrBufferReturned + } + return c.Buffer.WriteTo(w) +} + +func (c Buffer) MustWriteTo(w io.Writer) { + if _, err := c.WriteTo(w); err != nil { + panic(err) + } +} + +func (c Buffer) Read(p []byte) (int, error) { + if c.Buffer == nil { + return 0, ErrBufferReturned + } + return c.Buffer.Read(p) +} + +func (c Buffer) ReadByte() (byte, error) { + if c.Buffer == nil { + return 0, ErrBufferReturned + } + return c.Buffer.ReadByte() +} + +func (c Buffer) ReadRune() (rune, int, error) { + if c.Buffer == nil { + return 0, 0, ErrBufferReturned + } + return c.Buffer.ReadRune() +} + +func (c Buffer) UnreadByte() error { + if c.Buffer == nil { + return ErrBufferReturned + } + return c.Buffer.UnreadByte() +} + +func (c Buffer) UnreadRune() error { + if c.Buffer == nil { + return ErrBufferReturned + } + return c.Buffer.UnreadRune() +} + +func (c Buffer) ReadBytes(delim byte) ([]byte, error) { + if c.Buffer == nil { + return nil, ErrBufferReturned + } + return c.Buffer.ReadBytes(delim) +} + +func (c Buffer) Next(n int) []byte { + if c.Buffer == nil { + return nil + } + return c.Buffer.Next(n) +} diff --git a/pool/bytes_test.go b/pool/bytes_test.go new file mode 100644 index 0000000..db4beb5 --- /dev/null +++ b/pool/bytes_test.go @@ -0,0 +1,499 @@ +package pool + +import ( + "bytes" + "io" + "strings" + "testing" +) + +func TestNewBufferFactory(t *testing.T) { + bf := NewBufferFactory() + if bf.pool == nil { + t.Fatalf("The pool is nil") + } +} + +func TestBufferFactory(t *testing.T) { + bf := NewBufferFactory() + t.Run("BufferPut", func(t *testing.T) { + t.Parallel() + buf := bf.Get() + if err := bf.Put(buf); err != nil { + t.Fatalf("The buffer was not returned: %v", err) + } + if err := bf.Put(buf); err == nil { + t.Fatalf("The buffer was returned twice") + } + }) + t.Run("BufferMustPut", func(t *testing.T) { + t.Parallel() + buf := bf.Get() + bf.MustPut(buf) + assertPanic(t, func() { + bf.MustPut(buf) + }) + }) + t.Run("BufferFactoryGet", func(t *testing.T) { + t.Parallel() + buf := bf.Get() + if buf.Buffer == nil { + t.Fatalf("The buffer is nil") + } + if buf.o == nil { + t.Fatalf("The once is nil") + } + }) + t.Run("BufferBytes", func(t *testing.T) { + t.Parallel() + buf := bf.Get() + if len(buf.Bytes()) != 0 { + t.Fatalf("The bytes are not nil: %v", buf.Bytes()) + } + buf.MustWrite([]byte("hello world")) + if !bytes.Equal(buf.MustBytes(), []byte("hello world")) { + t.Fatalf("The bytes are wrong") + } + bf.MustPut(buf) + if buf.Bytes() != nil { + t.Fatalf("The bytes are not nil") + } + }) + t.Run("BufferMustBytes", func(t *testing.T) { + t.Parallel() + buf := bf.Get() + _, err := buf.Write([]byte("hello")) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(buf.MustBytes(), []byte("hello")) { + t.Fatalf("The bytes are not equal") + } + bf.MustPut(buf) + assertPanic(t, func() { + buf.MustBytes() + }) + }) + t.Run("BufferString", func(t *testing.T) { + t.Parallel() + buf := bf.Get() + if buf.String() != "" { + t.Fatalf("The string is not empty") + } + bf.MustPut(buf) + if buf.String() != "" { + t.Fatalf("The string is not empty") + } + }) + t.Run("BufferMustString", func(t *testing.T) { + t.Parallel() + buf := bf.Get() + _ = buf.MustString() + bf.MustPut(buf) + assertPanic(t, func() { + buf.MustString() + }) + }) + t.Run("BufferLen", func(t *testing.T) { + t.Parallel() + buf := bf.Get() + if buf.Len() != 0 { + t.Fatalf("The length is not zero") + } + bf.MustPut(buf) + if buf.Len() != 0 { + t.Fatalf("The length is not zero") + } + }) + t.Run("BufferMustLen", func(t *testing.T) { + t.Parallel() + buf := bf.Get() + _ = buf.MustLen() + bf.MustPut(buf) + assertPanic(t, func() { + buf.MustLen() + }) + }) + t.Run("BufferCap", func(t *testing.T) { + t.Parallel() + buf := bf.Get() + _ = buf.Cap() + bf.MustPut(buf) + if buf.Cap() != 0 { + t.Fatalf("The capacity is not zero") + } + }) + t.Run("BufferReset", func(t *testing.T) { + t.Parallel() + buf := bf.Get() + buf.MustWrite([]byte("hello")) + err := buf.Reset() + if err != nil { + t.Fatal(err) + } + if buf.Len() != 0 { + t.Fatalf("The length is not zero") + } + bf.MustPut(buf) + }) + t.Run("BufferMustReset", func(t *testing.T) { + t.Parallel() + buf := bf.Get() + buf.MustWrite([]byte("hello")) + buf.MustReset() + if buf.Len() != 0 { + t.Fatalf("The length is not zero") + } + bf.MustPut(buf) + assertPanic(t, func() { + buf.MustReset() + }) + }) + t.Run("BufferWrite", func(t *testing.T) { + t.Parallel() + buf := bf.Get() + _, err := buf.Write([]byte("hello")) + if err != nil { + t.Fatal(err) + } + if buf.Len() != 5 { + t.Fatalf("The length is not five") + } + bf.MustPut(buf) + written, werr := buf.Write([]byte("hello")) + if written != 0 { + t.Fatalf("The written is not zero") + } + if werr == nil { + t.Fatalf("The error is nil") + } + }) + t.Run("BufferMustWrite", func(t *testing.T) { + t.Parallel() + buf := bf.Get() + buf.MustWrite([]byte("hello")) + if buf.Len() != 5 { + t.Fatalf("The length is not five") + } + bf.MustPut(buf) + assertPanic(t, func() { + buf.MustWrite([]byte("hello")) + }) + }) + t.Run("BufferWriteByte", func(t *testing.T) { + t.Parallel() + buf := bf.Get() + err := buf.WriteByte('h') + if err != nil { + t.Fatal(err) + } + if buf.Len() != 1 { + t.Fatalf("The length is not one") + } + bf.MustPut(buf) + werr := buf.WriteByte('h') + if werr == nil { + t.Fatalf("The error is nil") + } + }) + t.Run("BufferMustWriteByte", func(t *testing.T) { + t.Parallel() + buf := bf.Get() + buf.MustWriteByte('h') + if buf.Len() != 1 { + t.Fatalf("The length is not one") + } + bf.MustPut(buf) + assertPanic(t, func() { + buf.MustWriteByte('h') + }) + }) + t.Run("BufferWriteRune", func(t *testing.T) { + t.Parallel() + buf := bf.Get() + _, err := buf.WriteRune('h') + if err != nil { + t.Fatal(err) + } + if buf.Len() != 1 { + t.Fatalf("The length is not one") + } + bf.MustPut(buf) + written, werr := buf.WriteRune('h') + if written != 0 { + t.Fatalf("The written is not zero") + } + if werr == nil { + t.Fatalf("The error is nil") + } + }) + t.Run("BufferMustWriteRune", func(t *testing.T) { + t.Parallel() + buf := bf.Get() + buf.MustWriteRune('h') + if buf.Len() != 1 { + t.Fatalf("The length is not one") + } + bf.MustPut(buf) + assertPanic(t, func() { + buf.MustWriteRune('h') + }) + }) + t.Run("BufferWriteString", func(t *testing.T) { + t.Parallel() + buf := bf.Get() + _, err := buf.WriteString("hello") + if err != nil { + t.Fatal(err) + } + if buf.Len() != 5 { + t.Fatalf("The length is not five") + } + bf.MustPut(buf) + written, werr := buf.WriteString("hello") + if written != 0 { + t.Fatalf("The written is not zero") + } + if werr == nil { + t.Fatalf("The error is nil") + } + }) + t.Run("BufferGrow", func(t *testing.T) { + t.Parallel() + buf := bf.Get() + err := buf.Grow(5) + if buf.Cap() < 5 { + t.Fatalf("The capacity is less than five: %d", buf.Cap()) + } + if err != nil { + t.Fatal(err) + } + bf.MustPut(buf) + if buf.Cap() != 0 { + t.Fatalf("The capacity is not zero") + } + if err = buf.Grow(1); err == nil { + t.Fatal("The error is nil") + } + }) + t.Run("BufferTruncate", func(t *testing.T) { + t.Parallel() + buf := bf.Get() + buf.MustWrite([]byte("hello")) + err := buf.Truncate(3) + if err != nil { + t.Fatal(err) + } + if buf.Len() != 3 { + t.Fatalf("The length is not three") + } + if buf.String() != "hel" { + t.Fatalf("The string is not hel") + } + bf.MustPut(buf) + }) + t.Run("BufferMustTruncate", func(t *testing.T) { + t.Parallel() + buf := bf.Get() + buf.MustWrite([]byte("hello")) + buf.MustTruncate(3) + if buf.Len() != 3 { + t.Fatalf("The length is not three") + } + if buf.String() != "hel" { + t.Fatalf("The string is not hel") + } + bf.MustPut(buf) + assertPanic(t, func() { + buf.MustTruncate(3) + }) + }) + t.Run("BufferRead", func(t *testing.T) { + t.Parallel() + buf := bf.Get() + buf.MustWrite([]byte("hello")) + p := make([]byte, 5) + n, err := buf.Read(p) + if err != nil { + t.Fatal(err) + } + if n != 5 { + t.Fatalf("The n is not five") + } + if string(p) != "hello" { + t.Fatalf("The string is not hello") + } + bf.MustPut(buf) + if _, err = buf.Read(p); err == nil { + t.Fatal("The error is nil after returning the buffer") + } + }) + t.Run("BufferReadByte", func(t *testing.T) { + t.Parallel() + buf := bf.Get() + buf.MustWrite([]byte("hello")) + b, err := buf.ReadByte() + if err != nil { + t.Fatal(err) + } + if b != 'h' { + t.Fatalf("The byte is not h") + } + bf.MustPut(buf) + if _, err = buf.ReadByte(); err == nil { + t.Fatal("The error is nil after returning the buffer") + } + }) + t.Run("BufferReadRune", func(t *testing.T) { + t.Parallel() + buf := bf.Get() + buf.MustWrite([]byte("hello")) + r, size, err := buf.ReadRune() + if err != nil { + t.Fatal(err) + } + if r != 'h' { + t.Fatalf("The rune is not h") + } + if size != 1 { + t.Fatalf("The size is not one") + } + bf.MustPut(buf) + if _, _, err = buf.ReadRune(); err == nil { + t.Fatal("The error is nil after returning the buffer") + } + }) + t.Run("BufferUnreadByte", func(t *testing.T) { + t.Parallel() + buf := bf.Get() + buf.MustWrite([]byte("hello")) + b, err := buf.ReadByte() + if err != nil { + t.Fatal(err) + } + if b != 'h' { + t.Fatalf("The byte is not h") + } + err = buf.UnreadByte() + if err != nil { + t.Fatal(err) + } + b, err = buf.ReadByte() + if err != nil { + t.Fatal(err) + } + if b != 'h' { + t.Fatalf("The byte is not h") + } + bf.MustPut(buf) + if err = buf.UnreadByte(); err == nil { + t.Fatal("The error is nil after returning the buffer") + } + }) + t.Run("BufferUnreadRune", func(t *testing.T) { + t.Parallel() + buf := bf.Get() + buf.MustWrite([]byte("hello")) + r, size, err := buf.ReadRune() + if err != nil { + t.Fatal(err) + } + if r != 'h' { + t.Fatalf("The rune is not h") + } + if size != 1 { + t.Fatalf("The size is not one") + } + err = buf.UnreadRune() + if err != nil { + t.Fatal(err) + } + r, size, err = buf.ReadRune() + if err != nil { + t.Fatal(err) + } + if r != 'h' { + t.Fatalf("The rune is not h") + } + if size != 1 { + t.Fatalf("The size is not one") + } + bf.MustPut(buf) + if err = buf.UnreadRune(); err == nil { + t.Fatal("The error is nil after returning the buffer") + } + }) + t.Run("BufferReadBytes", func(t *testing.T) { + t.Parallel() + buf := bf.Get() + buf.MustWrite([]byte("hello!")) + p, err := buf.ReadBytes('o') + if err != nil { + t.Fatal(err) + } + if string(p) != "hello" { + t.Fatalf("The string is not hello: %v", string(p)) + } + bf.MustPut(buf) + if _, err = buf.ReadBytes('l'); err == nil { + t.Fatal("The error is nil after returning the buffer") + } + }) + t.Run("BufferReadFrom", func(t *testing.T) { + t.Parallel() + buf := bf.Get() + _, err := buf.ReadFrom(strings.NewReader("hello")) + if err != nil { + t.Fatal(err) + } + if buf.Len() != 5 { + t.Fatalf("The length is not five") + } + bf.MustPut(buf) + if _, err = buf.ReadFrom(strings.NewReader("hello")); err == nil { + t.Fatal("The error is nil trying to use a returned buffer") + } + buf = bf.Get() + buf.MustReadFrom(strings.NewReader("hello")) + buf.MustTruncate(5) + if buf.Len() != 5 { + t.Fatalf("The length is not five") + } + bf.MustPut(buf) + assertPanic(t, func() { + buf.MustReadFrom(strings.NewReader("hello")) + }) + }) + t.Run("BufferWriteTo", func(t *testing.T) { + t.Parallel() + buf := bf.Get() + buf.MustWrite([]byte("hello")) + n, err := buf.WriteTo(io.Discard) + if err != nil { + t.Fatal(err) + } + if n != 5 { + t.Fatalf("The number of bytes is not five: %d", n) + } + bf.MustPut(buf) + if _, err = buf.WriteTo(io.Discard); err == nil { + t.Fatal("The error is nil trying to use a returned buffer") + } + assertPanic(t, func() { + buf.MustWriteTo(io.Discard) + }) + }) + t.Run("BufferNext", func(t *testing.T) { + t.Parallel() + buf := bf.Get() + buf.MustWrite([]byte("hello")) + p := buf.Next(5) + if string(p) != "hello" { + t.Fatalf("The string is not hello") + } + bf.MustPut(buf) + if p = buf.Next(5); p != nil { + t.Fatalf("The slice is not nil") + } + }) +}