diff --git a/chain/neutrino.go b/chain/neutrino.go index 28987f7a32..6f095e3ebe 100644 --- a/chain/neutrino.go +++ b/chain/neutrino.go @@ -41,13 +41,16 @@ type NeutrinoClient struct { rescanErr <-chan error wg sync.WaitGroup started bool - scanning bool finished bool isRescan bool clientMtx sync.Mutex } +func (s *NeutrinoClient) isScanning() bool { + return s.rescanQuit != nil +} + // NewNeutrinoClient creates a new NeutrinoClient struct with a backing // ChainService. func NewNeutrinoClient(chainParams *chaincfg.Params, @@ -335,45 +338,63 @@ func (s *NeutrinoClient) pollCFilter(hash *chainhash.Hash) (*gcs.Filter, error) func (s *NeutrinoClient) Rescan(startHash *chainhash.Hash, addrs []btcutil.Address, outPoints map[wire.OutPoint]btcutil.Address) error { + bestBlock, err := s.CS.BestBlock() + if err != nil { + return fmt.Errorf("Can't get chain service's best block: %s", err) + } + header, err := s.CS.GetBlockHeader(&bestBlock.Hash) + if err != nil { + return fmt.Errorf("Can't get block header for hash %v: %s", + bestBlock.Hash, err) + } + + var inputsToWatch []neutrino.InputWithScript + for op, addr := range outPoints { + addrScript, err := txscript.PayToAddrScript(addr) + if err != nil { + return err + } + + inputsToWatch = append(inputsToWatch, neutrino.InputWithScript{ + OutPoint: op, + PkScript: addrScript, + }) + } + s.clientMtx.Lock() if !s.started { s.clientMtx.Unlock() return fmt.Errorf("can't do a rescan when the chain client " + "is not started") } - if s.scanning { + for s.isScanning() { // Restart the rescan by killing the existing rescan. close(s.rescanQuit) rescan := s.rescan - s.clientMtx.Unlock() - rescan.WaitForShutdown() - s.clientMtx.Lock() + if s.rescan != nil { + s.clientMtx.Unlock() + rescan.WaitForShutdown() + s.clientMtx.Lock() + } + // If the rescan has changed since unlocking, shut down the new + // one as well. + if s.rescan != rescan { + continue + } s.rescan = nil s.rescanErr = nil + break } s.rescanQuit = make(chan struct{}) - s.scanning = true s.finished = false s.lastProgressSent = false s.lastFilteredBlockHeader = nil s.isRescan = true - s.clientMtx.Unlock() - - bestBlock, err := s.CS.BestBlock() - if err != nil { - return fmt.Errorf("Can't get chain service's best block: %s", err) - } - header, err := s.CS.GetBlockHeader(&bestBlock.Hash) - if err != nil { - return fmt.Errorf("Can't get block header for hash %v: %s", - bestBlock.Hash, err) - } // If the wallet is already fully caught up, or the rescan has started // with state that indicates a "fresh" wallet, we'll send a // notification indicating the rescan has "finished". if header.BlockHash() == *startHash { - s.clientMtx.Lock() s.finished = true rescanQuit := s.rescanQuit s.clientMtx.Unlock() @@ -381,12 +402,14 @@ func (s *NeutrinoClient) Rescan(startHash *chainhash.Hash, addrs []btcutil.Addre // Release the lock while dispatching the notification since // it's possible for the notificationHandler to be waiting to // acquire it before receiving the notification. - select { - case s.enqueueNotification <- &RescanFinished{ + ntfn := &RescanFinished{ Hash: startHash, Height: int32(bestBlock.Height), Time: header.Timestamp, - }: + } + select { + case s.enqueueNotification <- ntfn: + s.clientMtx.Lock() case <-s.quit: return nil case <-rescanQuit: @@ -394,20 +417,6 @@ func (s *NeutrinoClient) Rescan(startHash *chainhash.Hash, addrs []btcutil.Addre } } - var inputsToWatch []neutrino.InputWithScript - for op, addr := range outPoints { - addrScript, err := txscript.PayToAddrScript(addr) - if err != nil { - return err - } - - inputsToWatch = append(inputsToWatch, neutrino.InputWithScript{ - OutPoint: op, - PkScript: addrScript, - }) - } - - s.clientMtx.Lock() newRescan := neutrino.NewRescan( &neutrino.RescanChainSource{ ChainService: s.CS, @@ -433,29 +442,35 @@ func (s *NeutrinoClient) Rescan(startHash *chainhash.Hash, addrs []btcutil.Addre // NotifyBlocks replicates the RPC client's NotifyBlocks command. func (s *NeutrinoClient) NotifyBlocks() error { s.clientMtx.Lock() + defer s.clientMtx.Unlock() + // If we're scanning, we're already notifying on blocks. Otherwise, // start a rescan without watching any addresses. - if !s.scanning { - s.clientMtx.Unlock() - return s.NotifyReceived([]btcutil.Address{}) + if !s.isScanning() { + return s.notifyReceived([]btcutil.Address{}) } - s.clientMtx.Unlock() return nil } // NotifyReceived replicates the RPC client's NotifyReceived command. func (s *NeutrinoClient) NotifyReceived(addrs []btcutil.Address) error { s.clientMtx.Lock() + defer s.clientMtx.Unlock() + return s.notifyReceived(addrs) +} + +// notifyReceived replicates the RPC client's NotifyReceived command. +// +// NOTE: The clienMtx MUST be held when invoking. +func (s *NeutrinoClient) notifyReceived(addrs []btcutil.Address) error { // If we have a rescan running, we just need to add the appropriate // addresses to the watch list. - if s.scanning { - s.clientMtx.Unlock() + if s.isScanning() { return s.rescan.Update(neutrino.AddAddrs(addrs...)) } s.rescanQuit = make(chan struct{}) - s.scanning = true // Don't need RescanFinished or RescanProgress notifications. s.finished = true @@ -478,7 +493,6 @@ func (s *NeutrinoClient) NotifyReceived(addrs []btcutil.Address) error { ) s.rescan = newRescan s.rescanErr = s.rescan.Start() - s.clientMtx.Unlock() return nil }