Skip to content

Commit

Permalink
Add cancel request and wait function (#185)
Browse files Browse the repository at this point in the history
  • Loading branch information
hannahhoward authored Aug 4, 2021
1 parent 78ef068 commit a81371b
Show file tree
Hide file tree
Showing 9 changed files with 154 additions and 32 deletions.
12 changes: 8 additions & 4 deletions graphsync.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,12 @@ const (
RequestCancelled = ResponseStatusCode(35)
)

// RequestContextCancelledErr is an error message received on the error channel when the request context given by the user is cancelled/times out
type RequestContextCancelledErr struct{}
// RequestClientCancelledErr is an error message received on the error channel when the request is cancelled on by the client code,
// either by closing the passed request context or calling CancelRequest
type RequestClientCancelledErr struct{}

func (e RequestContextCancelledErr) Error() string {
return "Request Context Cancelled"
func (e RequestClientCancelledErr) Error() string {
return "Request Cancelled By Client"
}

// RequestFailedBusyErr is an error message received on the error channel when the peer is busy
Expand Down Expand Up @@ -369,4 +370,7 @@ type GraphExchange interface {

// CancelResponse cancels an in progress response
CancelResponse(peer.ID, RequestID) error

// CancelRequest cancels an in progress request
CancelRequest(context.Context, RequestID) error
}
7 changes: 6 additions & 1 deletion impl/graphsync.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ func (gs *GraphSync) RegisterIncomingRequestHook(hook graphsync.OnIncomingReques
return gs.incomingRequestHooks.Register(hook)
}

// RegisterIncomingRequestHook adds a hook that runs when a new incoming request is added
// RegisterIncomingRequestQueuedHook adds a hook that runs when a new incoming request is added
// to the responder's task queue.
func (gs *GraphSync) RegisterIncomingRequestQueuedHook(hook graphsync.OnIncomingRequestQueuedHook) graphsync.UnregisterHookFunc {
return gs.incomingRequestQueuedHooks.Register(hook)
Expand Down Expand Up @@ -296,6 +296,11 @@ func (gs *GraphSync) CancelResponse(p peer.ID, requestID graphsync.RequestID) er
return gs.responseManager.CancelResponse(p, requestID)
}

// CancelRequest cancels an in progress request
func (gs *GraphSync) CancelRequest(ctx context.Context, requestID graphsync.RequestID) error {
return gs.requestManager.CancelRequest(ctx, requestID)
}

type graphSyncReceiver GraphSync

func (gsr *graphSyncReceiver) graphSync() *GraphSync {
Expand Down
4 changes: 2 additions & 2 deletions impl/graphsync_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,7 @@ func TestNetworkDisconnect(t *testing.T) {

testutil.AssertReceive(ctx, t, networkError, &err, "should receive network error")
testutil.AssertReceive(ctx, t, errChan, &err, "should receive an error")
require.EqualError(t, err, graphsync.RequestContextCancelledErr{}.Error())
require.EqualError(t, err, graphsync.RequestClientCancelledErr{}.Error())
testutil.AssertReceive(ctx, t, receiverError, &err, "should receive an error on receiver side")
}

Expand Down Expand Up @@ -653,7 +653,7 @@ func TestConnectFail(t *testing.T) {
var err error
testutil.AssertReceive(ctx, t, reqNetworkError, &err, "should receive network error")
testutil.AssertReceive(ctx, t, errChan, &err, "should receive an error")
require.EqualError(t, err, graphsync.RequestContextCancelledErr{}.Error())
require.EqualError(t, err, graphsync.RequestClientCancelledErr{}.Error())
}

func TestGraphsyncRoundTripAlternatePersistenceAndNodes(t *testing.T) {
Expand Down
7 changes: 7 additions & 0 deletions ipldutil/traverser.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@ import (
"github.com/ipld/go-ipld-prime/traversal/selector"
)

/* TODO: This traverser creates an extra go-routine and is quite complicated, in order to give calling code control of
a selector traversal. If it were implemented inside of go-ipld-primes traversal library, with access to private functions,
it could be done without an extra go-routine, avoiding the possibility of races and simplifying implementation. This has
been documented here: https://github.com/ipld/go-ipld-prime/issues/213 -- and when this issue is implemented, this traverser
can go away */

var defaultVisitor traversal.AdvVisitFn = func(traversal.Progress, ipld.Node, traversal.VisitReason) error { return nil }

// ContextCancelError is a sentinel that indicates the passed in context
Expand Down Expand Up @@ -137,6 +143,7 @@ func (t *traverser) writeDone(err error) {
func (t *traverser) start() {
select {
case <-t.ctx.Done():
close(t.stopped)
return
case t.awaitRequest <- struct{}{}:
}
Expand Down
17 changes: 17 additions & 0 deletions ipldutil/traverser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"context"
"testing"
"time"

blocks "github.com/ipfs/go-block-format"
ipld "github.com/ipld/go-ipld-prime"
Expand All @@ -21,6 +22,22 @@ import (
func TestTraverser(t *testing.T) {
ctx := context.Background()

t.Run("started with shutdown context, then shutdown", func(t *testing.T) {
cancelledCtx, cancel := context.WithCancel(ctx)
cancel()
testdata := testutil.NewTestIPLDTree()
ssb := builder.NewSelectorSpecBuilder(basicnode.Prototype.Any)
sel := ssb.ExploreRecursive(selector.RecursionLimitNone(), ssb.ExploreAll(ssb.ExploreRecursiveEdge())).Node()
traverser := TraversalBuilder{
Root: testdata.RootNodeLnk,
Selector: sel,
}.Start(cancelledCtx)
timeoutCtx, cancel := context.WithTimeout(ctx, 2*time.Second)
defer cancel()
traverser.Shutdown(timeoutCtx)
require.NoError(t, timeoutCtx.Err())
})

t.Run("traverses correctly, simple struct", func(t *testing.T) {
testdata := testutil.NewTestIPLDTree()
ssb := builder.NewSelectorSpecBuilder(basicnode.Prototype.Any)
Expand Down
10 changes: 5 additions & 5 deletions requestmanager/executor/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ type ExecutionEnv struct {
type RequestExecution struct {
Ctx context.Context
P peer.ID
NetworkError chan error
TerminalError chan error
Request gsmsg.GraphSyncRequest
LastResponse *atomic.Value
DoNotSendCids *cid.Set
Expand All @@ -54,7 +54,7 @@ func (ee ExecutionEnv) Start(re RequestExecution) (chan graphsync.ResponseProgre
inProgressErr: make(chan error),
ctx: re.Ctx,
p: re.P,
networkError: re.NetworkError,
terminalError: re.TerminalError,
request: re.Request,
lastResponse: re.LastResponse,
doNotSendCids: re.DoNotSendCids,
Expand All @@ -73,7 +73,7 @@ type requestExecutor struct {
inProgressErr chan error
ctx context.Context
p peer.ID
networkError chan error
terminalError chan error
request gsmsg.GraphSyncRequest
lastResponse *atomic.Value
nodeStyleChooser traversal.LinkTargetNodePrototypeChooser
Expand Down Expand Up @@ -153,9 +153,9 @@ func (re *requestExecutor) run() {
}
}
select {
case networkError := <-re.networkError:
case terminalError := <-re.terminalError:
select {
case re.inProgressErr <- networkError:
case re.inProgressErr <- terminalError:
case <-re.env.Ctx.Done():
}
default:
Expand Down
67 changes: 52 additions & 15 deletions requestmanager/requestmanager.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,12 @@ type inProgressRequestStatus struct {
startTime time.Time
cancelFn func()
p peer.ID
networkError chan error
terminalError chan error
resumeMessages chan []graphsync.ExtensionData
pauseMessages chan struct{}
paused bool
lastResponse atomic.Value
onTerminated []chan error
}

// PeerHandler is an interface that can send requests to peers
Expand Down Expand Up @@ -234,8 +235,10 @@ func (rm *RequestManager) singleErrorResponse(err error) (chan graphsync.Respons
}

type cancelRequestMessage struct {
requestID graphsync.RequestID
isPause bool
requestID graphsync.RequestID
isPause bool
onTerminated chan error
terminalError error
}

func (rm *RequestManager) cancelRequest(requestID graphsync.RequestID,
Expand All @@ -244,7 +247,7 @@ func (rm *RequestManager) cancelRequest(requestID graphsync.RequestID,
cancelMessageChannel := rm.messages
for cancelMessageChannel != nil || incomingResponses != nil || incomingErrors != nil {
select {
case cancelMessageChannel <- &cancelRequestMessage{requestID, false}:
case cancelMessageChannel <- &cancelRequestMessage{requestID, false, nil, nil}:
cancelMessageChannel = nil
// clear out any remaining responses, in case and "incoming reponse"
// messages get processed before our cancel message
Expand All @@ -262,6 +265,12 @@ func (rm *RequestManager) cancelRequest(requestID graphsync.RequestID,
}
}

// CancelRequest cancels the given request ID and waits for the request to terminate
func (rm *RequestManager) CancelRequest(ctx context.Context, requestID graphsync.RequestID) error {
terminated := make(chan error, 1)
return rm.sendSyncMessage(&cancelRequestMessage{requestID, false, terminated, graphsync.RequestClientCancelledErr{}}, terminated, ctx.Done())
}

type processResponseMessage struct {
p peer.ID
responses []gsmsg.GraphSyncResponse
Expand All @@ -288,7 +297,7 @@ type unpauseRequestMessage struct {
// Can also send extensions with unpause
func (rm *RequestManager) UnpauseRequest(requestID graphsync.RequestID, extensions ...graphsync.ExtensionData) error {
response := make(chan error, 1)
return rm.sendSyncMessage(&unpauseRequestMessage{requestID, extensions, response}, response)
return rm.sendSyncMessage(&unpauseRequestMessage{requestID, extensions, response}, response, nil)
}

type pauseRequestMessage struct {
Expand All @@ -299,18 +308,22 @@ type pauseRequestMessage struct {
// PauseRequest pauses an in progress request (may take 1 or more blocks to process)
func (rm *RequestManager) PauseRequest(requestID graphsync.RequestID) error {
response := make(chan error, 1)
return rm.sendSyncMessage(&pauseRequestMessage{requestID, response}, response)
return rm.sendSyncMessage(&pauseRequestMessage{requestID, response}, response, nil)
}

func (rm *RequestManager) sendSyncMessage(message requestManagerMessage, response chan error) error {
func (rm *RequestManager) sendSyncMessage(message requestManagerMessage, response chan error, done <-chan struct{}) error {
select {
case <-rm.ctx.Done():
return errors.New("Context Cancelled")
case <-done:
return errors.New("Context Cancelled")
case rm.messages <- message:
}
select {
case <-rm.ctx.Done():
return errors.New("Context Cancelled")
case <-done:
return errors.New("Context Cancelled")
case err := <-response:
return err
}
Expand Down Expand Up @@ -374,9 +387,9 @@ func (nrm *newRequestMessage) setupRequest(requestID graphsync.RequestID, rm *Re
p := nrm.p
resumeMessages := make(chan []graphsync.ExtensionData, 1)
pauseMessages := make(chan struct{}, 1)
networkError := make(chan error, 1)
terminalError := make(chan error, 1)
requestStatus := &inProgressRequestStatus{
ctx: ctx, startTime: time.Now(), cancelFn: cancel, p: p, resumeMessages: resumeMessages, pauseMessages: pauseMessages, networkError: networkError,
ctx: ctx, startTime: time.Now(), cancelFn: cancel, p: p, resumeMessages: resumeMessages, pauseMessages: pauseMessages, terminalError: terminalError,
}
lastResponse := &requestStatus.lastResponse
lastResponse.Store(gsmsg.NewResponse(request.ID(), graphsync.RequestAcknowledged))
Expand All @@ -392,7 +405,7 @@ func (nrm *newRequestMessage) setupRequest(requestID graphsync.RequestID, rm *Re
Ctx: ctx,
P: p,
Request: request,
NetworkError: networkError,
TerminalError: terminalError,
LastResponse: lastResponse,
DoNotSendCids: doNotSendCids,
NodePrototypeChooser: hooksResult.CustomChooser,
Expand Down Expand Up @@ -421,14 +434,38 @@ func (trm *terminateRequestMessage) handle(rm *RequestManager) {
}
delete(rm.inProgressRequestStatuses, trm.requestID)
rm.asyncLoader.CleanupRequest(trm.requestID)
if ok {
for _, onTerminated := range ipr.onTerminated {
select {
case <-rm.ctx.Done():
case onTerminated <- nil:
}
}
}
}

func (crm *cancelRequestMessage) handle(rm *RequestManager) {
inProgressRequestStatus, ok := rm.inProgressRequestStatuses[crm.requestID]
if !ok {
if crm.onTerminated != nil {
select {
case crm.onTerminated <- errors.New("request not found"):
case <-rm.ctx.Done():
}
}
return
}

if crm.onTerminated != nil {
inProgressRequestStatus.onTerminated = append(inProgressRequestStatus.onTerminated, crm.onTerminated)
}
if crm.terminalError != nil {
select {
case inProgressRequestStatus.terminalError <- crm.terminalError:
default:
}
}

rm.sendRequest(inProgressRequestStatus.p, gsmsg.CancelRequest(crm.requestID))
if crm.isPause {
inProgressRequestStatus.paused = true
Expand Down Expand Up @@ -488,8 +525,8 @@ func (rm *RequestManager) processExtensionsForResponse(p peer.ID, response gsmsg
}
responseError := rm.generateResponseErrorFromStatus(graphsync.RequestFailedUnknown)
select {
case requestStatus.networkError <- responseError:
case <-requestStatus.ctx.Done():
case requestStatus.terminalError <- responseError:
default:
}
rm.sendRequest(p, gsmsg.CancelRequest(response.RequestID()))
requestStatus.cancelFn()
Expand All @@ -505,8 +542,8 @@ func (rm *RequestManager) processTerminations(responses []gsmsg.GraphSyncRespons
requestStatus := rm.inProgressRequestStatuses[response.RequestID()]
responseError := rm.generateResponseErrorFromStatus(response.Status())
select {
case requestStatus.networkError <- responseError:
case <-requestStatus.ctx.Done():
case requestStatus.terminalError <- responseError:
default:
}
requestStatus.cancelFn()
}
Expand Down Expand Up @@ -542,7 +579,7 @@ func (rm *RequestManager) processBlockHooks(p peer.ID, response graphsync.Respon
_, isPause := result.Err.(hooks.ErrPaused)
select {
case <-rm.ctx.Done():
case rm.messages <- &cancelRequestMessage{response.RequestID(), isPause}:
case rm.messages <- &cancelRequestMessage{response.RequestID(), isPause, nil, nil}:
}
}
return result.Err
Expand Down
Loading

0 comments on commit a81371b

Please sign in to comment.