diff --git a/pool/bytes.go b/pool/bytes.go index e0554fc..23cb054 100644 --- a/pool/bytes.go +++ b/pool/bytes.go @@ -61,14 +61,16 @@ func (cf BufferFactory) Get() *Buffer { // Buffer is a wrapper around bytes.Buffer that can only be returned to a pool once. type Buffer struct { *bytes.Buffer - o *sync.Once - p *BufferFactory + o *sync.Once + co *sync.Once + p *BufferFactory } // WithParent sets the parent of the buffer. This is useful for chaining factories, and for facilitating // in-line buffer return with functions like Buffer.Close(). Be mindful, however, that this adds a bit of overhead. func (c Buffer) WithParent(p *BufferFactory) *Buffer { c.p = p + c.co = &sync.Once{} return &c } @@ -423,6 +425,18 @@ func (c Buffer) Next(n int) []byte { return c.Buffer.Next(n) } +// IsClosed returns true if the buffer has been returned to the pool. +func (c Buffer) IsClosed() bool { + var closed = true + if c.co == nil { + c.co = &sync.Once{} + } + c.co.Do(func() { + closed = false + }) + return closed +} + // Close implements io.Closer. It returns the buffer to the pool. This func (c Buffer) Close() error { if c.Buffer == nil { @@ -430,8 +444,12 @@ func (c Buffer) Close() error { } if c.p == nil { return errors.New( - "buffer does not know it's parent pool and therefore cannot return itself, use Buffer.WithParent to set the parent pool", + "buffer does not know it's parent pool and therefore cannot return itself, use Buffer.WithParent", ) } - return c.p.Put(&c) + var err = ErrBufferReturned + c.co.Do(func() { + err = c.p.Put(&c) + }) + return err } diff --git a/pool/bytes_test.go b/pool/bytes_test.go index 35e780f..6f07b0e 100644 --- a/pool/bytes_test.go +++ b/pool/bytes_test.go @@ -532,9 +532,15 @@ func TestBufferFactory(t *testing.T) { buf.MustWrite([]byte("hello")) buf.MustWrite([]byte("world")) buf.MustWrite([]byte("!")) + if buf.IsClosed() { + t.Fatalf("The buffer is closed before closing") + } if err := buf.Close(); err == nil { t.Fatal("The error is nil after closing the buffer with no parent") } + if buf.IsClosed() { + t.Fatalf("The buffer is closed after failing to close") + } if buf.String() != "helloworld!" { t.Fatalf("The string is not 'helloworld!' after unsuccessful close: %v", buf.String()) } @@ -542,6 +548,12 @@ func TestBufferFactory(t *testing.T) { if err := buf.Close(); err != nil { t.Fatal(err) } + if !buf.IsClosed() { + t.Fatalf("The buffer is not closed after closing") + } + if err := buf.Close(); err == nil { + t.Fatal("The error is nil after closing an already closed buffer") + } }) t.Run("BufferCannotClose", func(t *testing.T) { t.Parallel()