diff --git a/unixfs/io/dagreader_test.go b/unixfs/io/dagreader_test.go index e3d3d042b23..7cbe35bb5f4 100644 --- a/unixfs/io/dagreader_test.go +++ b/unixfs/io/dagreader_test.go @@ -122,6 +122,41 @@ func TestSeekAndReadLarge(t *testing.T) { } } +func TestReadAndCancel(t *testing.T) { + dserv := testu.GetDAGServ() + inbuf := make([]byte, 20000) + rand.Read(inbuf) + + node := testu.GetNode(t, dserv, inbuf, testu.UseProtoBufLeaves) + ctx, closer := context.WithCancel(context.Background()) + defer closer() + + reader, err := NewDagReader(ctx, node, dserv) + if err != nil { + t.Fatal(err) + } + + ctx, cancel := context.WithCancel(context.Background()) + buf := make([]byte, 100) + _, err = reader.CtxReadFull(ctx, buf) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(buf, inbuf[0:100]) { + t.Fatal("read failed") + } + cancel() + + b, err := ioutil.ReadAll(reader) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(inbuf[100:], b) { + t.Fatal("buffers not equal") + } +} + func TestRelativeSeek(t *testing.T) { dserv := testu.GetDAGServ() ctx, closer := context.WithCancel(context.Background()) diff --git a/unixfs/io/pbdagreader.go b/unixfs/io/pbdagreader.go index ce53d67118d..42d903aac26 100644 --- a/unixfs/io/pbdagreader.go +++ b/unixfs/io/pbdagreader.go @@ -68,16 +68,13 @@ func NewPBFileReader(ctx context.Context, n *mdag.ProtoNode, pb *ftpb.Data, serv const preloadSize = 10 -func (dr *PBDagReader) preloadNextNodes(ctx context.Context) { - beg := dr.linkPosition +func (dr *PBDagReader) preload(ctx context.Context, beg int) { end := beg + preloadSize if end >= len(dr.links) { end = len(dr.links) } - for i, p := range ipld.GetNodes(ctx, dr.serv, dr.links[beg:end]) { - dr.promises[beg+i] = p - } + copy(dr.promises[beg:], ipld.GetNodes(ctx, dr.serv, dr.links[beg:end])) } // precalcNextBuf follows the next link in line and loads it from the @@ -92,15 +89,42 @@ func (dr *PBDagReader) precalcNextBuf(ctx context.Context) error { return io.EOF } - if dr.promises[dr.linkPosition] == nil { - dr.preloadNextNodes(ctx) + // If we drop to <= preloadSize/2 preloading nodes, preload the next 10. + for i := dr.linkPosition; i < dr.linkPosition+preloadSize/2 && i < len(dr.promises); i++ { + // TODO: check if canceled. + if dr.promises[i] == nil { + dr.preload(ctx, i) + break + } } nxt, err := dr.promises[dr.linkPosition].Get(ctx) - if err != nil { + dr.promises[dr.linkPosition] = nil + switch err { + case nil: + case context.DeadlineExceeded, context.Canceled: + err = ctx.Err() + if err != nil { + return ctx.Err() + } + // In this case, the context used to *preload* the node has been canceled. + // We need to retry the load with our context and we might as + // well preload some extra nodes while we're at it. + // + // Note: When using `Read`, this code will never execute as + // `Read` will use the global context. It only runs if the user + // explicitly reads with a custom context (e.g., by calling + // `CtxReadFull`). + dr.preload(ctx, dr.linkPosition) + nxt, err = dr.promises[dr.linkPosition].Get(ctx) + dr.promises[dr.linkPosition] = nil + if err != nil { + return err + } + default: return err } - dr.promises[dr.linkPosition] = nil + dr.linkPosition++ switch nxt := nxt.(type) {