From 91d9678f9b1c1852aba406112e94fc6af830fdc9 Mon Sep 17 00:00:00 2001 From: Andrew Gaffney Date: Mon, 24 Apr 2023 20:46:43 -0500 Subject: [PATCH] feat: method to stop chainsync process Fixes #236 --- protocol/chainsync/client.go | 33 ++++++++++++++++++++++++--------- protocol/chainsync/error.go | 8 ++++++++ 2 files changed, 32 insertions(+), 9 deletions(-) diff --git a/protocol/chainsync/client.go b/protocol/chainsync/client.go index fc9890c3..d73310e9 100644 --- a/protocol/chainsync/client.go +++ b/protocol/chainsync/client.go @@ -2,10 +2,11 @@ package chainsync import ( "fmt" + "sync" + "github.com/blinklabs-io/gouroboros/ledger" "github.com/blinklabs-io/gouroboros/protocol" "github.com/blinklabs-io/gouroboros/protocol/common" - "sync" ) // Client implements the ChainSync client @@ -146,9 +147,12 @@ func (c *Client) Sync(intersectPoints []common.Point) error { func (c *Client) syncLoop() { for { // Wait for a block to be received - if _, ok := <-c.readyForNextBlockChan; !ok { + if ready, ok := <-c.readyForNextBlockChan; !ok { // Channel is closed, which means we're shutting down return + } else if !ready { + // Sync was cancelled + return } c.busyMutex.Lock() // Request the next block @@ -171,10 +175,7 @@ func (c *Client) handleRollForward(msgGeneric protocol.Message) error { if c.config.RollForwardFunc == nil { return fmt.Errorf("received chain-sync RollForward message but no callback function is defined") } - // Signal that we're ready for the next block after we finish handling this one - defer func() { - c.readyForNextBlockChan <- true - }() + var callbackErr error if c.Mode() == protocol.ProtocolModeNodeToNode { msg := msgGeneric.(*MsgRollForwardNtN) var blockHeader interface{} @@ -205,7 +206,7 @@ func (c *Client) handleRollForward(msgGeneric protocol.Message) error { } } // Call the user callback function - return c.config.RollForwardFunc(blockType, blockHeader, msg.Tip) + callbackErr = c.config.RollForwardFunc(blockType, blockHeader, msg.Tip) } else { msg := msgGeneric.(*MsgRollForwardNtC) blk, err := ledger.NewBlockFromCbor(msg.BlockType(), msg.BlockCbor()) @@ -213,8 +214,15 @@ func (c *Client) handleRollForward(msgGeneric protocol.Message) error { return err } // Call the user callback function - return c.config.RollForwardFunc(msg.BlockType(), blk, msg.Tip) + callbackErr = c.config.RollForwardFunc(msg.BlockType(), blk, msg.Tip) } + if callbackErr == StopSyncProcessError { + // Signal that we're cancelling the sync + c.readyForNextBlockChan <- false + } + // Signal that we're ready for the next block + c.readyForNextBlockChan <- true + return nil } func (c *Client) handleRollBackward(msgGeneric protocol.Message) error { @@ -223,7 +231,14 @@ func (c *Client) handleRollBackward(msgGeneric protocol.Message) error { } msg := msgGeneric.(*MsgRollBackward) // Call the user callback function - return c.config.RollBackwardFunc(msg.Point, msg.Tip) + callbackErr := c.config.RollBackwardFunc(msg.Point, msg.Tip) + if callbackErr == StopSyncProcessError { + // Signal that we're cancelling the sync + c.readyForNextBlockChan <- false + } + // Signal that we're ready for the next block + c.readyForNextBlockChan <- true + return nil } func (c *Client) handleIntersectFound(msgGeneric protocol.Message) error { diff --git a/protocol/chainsync/error.go b/protocol/chainsync/error.go index 407c079b..7c18eb1a 100644 --- a/protocol/chainsync/error.go +++ b/protocol/chainsync/error.go @@ -1,5 +1,9 @@ package chainsync +import ( + "fmt" +) + // IntersectNotFoundError represents a failure to find a chain intersection type IntersectNotFoundError struct { } @@ -7,3 +11,7 @@ type IntersectNotFoundError struct { func (e IntersectNotFoundError) Error() string { return "chain intersection not found" } + +// StopChainSync is used as a special return value from a RollForward or RollBackward handler function +// to signify that the sync process should be stopped +var StopSyncProcessError = fmt.Errorf("stop sync process")