Skip to content

Commit

Permalink
Add Reader/Writer constructors with custom buffer
Browse files Browse the repository at this point in the history
  • Loading branch information
VirrageS authored and philhofer committed Jul 30, 2020
1 parent bb6d471 commit 414ae1b
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 8 deletions.
17 changes: 15 additions & 2 deletions reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,24 @@ func NewReader(r io.Reader) *Reader {
}

// NewReaderSize returns a new *Reader that
// reads from 'r' and has a buffer size 'n'
// reads from 'r' and has a buffer size 'n'.
func NewReaderSize(r io.Reader, n int) *Reader {
buf := make([]byte, 0, max(n, minReaderSize))
return NewReaderBuf(r, buf)
}

// NewReaderBuf returns a new *Reader that
// reads from 'r' and uses 'buf' as a buffer.
// 'buf' is not used when has smaller capacity than 16,
// custom buffer is allocated instead.
func NewReaderBuf(r io.Reader, buf []byte) *Reader {
if cap(buf) < minReaderSize {
buf = make([]byte, 0, minReaderSize)
}
buf = buf[:0]
rd := &Reader{
r: r,
data: make([]byte, 0, max(minReaderSize, n)),
data: buf,
}
if s, ok := r.(io.Seeker); ok {
rd.rs = s
Expand Down
25 changes: 25 additions & 0 deletions reader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -396,3 +396,28 @@ func TestReadFullPerf(t *testing.T) {

t.Logf("called Read() on the underlying reader %d times to fill %d buffers", c.count, size/r.BufferSize())
}

func TestReaderBufCreation(t *testing.T) {
tests := []struct {
name string
buffer []byte
size int
}{
{name: "nil", buffer: nil, size: minReaderSize},
{name: "empty", buffer: []byte{}, size: minReaderSize},
{name: "allocated", buffer: make([]byte, 0, 200), size: 200},
{name: "filled", buffer: make([]byte, 200), size: 200},
}

for _, test := range tests {
var b bytes.Buffer
r := NewReaderBuf(&b, test.buffer)

if r.BufferSize() != test.size {
t.Errorf("%s: unequal buffer size (got: %d, expected: %d)", test.name, r.BufferSize(), test.size)
}
if r.Buffered() != 0 {
t.Errorf("%s: unequal buffered bytes (got: %d, expected: 0)", test.name, r.Buffered())
}
}
}
24 changes: 18 additions & 6 deletions writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,28 @@ func NewWriter(w io.Writer) *Writer {
}
}

// NewWriterSize returns a new writer
// that writes to 'w' and has a buffer
// that is 'size' bytes.
func NewWriterSize(w io.Writer, size int) *Writer {
if wr, ok := w.(*Writer); ok && cap(wr.buf) >= size {
// NewWriterSize returns a new writer that
// writes to 'w' and has a buffer size 'n'.
func NewWriterSize(w io.Writer, n int) *Writer {
if wr, ok := w.(*Writer); ok && cap(wr.buf) >= n {
return wr
}
buf := make([]byte, 0, max(n, minWriterSize))
return NewWriterBuf(w, buf)
}

// NewWriterBuf returns a new writer
// that writes to 'w' and has 'buf' as a buffer.
// 'buf' is not used when has smaller capacity than 18,
// custom buffer is allocated instead.
func NewWriterBuf(w io.Writer, buf []byte) *Writer {
if cap(buf) < minWriterSize {
buf = make([]byte, 0, minWriterSize)
}
buf = buf[:0]
return &Writer{
w: w,
buf: make([]byte, 0, max(size, minWriterSize)),
buf: buf,
}
}

Expand Down
24 changes: 24 additions & 0 deletions writer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -235,5 +235,29 @@ func TestReadFrom(t *testing.T) {
if !bytes.Equal(buf.Bytes(), bts) {
t.Fatal("buf.Bytes() and input are not equal")
}
}

func TestWriterBufCreation(t *testing.T) {
tests := []struct {
name string
buffer []byte
size int
}{
{name: "nil", buffer: nil, size: minWriterSize},
{name: "empty", buffer: []byte{}, size: minWriterSize},
{name: "allocated", buffer: make([]byte, 0, 200), size: 200},
{name: "filled", buffer: make([]byte, 200), size: 200},
}

for _, test := range tests {
var b bytes.Buffer
w := NewWriterBuf(&b, test.buffer)

if w.BufferSize() != test.size {
t.Errorf("%s: unequal buffer size (got: %d, expected: %d)", test.name, w.BufferSize(), test.size)
}
if w.Buffered() != 0 {
t.Errorf("%s: unequal buffered bytes (got: %d, expected: 0)", test.name, w.Buffered())
}
}
}

0 comments on commit 414ae1b

Please sign in to comment.