From 414ae1bb9ed152f441b372db076776b3a937ebf5 Mon Sep 17 00:00:00 2001 From: Janusz Marcinkiewicz Date: Thu, 30 Jul 2020 09:22:05 +0200 Subject: [PATCH] Add Reader/Writer constructors with custom buffer --- reader.go | 17 +++++++++++++++-- reader_test.go | 25 +++++++++++++++++++++++++ writer.go | 24 ++++++++++++++++++------ writer_test.go | 24 ++++++++++++++++++++++++ 4 files changed, 82 insertions(+), 8 deletions(-) diff --git a/reader.go b/reader.go index 75be62a..6918d3e 100644 --- a/reader.go +++ b/reader.go @@ -50,11 +50,24 @@ func NewReader(r io.Reader) *Reader { } // NewReaderSize returns a new *Reader that -// reads from 'r' and has a buffer size 'n' +// reads from 'r' and has a buffer size 'n'. func NewReaderSize(r io.Reader, n int) *Reader { + buf := make([]byte, 0, max(n, minReaderSize)) + return NewReaderBuf(r, buf) +} + +// NewReaderBuf returns a new *Reader that +// reads from 'r' and uses 'buf' as a buffer. +// 'buf' is not used when has smaller capacity than 16, +// custom buffer is allocated instead. +func NewReaderBuf(r io.Reader, buf []byte) *Reader { + if cap(buf) < minReaderSize { + buf = make([]byte, 0, minReaderSize) + } + buf = buf[:0] rd := &Reader{ r: r, - data: make([]byte, 0, max(minReaderSize, n)), + data: buf, } if s, ok := r.(io.Seeker); ok { rd.rs = s diff --git a/reader_test.go b/reader_test.go index e96303a..c9a2d0a 100644 --- a/reader_test.go +++ b/reader_test.go @@ -396,3 +396,28 @@ func TestReadFullPerf(t *testing.T) { t.Logf("called Read() on the underlying reader %d times to fill %d buffers", c.count, size/r.BufferSize()) } + +func TestReaderBufCreation(t *testing.T) { + tests := []struct { + name string + buffer []byte + size int + }{ + {name: "nil", buffer: nil, size: minReaderSize}, + {name: "empty", buffer: []byte{}, size: minReaderSize}, + {name: "allocated", buffer: make([]byte, 0, 200), size: 200}, + {name: "filled", buffer: make([]byte, 200), size: 200}, + } + + for _, test := range tests { + var b bytes.Buffer + r := NewReaderBuf(&b, test.buffer) + + if r.BufferSize() != test.size { + t.Errorf("%s: unequal buffer size (got: %d, expected: %d)", test.name, r.BufferSize(), test.size) + } + if r.Buffered() != 0 { + t.Errorf("%s: unequal buffered bytes (got: %d, expected: 0)", test.name, r.Buffered()) + } + } +} diff --git a/writer.go b/writer.go index 2dc392a..4d6ea15 100644 --- a/writer.go +++ b/writer.go @@ -29,16 +29,28 @@ func NewWriter(w io.Writer) *Writer { } } -// NewWriterSize returns a new writer -// that writes to 'w' and has a buffer -// that is 'size' bytes. -func NewWriterSize(w io.Writer, size int) *Writer { - if wr, ok := w.(*Writer); ok && cap(wr.buf) >= size { +// NewWriterSize returns a new writer that +// writes to 'w' and has a buffer size 'n'. +func NewWriterSize(w io.Writer, n int) *Writer { + if wr, ok := w.(*Writer); ok && cap(wr.buf) >= n { return wr } + buf := make([]byte, 0, max(n, minWriterSize)) + return NewWriterBuf(w, buf) +} + +// NewWriterBuf returns a new writer +// that writes to 'w' and has 'buf' as a buffer. +// 'buf' is not used when has smaller capacity than 18, +// custom buffer is allocated instead. +func NewWriterBuf(w io.Writer, buf []byte) *Writer { + if cap(buf) < minWriterSize { + buf = make([]byte, 0, minWriterSize) + } + buf = buf[:0] return &Writer{ w: w, - buf: make([]byte, 0, max(size, minWriterSize)), + buf: buf, } } diff --git a/writer_test.go b/writer_test.go index 3dcf3a5..5868d29 100644 --- a/writer_test.go +++ b/writer_test.go @@ -235,5 +235,29 @@ func TestReadFrom(t *testing.T) { if !bytes.Equal(buf.Bytes(), bts) { t.Fatal("buf.Bytes() and input are not equal") } +} + +func TestWriterBufCreation(t *testing.T) { + tests := []struct { + name string + buffer []byte + size int + }{ + {name: "nil", buffer: nil, size: minWriterSize}, + {name: "empty", buffer: []byte{}, size: minWriterSize}, + {name: "allocated", buffer: make([]byte, 0, 200), size: 200}, + {name: "filled", buffer: make([]byte, 200), size: 200}, + } + + for _, test := range tests { + var b bytes.Buffer + w := NewWriterBuf(&b, test.buffer) + if w.BufferSize() != test.size { + t.Errorf("%s: unequal buffer size (got: %d, expected: %d)", test.name, w.BufferSize(), test.size) + } + if w.Buffered() != 0 { + t.Errorf("%s: unequal buffered bytes (got: %d, expected: 0)", test.name, w.Buffered()) + } + } }