Skip to content

Commit

Permalink
reader: fix Skip(n) where n >= buffer size
Browse files Browse the repository at this point in the history
The (*Reader).Skip() code did not handle skip sizes
above the buffer size correctly, because the inner
reader loop only looped while 'r.buffered() < n', and
that condition is obviously violated when the reader
successfully buffers more bytes than are to be skipped.

Refactor this method to be a little shorter and hopefully
much clearer (and more correct) and add test coverage that
exercises this code better.
  • Loading branch information
philhofer committed Nov 24, 2020
1 parent 414ae1b commit 9bcb9ca
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 50 deletions.
75 changes: 31 additions & 44 deletions reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,10 @@
//
package fwd

import "io"
import (
"io"
"os"
)

const (
// DefaultReaderSize is the default size of the read buffer
Expand Down Expand Up @@ -187,6 +190,19 @@ func (r *Reader) Peek(n int) ([]byte, error) {
return r.data[r.n : r.n+n], nil
}

// discard(n) discards up to 'n' buffered bytes, and
// and returns the number of bytes discarded
func (r *Reader) discard(n int) int {
inbuf := r.buffered()
if inbuf <= n {
r.n = 0
r.data = r.data[:0]
return inbuf
}
r.n += n
return n
}

// Skip moves the reader forward 'n' bytes.
// Returns the number of bytes skipped and any
// errors encountered. It is analogous to Seek(n, 1).
Expand All @@ -201,33 +217,25 @@ func (r *Reader) Peek(n int) ([]byte, error) {
// will not return `io.EOF` until the next call
// to Read.)
func (r *Reader) Skip(n int) (int, error) {

// fast path
if r.buffered() >= n {
r.n += n
return n, nil
if n < 0 {
return 0, os.ErrInvalid
}

// use seeker implementation
// if we can
if r.rs != nil {
return r.skipSeek(n)
}
// discard some or all of the current buffer
skipped := r.discard(n)

// loop on filling
// and then erasing
o := n
for r.buffered() < n && r.state == nil {
// if we can Seek() through the remaining bytes, do that
if n > skipped && r.rs != nil {
nn, err := r.rs.Seek(int64(n-skipped), 1)
return int(nn) + skipped, err
}
// otherwise, keep filling the buffer
// and discarding it up to 'n'
for skipped < n && r.state == nil {
r.more()
// we can skip forward
// up to r.buffered() bytes
step := min(r.buffered(), n)
r.n += step
n -= step
skipped += r.discard(n - skipped)
}
// at this point, n should be
// 0 if everything went smoothly
return o - n, r.noEOF()
return skipped, r.noEOF()
}

// Next returns the next 'n' bytes in the stream.
Expand Down Expand Up @@ -262,20 +270,6 @@ func (r *Reader) Next(n int) ([]byte, error) {
return out, nil
}

// skipSeek uses the io.Seeker to seek forward.
// only call this function when n > r.buffered()
func (r *Reader) skipSeek(n int) (int, error) {
o := r.buffered()
// first, clear buffer
n -= o
r.n = 0
r.data = r.data[:0]

// then seek forward remaning bytes
i, err := r.rs.Seek(int64(n), 1)
return int(i) + o, err
}

// Read implements `io.Reader`
func (r *Reader) Read(b []byte) (int, error) {
// if we have data in the buffer, just
Expand Down Expand Up @@ -381,13 +375,6 @@ func (r *Reader) WriteTo(w io.Writer) (int64, error) {
return i, nil
}

func min(a int, b int) int {
if a < b {
return a
}
return b
}

func max(a int, b int) int {
if a < b {
return b
Expand Down
42 changes: 36 additions & 6 deletions reader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,10 @@ func TestReadByte(t *testing.T) {
}
}

func remaining(r *Reader) int {
return r.Buffered() + r.r.(partialReader).r.(*bytes.Reader).Len()
}

func TestSkipNoSeek(t *testing.T) {
bts := randomBts(1024)
rd := NewReaderSize(partialReader{bytes.NewReader(bts)}, 200)
Expand All @@ -145,14 +149,18 @@ func TestSkipNoSeek(t *testing.T) {
t.Fatalf("Skip() returned a nil error, but skipped %d bytes instead of %d", n, 512)
}

if remaining(rd) != 512 {
t.Errorf("expected 512 remaining; got %d", remaining(rd))
}

var b byte
b, err = rd.ReadByte()
if err != nil {
t.Fatal(err)
}

if b != bts[512] {
t.Fatalf("at index %d: %d in; %d out", 512, bts[512], b)
t.Errorf("at index %d: %d in; %d out", 512, bts[512], b)
}

n, err = rd.Skip(10)
Expand All @@ -162,16 +170,38 @@ func TestSkipNoSeek(t *testing.T) {
if n != 10 {
t.Fatalf("Skip() returned a nil error, but skipped %d bytes instead of %d", n, 10)
}
// the number of bytes remaining in the buffer needs
// to comport with the number of bytes we expect to have skipped
if want := 1024 - 512 - 10 - 1; remaining(rd) != want {
t.Errorf("only %d bytes remaining (want %d)?", remaining(rd), want)
}
n, err = rd.Skip(10)
if err != nil {
t.Fatal(err)
}
if n != 10 {
t.Fatalf("Skip(10) a second time returned %d", n)
}
if want := 1024 - 512 - 10 - 10 - 1; remaining(rd) != want {
t.Errorf("only %d bytes remaining (want %d)?", remaining(rd), want)
}
b, err = rd.ReadByte()
if err != nil {
t.Fatalf("second ReadByte(): %s", err)
}
if b != bts[512+10+10+1] {
t.Errorf("expected %d but got %d", bts[512+10+10], b)
}

// now try to skip past the end
rd = NewReaderSize(partialReader{bytes.NewReader(bts)}, 200)

// now try to skip past the end; we expect
// only to skip the number of bytes remaining
want := remaining(rd)
n, err = rd.Skip(2000)
if err != io.ErrUnexpectedEOF {
t.Fatalf("expected error %q; got %q", io.EOF, err)
}
if n != 1024 {
t.Fatalf("expected to skip only 1024 bytes; skipped %d", n)
if n != want {
t.Fatalf("expected to skip only %d bytes; skipped %d", want, n)
}
}

Expand Down

0 comments on commit 9bcb9ca

Please sign in to comment.