diff --git a/autopilot/agent.go b/autopilot/agent.go index d9c35a685f..0328d9a617 100644 --- a/autopilot/agent.go +++ b/autopilot/agent.go @@ -2,6 +2,7 @@ package autopilot import ( "bytes" + "context" "fmt" "math/rand" "net" @@ -39,7 +40,8 @@ type Config struct { // ConnectToPeer attempts to connect to the peer using one of its // advertised addresses. The boolean returned signals whether the peer // was already connected. - ConnectToPeer func(*btcec.PublicKey, []net.Addr) (bool, error) + ConnectToPeer func(context.Context, *btcec.PublicKey, []net.Addr) (bool, + error) // DisconnectPeer attempts to disconnect the peer with the given public // key. @@ -199,20 +201,20 @@ func New(cfg Config, initialState []LocalChannel) (*Agent, error) { // Start starts the agent along with any goroutines it needs to perform its // normal duties. -func (a *Agent) Start() error { +func (a *Agent) Start(ctx context.Context) error { var err error a.started.Do(func() { - err = a.start() + err = a.start(ctx) }) return err } -func (a *Agent) start() error { +func (a *Agent) start(ctx context.Context) error { rand.Seed(time.Now().Unix()) log.Infof("Autopilot Agent starting") a.wg.Add(1) - go a.controller() + go a.controller(ctx) return nil } @@ -401,7 +403,7 @@ func mergeChanState(pendingChans map[NodeID]LocalChannel, // and external state changes as a result of decisions it makes w.r.t channel // allocation, or attributes affecting its control loop being updated by the // backing Lightning Node. -func (a *Agent) controller() { +func (a *Agent) controller(ctx context.Context) { defer a.wg.Done() // We'll start off by assigning our starting balance, and injecting @@ -539,7 +541,7 @@ func (a *Agent) controller() { log.Infof("Triggering attachment directive dispatch, "+ "total_funds=%v", a.totalBalance) - err := a.openChans(availableFunds, numChans, totalChans) + err := a.openChans(ctx, availableFunds, numChans, totalChans) if err != nil { log.Errorf("Unable to open channels: %v", err) } @@ -548,8 +550,8 @@ func (a *Agent) controller() { // openChans queries the agent's heuristic for a set of channel candidates, and // attempts to open channels to them. -func (a *Agent) openChans(availableFunds btcutil.Amount, numChans uint32, - totalChans []LocalChannel) error { +func (a *Agent) openChans(ctx context.Context, availableFunds btcutil.Amount, + numChans uint32, totalChans []LocalChannel) error { // As channel size we'll use the maximum channel size available. chanSize := a.cfg.Constraints.MaxChanSize() @@ -716,7 +718,7 @@ func (a *Agent) openChans(availableFunds btcutil.Amount, numChans uint32, a.pendingConns[nodeID] = struct{}{} a.wg.Add(1) - go a.executeDirective(*chanCandidate) + go a.executeDirective(ctx, *chanCandidate) } return nil } @@ -725,7 +727,9 @@ func (a *Agent) openChans(availableFunds btcutil.Amount, numChans uint32, // the given attachment directive, and open a channel of the given size. // // NOTE: MUST be run as a goroutine. -func (a *Agent) executeDirective(directive AttachmentDirective) { +func (a *Agent) executeDirective(ctx context.Context, + directive AttachmentDirective) { + defer a.wg.Done() // We'll start out by attempting to connect to the peer in order to @@ -746,7 +750,7 @@ func (a *Agent) executeDirective(directive AttachmentDirective) { // TODO(halseth): use DialContext to cancel on transport level. go func() { alreadyConnected, err := a.cfg.ConnectToPeer( - pub, directive.Addrs, + ctx, pub, directive.Addrs, ) if err != nil { select { diff --git a/autopilot/manager.go b/autopilot/manager.go index dba4cc6cc5..5dd359464d 100644 --- a/autopilot/manager.go +++ b/autopilot/manager.go @@ -1,6 +1,7 @@ package autopilot import ( + "context" "fmt" "sync" @@ -96,7 +97,7 @@ func (m *Manager) IsActive() bool { // StartAgent creates and starts an autopilot agent from the Manager's // config. -func (m *Manager) StartAgent() error { +func (m *Manager) StartAgent(ctx context.Context) error { m.Lock() defer m.Unlock() @@ -119,7 +120,7 @@ func (m *Manager) StartAgent() error { return err } - if err := pilot.Start(); err != nil { + if err := pilot.Start(ctx); err != nil { return err } diff --git a/chanbackup/recover.go b/chanbackup/recover.go index 033bd695f2..daaad62487 100644 --- a/chanbackup/recover.go +++ b/chanbackup/recover.go @@ -1,6 +1,7 @@ package chanbackup import ( + "context" "net" "github.com/btcsuite/btcd/btcec/v2" @@ -29,7 +30,8 @@ type PeerConnector interface { // available addresses. Once this method returns with a non-nil error, // the connector should attempt to persistently connect to the target // peer in the background as a persistent attempt. - ConnectPeer(node *btcec.PublicKey, addrs []net.Addr) error + ConnectPeer(ctx context.Context, node *btcec.PublicKey, + addrs []net.Addr) error } // Recover attempts to recover the static channel state from a set of static @@ -41,7 +43,7 @@ type PeerConnector interface { // well, in order to expose the addressing information required to locate to // and connect to each peer in order to initiate the recovery protocol. // The number of channels that were successfully restored is returned. -func Recover(backups []Single, restorer ChannelRestorer, +func Recover(ctx context.Context, backups []Single, restorer ChannelRestorer, peerConnector PeerConnector) (int, error) { var numRestored int @@ -70,7 +72,7 @@ func Recover(backups []Single, restorer ChannelRestorer, backup.FundingOutpoint) err = peerConnector.ConnectPeer( - backup.RemoteNodePub, backup.Addresses, + ctx, backup.RemoteNodePub, backup.Addresses, ) if err != nil { return numRestored, err @@ -95,7 +97,7 @@ func Recover(backups []Single, restorer ChannelRestorer, // established, then the PeerConnector will continue to attempt to re-establish // a persistent connection in the background. The number of channels that were // successfully restored is returned. -func UnpackAndRecoverSingles(singles PackedSingles, +func UnpackAndRecoverSingles(ctx context.Context, singles PackedSingles, keyChain keychain.KeyRing, restorer ChannelRestorer, peerConnector PeerConnector) (int, error) { @@ -104,7 +106,7 @@ func UnpackAndRecoverSingles(singles PackedSingles, return 0, err } - return Recover(chanBackups, restorer, peerConnector) + return Recover(ctx, chanBackups, restorer, peerConnector) } // UnpackAndRecoverMulti is a one-shot method, that given a set of packed @@ -114,7 +116,7 @@ func UnpackAndRecoverSingles(singles PackedSingles, // established, then the PeerConnector will continue to attempt to re-establish // a persistent connection in the background. The number of channels that were // successfully restored is returned. -func UnpackAndRecoverMulti(packedMulti PackedMulti, +func UnpackAndRecoverMulti(ctx context.Context, packedMulti PackedMulti, keyChain keychain.KeyRing, restorer ChannelRestorer, peerConnector PeerConnector) (int, error) { @@ -123,5 +125,5 @@ func UnpackAndRecoverMulti(packedMulti PackedMulti, return 0, err } - return Recover(chanBackups.StaticBackups, restorer, peerConnector) + return Recover(ctx, chanBackups.StaticBackups, restorer, peerConnector) } diff --git a/chanrestore.go b/chanrestore.go index 5b221c105a..6daf3922c9 100644 --- a/chanrestore.go +++ b/chanrestore.go @@ -1,6 +1,7 @@ package lnd import ( + "context" "fmt" "math" "net" @@ -309,7 +310,9 @@ var _ chanbackup.ChannelRestorer = (*chanDBRestorer)(nil) // as a persistent attempt. // // NOTE: Part of the chanbackup.PeerConnector interface. -func (s *server) ConnectPeer(nodePub *btcec.PublicKey, addrs []net.Addr) error { +func (s *server) ConnectPeer(ctx context.Context, nodePub *btcec.PublicKey, + addrs []net.Addr) error { + // Before we connect to the remote peer, we'll remove any connections // to ensure the new connection is created after this new link/channel // is known. @@ -333,7 +336,9 @@ func (s *server) ConnectPeer(nodePub *btcec.PublicKey, addrs []net.Addr) error { // Attempt to connect to the peer using this full address. If // we're unable to connect to them, then we'll try the next // address in place of it. - err := s.ConnectToPeer(netAddr, true, s.cfg.ConnectionTimeout) + err := s.ConnectToPeer( + ctx, netAddr, true, s.cfg.ConnectionTimeout, + ) // If we're already connected to this peer, then we don't // consider this an error, so we'll exit here. diff --git a/discovery/gossiper.go b/discovery/gossiper.go index 284cc42212..9c2295587d 100644 --- a/discovery/gossiper.go +++ b/discovery/gossiper.go @@ -2,6 +2,7 @@ package discovery import ( "bytes" + "context" "errors" "fmt" "sync" @@ -436,8 +437,7 @@ type AuthenticatedGossiper struct { // held. bestHeight uint32 - quit chan struct{} - wg sync.WaitGroup + wg sync.WaitGroup // cfg is a copy of the configuration struct that the gossiper service // was initialized with. @@ -524,7 +524,6 @@ func New(cfg Config, selfKeyDesc *keychain.KeyDescriptor) *AuthenticatedGossiper cfg: &cfg, networkMsgs: make(chan *networkMsg), futureMsgs: newFutureMsgCache(maxFutureMessages), - quit: make(chan struct{}), chanPolicyUpdates: make(chan *chanPolicyUpdateRequest), prematureChannelUpdates: lru.NewCache[uint64, *cachedNetworkMsg]( //nolint: lll maxPrematureUpdates, @@ -575,7 +574,7 @@ type EdgeWithInfo struct { // sub-systems, then it signs and broadcasts new updates to the network. A // mapping between outpoints and updated channel policies is returned, which is // used to update the forwarding policies of the underlying links. -func (d *AuthenticatedGossiper) PropagateChanPolicyUpdate( +func (d *AuthenticatedGossiper) PropagateChanPolicyUpdate(ctx context.Context, edgesToUpdate []EdgeWithInfo) error { errChan := make(chan error, 1) @@ -588,23 +587,23 @@ func (d *AuthenticatedGossiper) PropagateChanPolicyUpdate( case d.chanPolicyUpdates <- policyUpdate: err := <-errChan return err - case <-d.quit: + case <-ctx.Done(): return fmt.Errorf("AuthenticatedGossiper shutting down") } } // Start spawns network messages handler goroutine and registers on new block // notifications in order to properly handle the premature announcements. -func (d *AuthenticatedGossiper) Start() error { +func (d *AuthenticatedGossiper) Start(ctx context.Context) error { var err error d.started.Do(func() { log.Info("Authenticated Gossiper starting") - err = d.start() + err = d.start(ctx) }) return err } -func (d *AuthenticatedGossiper) start() error { +func (d *AuthenticatedGossiper) start(ctx context.Context) error { // First we register for new notifications of newly discovered blocks. // We do this immediately so we'll later be able to consume any/all // blocks which were discovered. @@ -627,14 +626,14 @@ func (d *AuthenticatedGossiper) start() error { return err } - d.syncMgr.Start() + d.syncMgr.Start(ctx) d.banman.start() // Start receiving blocks in its dedicated goroutine. d.wg.Add(2) - go d.syncBlockHeight() - go d.networkHandler() + go d.syncBlockHeight(ctx) + go d.networkHandler(ctx) return nil } @@ -643,7 +642,7 @@ func (d *AuthenticatedGossiper) start() error { // blockEpochs. // // NOTE: must be run as a goroutine. -func (d *AuthenticatedGossiper) syncBlockHeight() { +func (d *AuthenticatedGossiper) syncBlockHeight(ctx context.Context) { defer d.wg.Done() for { @@ -668,9 +667,9 @@ func (d *AuthenticatedGossiper) syncBlockHeight() { newBlock.Hash) // Resend future messages, if any. - d.resendFutureMessages(blockHeight) + d.resendFutureMessages(ctx, blockHeight) - case <-d.quit: + case <-ctx.Done(): return } } @@ -719,7 +718,9 @@ func (c *cachedFutureMsg) Size() (uint64, error) { // resendFutureMessages takes a block height, resends all the future messages // found below and equal to that height and deletes those messages found in the // gossiper's futureMsgs. -func (d *AuthenticatedGossiper) resendFutureMessages(height uint32) { +func (d *AuthenticatedGossiper) resendFutureMessages(ctx context.Context, + height uint32) { + var ( // msgs are the target messages. msgs []*networkMsg @@ -757,7 +758,7 @@ func (d *AuthenticatedGossiper) resendFutureMessages(height uint32) { for _, msg := range msgs { select { case d.networkMsgs <- msg: - case <-d.quit: + case <-ctx.Done(): msg.err <- ErrGossiperShuttingDown } } @@ -789,7 +790,6 @@ func (d *AuthenticatedGossiper) stop() { d.banman.stop() - close(d.quit) d.wg.Wait() // We'll stop our reliable sender after all of the gossiper's goroutines @@ -805,8 +805,8 @@ func (d *AuthenticatedGossiper) stop() { // then added to a queue for batched trickled announcement to all connected // peers. Remote channel announcements should contain the announcement proof // and be fully validated. -func (d *AuthenticatedGossiper) ProcessRemoteAnnouncement(msg lnwire.Message, - peer lnpeer.Peer) chan error { +func (d *AuthenticatedGossiper) ProcessRemoteAnnouncement(ctx context.Context, + msg lnwire.Message, peer lnpeer.Peer) chan error { errChan := make(chan error, 1) @@ -892,7 +892,7 @@ func (d *AuthenticatedGossiper) ProcessRemoteAnnouncement(msg lnwire.Message, // to send back an error and can return immediately. case <-peer.QuitSignal(): return nil - case <-d.quit: + case <-ctx.Done(): nMsg.err <- ErrGossiperShuttingDown } @@ -906,8 +906,8 @@ func (d *AuthenticatedGossiper) ProcessRemoteAnnouncement(msg lnwire.Message, // will not be fully validated. Once the channel proofs are received, the // entire channel announcement and update messages will be re-constructed and // broadcast to the rest of the network. -func (d *AuthenticatedGossiper) ProcessLocalAnnouncement(msg lnwire.Message, - optionalFields ...OptionalMsgField) chan error { +func (d *AuthenticatedGossiper) ProcessLocalAnnouncement(ctx context.Context, + msg lnwire.Message, optionalFields ...OptionalMsgField) chan error { optionalMsgFields := &optionalMsgFields{} optionalMsgFields.apply(optionalFields...) @@ -922,7 +922,7 @@ func (d *AuthenticatedGossiper) ProcessLocalAnnouncement(msg lnwire.Message, select { case d.networkMsgs <- nMsg: - case <-d.quit: + case <-ctx.Done(): nMsg.err <- ErrGossiperShuttingDown } @@ -1285,7 +1285,7 @@ func (d *AuthenticatedGossiper) splitAnnouncementBatches( // split size, and then sends out all items to the set of target peers. Locally // generated announcements are always sent before remotely generated // announcements. -func (d *AuthenticatedGossiper) splitAndSendAnnBatch( +func (d *AuthenticatedGossiper) splitAndSendAnnBatch(ctx context.Context, annBatch msgsToBroadcast) { // delayNextBatch is a helper closure that blocks for `SubBatchDelay` @@ -1293,7 +1293,7 @@ func (d *AuthenticatedGossiper) splitAndSendAnnBatch( delayNextBatch := func() { select { case <-time.After(d.cfg.SubBatchDelay): - case <-d.quit: + case <-ctx.Done(): return } } @@ -1379,7 +1379,7 @@ func (d *AuthenticatedGossiper) sendRemoteBatch(annBatch []msgWithSenders) { // broadcasting our latest topology state to all connected peers. // // NOTE: This MUST be run as a goroutine. -func (d *AuthenticatedGossiper) networkHandler() { +func (d *AuthenticatedGossiper) networkHandler(ctx context.Context) { defer d.wg.Done() // Initialize empty deDupedAnnouncements to store announcement batch. @@ -1400,7 +1400,7 @@ func (d *AuthenticatedGossiper) networkHandler() { // We'll use this validation to ensure that we process jobs in their // dependency order during parallel validation. - validationBarrier := graph.NewValidationBarrier(1000, d.quit) + validationBarrier := graph.NewValidationBarrier(1000) for { select { @@ -1440,7 +1440,7 @@ func (d *AuthenticatedGossiper) networkHandler() { // messages that we'll process serially. case *lnwire.AnnounceSignatures1: emittedAnnouncements, _ := d.processNetworkAnnouncement( - announcement, + ctx, announcement, ) log.Debugf("Processed network message %s, "+ "returned len(announcements)=%v", @@ -1470,11 +1470,14 @@ func (d *AuthenticatedGossiper) networkHandler() { // We'll set up any dependent, and wait until a free // slot for this job opens up, this allow us to not // have thousands of goroutines active. - validationBarrier.InitJobDependencies(announcement.msg) + validationBarrier.InitJobDependencies( + ctx, announcement.msg, + ) d.wg.Add(1) go d.handleNetworkMessages( - announcement, &announcements, validationBarrier, + ctx, announcement, &announcements, + validationBarrier, ) // The trickle timer has ticked, which indicates we should @@ -1497,7 +1500,7 @@ func (d *AuthenticatedGossiper) networkHandler() { // announcements, we'll blast them out w/o regard for // our peer's policies so we ensure they propagate // properly. - d.splitAndSendAnnBatch(announcementBatch) + d.splitAndSendAnnBatch(ctx, announcementBatch) // The retransmission timer has ticked which indicates that we // should check if we need to prune or re-broadcast any of our @@ -1513,7 +1516,7 @@ func (d *AuthenticatedGossiper) networkHandler() { // The gossiper has been signalled to exit, to we exit our // main loop so the wait group can be decremented. - case <-d.quit: + case <-ctx.Done(): return } } @@ -1524,11 +1527,12 @@ func (d *AuthenticatedGossiper) networkHandler() { // signal its dependants and add the new announcements to the announce batch. // // NOTE: must be run as a goroutine. -func (d *AuthenticatedGossiper) handleNetworkMessages(nMsg *networkMsg, - deDuped *deDupedAnnouncements, vb *graph.ValidationBarrier) { +func (d *AuthenticatedGossiper) handleNetworkMessages(ctx context.Context, + nMsg *networkMsg, deDuped *deDupedAnnouncements, + vb *graph.ValidationBarrier) { defer d.wg.Done() - defer vb.CompleteJob() + defer vb.CompleteJob(ctx) // We should only broadcast this message forward if it originated from // us or it wasn't received as part of our initial historical sync. @@ -1536,7 +1540,7 @@ func (d *AuthenticatedGossiper) handleNetworkMessages(nMsg *networkMsg, // If this message has an existing dependency, then we'll wait until // that has been fully validated before we proceed. - err := vb.WaitForDependants(nMsg.msg) + err := vb.WaitForDependants(ctx, nMsg.msg) if err != nil { log.Debugf("Validating network message %s got err: %v", nMsg.msg.MsgType(), err) @@ -1558,7 +1562,7 @@ func (d *AuthenticatedGossiper) handleNetworkMessages(nMsg *networkMsg, // Process the network announcement to determine if this is either a // new announcement from our PoV or an edges to a prior vertex/edge we // previously proceeded. - newAnns, allow := d.processNetworkAnnouncement(nMsg) + newAnns, allow := d.processNetworkAnnouncement(ctx, nMsg) log.Tracef("Processed network message %s, returned "+ "len(announcements)=%v, allowDependents=%v", @@ -1586,15 +1590,19 @@ func (d *AuthenticatedGossiper) handleNetworkMessages(nMsg *networkMsg, // established to a new peer that understands how to perform channel range // queries. We'll allocate a new gossip syncer for it, and start any goroutines // needed to handle new queries. -func (d *AuthenticatedGossiper) InitSyncState(syncPeer lnpeer.Peer) { - d.syncMgr.InitSyncState(syncPeer) +func (d *AuthenticatedGossiper) InitSyncState(ctx context.Context, + syncPeer lnpeer.Peer) { + + d.syncMgr.InitSyncState(ctx, syncPeer) } // PruneSyncState is called by outside sub-systems once a peer that we were // previously connected to has been disconnected. In this case we can stop the // existing GossipSyncer assigned to the peer and free up resources. -func (d *AuthenticatedGossiper) PruneSyncState(peer route.Vertex) { - d.syncMgr.PruneSyncState(peer) +func (d *AuthenticatedGossiper) PruneSyncState(ctx context.Context, + peer route.Vertex) { + + d.syncMgr.PruneSyncState(ctx, peer) } // isRecentlyRejectedMsg returns true if we recently rejected a message, and @@ -1865,7 +1873,7 @@ func remotePubFromChanInfo(chanInfo *models.ChannelEdgeInfo, // situation in the case where we create a channel, but for some reason fail // to receive the remote peer's proof, while the remote peer is able to fully // assemble the proof and craft the ChannelAnnouncement. -func (d *AuthenticatedGossiper) processRejectedEdge( +func (d *AuthenticatedGossiper) processRejectedEdge(ctx context.Context, chanAnnMsg *lnwire.ChannelAnnouncement1, proof *models.ChannelAuthProof) ([]networkMsg, error) { @@ -1900,7 +1908,7 @@ func (d *AuthenticatedGossiper) processRejectedEdge( if err != nil { return nil, err } - err = netann.ValidateChannelAnn(chanAnn, d.fetchPKScript) + err = netann.ValidateChannelAnn(ctx, chanAnn, d.fetchPKScript) if err != nil { err := fmt.Errorf("assembled channel announcement proof "+ "for shortChanID=%v isn't valid: %v", @@ -1945,10 +1953,10 @@ func (d *AuthenticatedGossiper) processRejectedEdge( } // fetchPKScript fetches the output script for the given SCID. -func (d *AuthenticatedGossiper) fetchPKScript(chanID *lnwire.ShortChannelID) ( - []byte, error) { +func (d *AuthenticatedGossiper) fetchPKScript(ctx context.Context, + chanID *lnwire.ShortChannelID) ([]byte, error) { - return lnwallet.FetchPKScriptWithQuit(d.cfg.ChainIO, chanID, d.quit) + return lnwallet.FetchPKScript(ctx, d.cfg.ChainIO, chanID) } // addNode processes the given node announcement, and adds it to our channel @@ -2037,7 +2045,7 @@ func (d *AuthenticatedGossiper) isPremature(chanID lnwire.ShortChannelID, // be returned which should be broadcasted to the rest of the network. The // boolean returned indicates whether any dependents of the announcement should // attempt to be processed as well. -func (d *AuthenticatedGossiper) processNetworkAnnouncement( +func (d *AuthenticatedGossiper) processNetworkAnnouncement(ctx context.Context, nMsg *networkMsg) ([]networkMsg, bool) { // If this is a remote update, we set the scheduler option to lazily @@ -2059,7 +2067,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement( // the existence of a channel and not yet the routing policies in // either direction of the channel. case *lnwire.ChannelAnnouncement1: - return d.handleChanAnnouncement(nMsg, msg, schedulerOp) + return d.handleChanAnnouncement(ctx, nMsg, msg, schedulerOp) // A new authenticated channel edge update has arrived. This indicates // that the directional information for an already known channel has @@ -2071,7 +2079,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement( // willingness of nodes involved in the funding of a channel to // announce this new channel to the rest of the world. case *lnwire.AnnounceSignatures1: - return d.handleAnnSig(nMsg, msg) + return d.handleAnnSig(ctx, nMsg, msg) default: err := errors.New("wrong type of the announcement") @@ -2437,8 +2445,8 @@ func (d *AuthenticatedGossiper) handleNodeAnnouncement(nMsg *networkMsg, } // handleChanAnnouncement processes a new channel announcement. -func (d *AuthenticatedGossiper) handleChanAnnouncement(nMsg *networkMsg, - ann *lnwire.ChannelAnnouncement1, +func (d *AuthenticatedGossiper) handleChanAnnouncement(ctx context.Context, + nMsg *networkMsg, ann *lnwire.ChannelAnnouncement1, ops []batch.SchedulerOption) ([]networkMsg, bool) { scid := ann.ShortChannelID @@ -2550,7 +2558,7 @@ func (d *AuthenticatedGossiper) handleChanAnnouncement(nMsg *networkMsg, // the signatures within the proof as it should be well formed. var proof *models.ChannelAuthProof if nMsg.isRemote { - err := netann.ValidateChannelAnn(ann, d.fetchPKScript) + err := netann.ValidateChannelAnn(ctx, ann, d.fetchPKScript) if err != nil { err := fmt.Errorf("unable to validate announcement: "+ "%v", err) @@ -2638,7 +2646,7 @@ func (d *AuthenticatedGossiper) handleChanAnnouncement(nMsg *networkMsg, case graph.IsError(err, graph.ErrIgnored): // Attempt to process the rejected message to see if we // get any new announcements. - anns, rErr := d.processRejectedEdge(ann, proof) + anns, rErr := d.processRejectedEdge(ctx, ann, proof) if rErr != nil { key := newRejectCacheKey( scid.ToUint64(), @@ -2792,7 +2800,7 @@ func (d *AuthenticatedGossiper) handleChanAnnouncement(nMsg *networkMsg, select { case d.networkMsgs <- updMsg: - case <-d.quit: + case <-ctx.Done(): updMsg.err <- ErrGossiperShuttingDown } @@ -3232,7 +3240,7 @@ func (d *AuthenticatedGossiper) handleChanUpdate(nMsg *networkMsg, } // handleAnnSig processes a new announcement signatures message. -func (d *AuthenticatedGossiper) handleAnnSig(nMsg *networkMsg, +func (d *AuthenticatedGossiper) handleAnnSig(ctx context.Context, nMsg *networkMsg, ann *lnwire.AnnounceSignatures1) ([]networkMsg, bool) { needBlockHeight := ann.ShortChannelID.BlockHeight + @@ -3448,7 +3456,7 @@ func (d *AuthenticatedGossiper) handleAnnSig(nMsg *networkMsg, // With all the necessary components assembled validate the full // channel announcement proof. - err = netann.ValidateChannelAnn(chanAnn, d.fetchPKScript) + err = netann.ValidateChannelAnn(ctx, chanAnn, d.fetchPKScript) if err != nil { err := fmt.Errorf("channel announcement proof for "+ "short_chan_id=%v isn't valid: %v", shortChanID, err) diff --git a/discovery/sync_manager.go b/discovery/sync_manager.go index 70d28784b8..69cfca8d7b 100644 --- a/discovery/sync_manager.go +++ b/discovery/sync_manager.go @@ -1,6 +1,7 @@ package discovery import ( + "context" "errors" "sync" "sync/atomic" @@ -168,13 +169,11 @@ type SyncManager struct { // queries. gossipFilterSema chan struct{} - wg sync.WaitGroup - quit chan struct{} + wg sync.WaitGroup } // newSyncManager constructs a new SyncManager backed by the given config. func newSyncManager(cfg *SyncManagerCfg) *SyncManager { - filterSema := make(chan struct{}, filterSemaSize) for i := 0; i < filterSemaSize; i++ { filterSema <- struct{}{} @@ -192,15 +191,14 @@ func newSyncManager(cfg *SyncManagerCfg) *SyncManager { map[route.Vertex]*GossipSyncer, len(cfg.PinnedSyncers), ), gossipFilterSema: filterSema, - quit: make(chan struct{}), } } // Start starts the SyncManager in order to properly carry out its duties. -func (m *SyncManager) Start() { +func (m *SyncManager) Start(ctx context.Context) { m.start.Do(func() { m.wg.Add(1) - go m.syncerHandler() + go m.syncerHandler(ctx) }) } @@ -210,7 +208,6 @@ func (m *SyncManager) Stop() { log.Debugf("SyncManager is stopping") defer log.Debugf("SyncManager stopped") - close(m.quit) m.wg.Wait() for _, syncer := range m.inactiveSyncers { @@ -233,7 +230,7 @@ func (m *SyncManager) Stop() { // much of the public network as possible. // // NOTE: This must be run as a goroutine. -func (m *SyncManager) syncerHandler() { +func (m *SyncManager) syncerHandler(ctx context.Context) { defer m.wg.Done() m.cfg.RotateTicker.Resume() @@ -481,7 +478,7 @@ func (m *SyncManager) syncerHandler() { // much of the graph as we should. setInitialHistoricalSyncer(s) - case <-m.quit: + case <-ctx.Done(): return } } @@ -706,7 +703,9 @@ func chooseRandomSyncer(syncers map[route.Vertex]*GossipSyncer, // public channel graph as possible. // // TODO(wilmer): Only mark as ActiveSync if this isn't a channel peer. -func (m *SyncManager) InitSyncState(peer lnpeer.Peer) error { +func (m *SyncManager) InitSyncState(ctx context.Context, + peer lnpeer.Peer) error { + done := make(chan struct{}) select { @@ -714,14 +713,14 @@ func (m *SyncManager) InitSyncState(peer lnpeer.Peer) error { peer: peer, doneChan: done, }: - case <-m.quit: + case <-ctx.Done(): return ErrSyncManagerExiting } select { case <-done: return nil - case <-m.quit: + case <-ctx.Done(): return ErrSyncManagerExiting } } @@ -729,7 +728,7 @@ func (m *SyncManager) InitSyncState(peer lnpeer.Peer) error { // PruneSyncState is called by outside sub-systems once a peer that we were // previously connected to has been disconnected. In this case we can stop the // existing GossipSyncer assigned to the peer and free up resources. -func (m *SyncManager) PruneSyncState(peer route.Vertex) { +func (m *SyncManager) PruneSyncState(ctx context.Context, peer route.Vertex) { done := make(chan struct{}) // We avoid returning an error when the SyncManager is stopped since the @@ -739,13 +738,13 @@ func (m *SyncManager) PruneSyncState(peer route.Vertex) { peer: peer, doneChan: done, }: - case <-m.quit: + case <-ctx.Done(): return } select { case <-done: - case <-m.quit: + case <-ctx.Done(): } } diff --git a/funding/manager.go b/funding/manager.go index 1fa90c6932..b90cfc5896 100644 --- a/funding/manager.go +++ b/funding/manager.go @@ -2,6 +2,7 @@ package funding import ( "bytes" + "context" "encoding/binary" "fmt" "io" @@ -405,7 +406,7 @@ type Config struct { // network. A set of optional message fields can be provided to populate // any information within the graph that is not included in the gossip // message. - SendAnnouncement func(msg lnwire.Message, + SendAnnouncement func(ctx context.Context, msg lnwire.Message, optionalFields ...discovery.OptionalMsgField) chan error // NotifyWhenOnline allows the FundingManager to register with a @@ -703,16 +704,16 @@ func NewFundingManager(cfg Config) (*Manager, error) { // Start launches all helper goroutines required for handling requests sent // to the funding manager. -func (f *Manager) Start() error { +func (f *Manager) Start(ctx context.Context) error { var err error f.started.Do(func() { log.Info("Funding manager starting") - err = f.start() + err = f.start(ctx) }) return err } -func (f *Manager) start() error { +func (f *Manager) start(ctx context.Context) error { // Upon restart, the Funding Manager will check the database to load any // channels that were waiting for their funding transactions to be // confirmed on the blockchain at the time when the daemon last went @@ -766,11 +767,11 @@ func (f *Manager) start() error { // confirmed on the blockchain, and transmit the messages // necessary for the channel to be operational. f.wg.Add(1) - go f.advanceFundingState(channel, chanID, nil) + go f.advanceFundingState(ctx, channel, chanID, nil) } f.wg.Add(1) // TODO(roasbeef): tune - go f.reservationCoordinator() + go f.reservationCoordinator(ctx) return nil } @@ -1017,7 +1018,7 @@ func (f *Manager) sendWarning(peer lnpeer.Peer, cid *chanIdentifier, // funding workflow between the wallet, and any outside peers or local callers. // // NOTE: This MUST be run as a goroutine. -func (f *Manager) reservationCoordinator() { +func (f *Manager) reservationCoordinator(ctx context.Context) { defer f.wg.Done() zombieSweepTicker := time.NewTicker(f.cfg.ZombieSweeperInterval) @@ -1034,10 +1035,10 @@ func (f *Manager) reservationCoordinator() { f.funderProcessAcceptChannel(fmsg.peer, msg) case *lnwire.FundingCreated: - f.fundeeProcessFundingCreated(fmsg.peer, msg) + f.fundeeProcessFundingCreated(ctx, fmsg.peer, msg) case *lnwire.FundingSigned: - f.funderProcessFundingSigned(fmsg.peer, msg) + f.funderProcessFundingSigned(ctx, fmsg.peer, msg) case *lnwire.ChannelReady: f.wg.Add(1) @@ -1069,7 +1070,8 @@ func (f *Manager) reservationCoordinator() { // OpenStatusUpdates. // // NOTE: This MUST be run as a goroutine. -func (f *Manager) advanceFundingState(channel *channeldb.OpenChannel, +func (f *Manager) advanceFundingState(ctx context.Context, + channel *channeldb.OpenChannel, pendingChanID PendingChanID, updateChan chan<- *lnrpc.OpenStatusUpdate) { @@ -1135,7 +1137,7 @@ func (f *Manager) advanceFundingState(channel *channeldb.OpenChannel, // are still steps left of the setup procedure. We continue the // procedure where we left off. err = f.stateStep( - channel, lnChannel, shortChanID, pendingChanID, + ctx, channel, lnChannel, shortChanID, pendingChanID, channelState, updateChan, ) if err != nil { @@ -1150,7 +1152,7 @@ func (f *Manager) advanceFundingState(channel *channeldb.OpenChannel, // machine. This method is synchronous and the new channel opening state will // have been written to the database when it successfully returns. The // updateChan can be set non-nil to get OpenStatusUpdates. -func (f *Manager) stateStep(channel *channeldb.OpenChannel, +func (f *Manager) stateStep(ctx context.Context, channel *channeldb.OpenChannel, lnChannel *lnwallet.LightningChannel, shortChanID *lnwire.ShortChannelID, pendingChanID PendingChanID, channelState channelOpeningState, @@ -1215,7 +1217,7 @@ func (f *Manager) stateStep(channel *channeldb.OpenChannel, } return f.handleChannelReadyReceived( - channel, shortChanID, pendingChanID, updateChan, + ctx, channel, shortChanID, pendingChanID, updateChan, ) // The channel was added to the Router's topology, but the channel @@ -1225,7 +1227,7 @@ func (f *Manager) stateStep(channel *channeldb.OpenChannel, // If this is a zero-conf channel, then we will wait // for it to be confirmed before announcing it to the // greater network. - err := f.waitForZeroConfChannel(channel) + err := f.waitForZeroConfChannel(ctx, channel) if err != nil { return fmt.Errorf("failed waiting for zero "+ "channel: %v", err) @@ -1237,7 +1239,7 @@ func (f *Manager) stateStep(channel *channeldb.OpenChannel, shortChanID = &confirmedScid } - err := f.annAfterSixConfs(channel, shortChanID) + err := f.annAfterSixConfs(ctx, channel, shortChanID) if err != nil { return fmt.Errorf("error sending channel "+ "announcement: %v", err) @@ -2442,8 +2444,8 @@ func (f *Manager) continueFundingAccept(resCtx *reservationWithCtx, // stage. // //nolint:funlen -func (f *Manager) fundeeProcessFundingCreated(peer lnpeer.Peer, - msg *lnwire.FundingCreated) { +func (f *Manager) fundeeProcessFundingCreated(ctx context.Context, + peer lnpeer.Peer, msg *lnwire.FundingCreated) { peerKey := peer.IdentityKey() pendingChanID := msg.PendingChannelID @@ -2672,7 +2674,7 @@ func (f *Manager) fundeeProcessFundingCreated(peer lnpeer.Peer, // transaction in 288 blocks (~ 48 hrs), by canceling the reservation // and canceling the wait for the funding confirmation. f.wg.Add(1) - go f.advanceFundingState(completeChan, pendingChanID, nil) + go f.advanceFundingState(ctx, completeChan, pendingChanID, nil) } // funderProcessFundingSigned processes the final message received in a single @@ -2680,8 +2682,8 @@ func (f *Manager) fundeeProcessFundingCreated(peer lnpeer.Peer, // broadcast. Once the funding transaction reaches a sufficient number of // confirmations, a message is sent to the responding peer along with a compact // encoding of the location of the channel within the blockchain. -func (f *Manager) funderProcessFundingSigned(peer lnpeer.Peer, - msg *lnwire.FundingSigned) { +func (f *Manager) funderProcessFundingSigned(ctx context.Context, + peer lnpeer.Peer, msg *lnwire.FundingSigned) { // As the funding signed message will reference the reservation by its // permanent channel ID, we'll need to perform an intermediate look up @@ -2881,7 +2883,7 @@ func (f *Manager) funderProcessFundingSigned(peer lnpeer.Peer, // At this point we have broadcast the funding transaction and done all // necessary processing. f.wg.Add(1) - go f.advanceFundingState(completeChan, pendingChanID, resCtx.updates) + go f.advanceFundingState(ctx, completeChan, pendingChanID, resCtx.updates) } // confirmedChannel wraps a confirmed funding transaction, as well as the short @@ -3538,7 +3540,8 @@ func (f *Manager) extractAnnounceParams(c *channeldb.OpenChannel) ( // The peerAlias is used for zero-conf channels to give the counter-party a // ChannelUpdate they understand. ourPolicy may be set for various // option-scid-alias channels to re-use the same policy. -func (f *Manager) addToGraph(completeChan *channeldb.OpenChannel, +func (f *Manager) addToGraph(ctx context.Context, + completeChan *channeldb.OpenChannel, shortChanID *lnwire.ShortChannelID, peerAlias *lnwire.ShortChannelID, ourPolicy *models.ChannelEdgePolicy) error { @@ -3562,7 +3565,7 @@ func (f *Manager) addToGraph(completeChan *channeldb.OpenChannel, // Send ChannelAnnouncement and ChannelUpdate to the gossiper to add // to the Router's topology. errChan := f.cfg.SendAnnouncement( - ann.chanAnn, discovery.ChannelCapacity(completeChan.Capacity), + ctx, ann.chanAnn, discovery.ChannelCapacity(completeChan.Capacity), discovery.ChannelPoint(completeChan.FundingOutpoint), discovery.TapscriptRoot(completeChan.TapscriptRoot), ) @@ -3584,7 +3587,7 @@ func (f *Manager) addToGraph(completeChan *channeldb.OpenChannel, } errChan = f.cfg.SendAnnouncement( - ann.chanUpdateAnn, discovery.RemoteAlias(peerAlias), + ctx, ann.chanUpdateAnn, discovery.RemoteAlias(peerAlias), ) select { case err := <-errChan: @@ -3612,7 +3615,8 @@ func (f *Manager) addToGraph(completeChan *channeldb.OpenChannel, // 'addedToGraph') and the channel is ready to be used. This is the last // step in the channel opening process, and the opening state will be deleted // from the database if successful. -func (f *Manager) annAfterSixConfs(completeChan *channeldb.OpenChannel, +func (f *Manager) annAfterSixConfs(ctx context.Context, + completeChan *channeldb.OpenChannel, shortChanID *lnwire.ShortChannelID) error { // If this channel is not meant to be announced to the greater network, @@ -3731,7 +3735,7 @@ func (f *Manager) annAfterSixConfs(completeChan *channeldb.OpenChannel, } err = f.addToGraph( - completeChan, &baseScid, nil, ourPolicy, + ctx, completeChan, &baseScid, nil, ourPolicy, ) if err != nil { return fmt.Errorf("failed to re-add to "+ @@ -3742,7 +3746,7 @@ func (f *Manager) annAfterSixConfs(completeChan *channeldb.OpenChannel, // Create and broadcast the proofs required to make this channel // public and usable for other nodes for routing. err = f.announceChannel( - f.cfg.IDKey, completeChan.IdentityPub, + ctx, f.cfg.IDKey, completeChan.IdentityPub, &completeChan.LocalChanCfg.MultiSigKey, completeChan.RemoteChanCfg.MultiSigKey.PubKey, *shortChanID, chanID, completeChan.ChanType, @@ -3762,7 +3766,8 @@ func (f *Manager) annAfterSixConfs(completeChan *channeldb.OpenChannel, // waitForZeroConfChannel is called when the state is addedToGraph with // a zero-conf channel. This will wait for the real confirmation, add the // confirmed SCID to the router graph, and then announce after six confs. -func (f *Manager) waitForZeroConfChannel(c *channeldb.OpenChannel) error { +func (f *Manager) waitForZeroConfChannel(ctx context.Context, + c *channeldb.OpenChannel) error { // First we'll check whether the channel is confirmed on-chain. If it // is already confirmed, the chainntnfs subsystem will return with the // confirmed tx. Otherwise, we'll wait here until confirmation occurs. @@ -3820,7 +3825,7 @@ func (f *Manager) waitForZeroConfChannel(c *channeldb.OpenChannel) error { // alias since we'll be using the confirmed SCID from now on // regardless if it's public or not. err = f.addToGraph( - c, &confChan.shortChanID, nil, ourPolicy, + ctx, c, &confChan.shortChanID, nil, ourPolicy, ) if err != nil { return fmt.Errorf("failed adding confirmed zero-conf "+ @@ -4148,7 +4153,8 @@ func (f *Manager) handleChannelReady(peer lnpeer.Peer, //nolint:funlen // channelReady message, once the remote's channelReady is processed, the // channel is now active, thus we change its state to `addedToGraph` to // let the channel start handling routing. -func (f *Manager) handleChannelReadyReceived(channel *channeldb.OpenChannel, +func (f *Manager) handleChannelReadyReceived(ctx context.Context, + channel *channeldb.OpenChannel, scid *lnwire.ShortChannelID, pendingChanID PendingChanID, updateChan chan<- *lnrpc.OpenStatusUpdate) error { @@ -4178,7 +4184,7 @@ func (f *Manager) handleChannelReadyReceived(channel *channeldb.OpenChannel, peerAlias = &foundAlias } - err := f.addToGraph(channel, scid, peerAlias, nil) + err := f.addToGraph(ctx, channel, scid, peerAlias, nil) if err != nil { return fmt.Errorf("failed adding to graph: %w", err) } @@ -4512,7 +4518,8 @@ func (f *Manager) newChanAnnouncement(localPubKey, // the network during its next trickle. // This method is synchronous and will return when all the network requests // finish, either successfully or with an error. -func (f *Manager) announceChannel(localIDKey, remoteIDKey *btcec.PublicKey, +func (f *Manager) announceChannel(ctx context.Context, localIDKey, + remoteIDKey *btcec.PublicKey, localFundingKey *keychain.KeyDescriptor, remoteFundingKey *btcec.PublicKey, shortChanID lnwire.ShortChannelID, chanID lnwire.ChannelID, chanType channeldb.ChannelType) error { @@ -4537,7 +4544,7 @@ func (f *Manager) announceChannel(localIDKey, remoteIDKey *btcec.PublicKey, // because addToGraph previously sent the ChannelAnnouncement and // the ChannelUpdate announcement messages. The channel proof and node // announcements are broadcast to the greater network. - errChan := f.cfg.SendAnnouncement(ann.chanProof) + errChan := f.cfg.SendAnnouncement(ctx, ann.chanProof) select { case err := <-errChan: if err != nil { @@ -4567,7 +4574,7 @@ func (f *Manager) announceChannel(localIDKey, remoteIDKey *btcec.PublicKey, return err } - errChan = f.cfg.SendAnnouncement(&nodeAnn) + errChan = f.cfg.SendAnnouncement(ctx, &nodeAnn) select { case err := <-errChan: if err != nil { diff --git a/graph/builder.go b/graph/builder.go index 6930f1a894..b5f88fb8ed 100644 --- a/graph/builder.go +++ b/graph/builder.go @@ -2,6 +2,7 @@ package graph import ( "bytes" + "context" "fmt" "runtime" "strings" @@ -185,7 +186,7 @@ func NewBuilder(cfg *Config) (*Builder, error) { // Start launches all the goroutines the Builder requires to carry out its // duties. If the builder has already been started, then this method is a noop. -func (b *Builder) Start() error { +func (b *Builder) Start(ctx context.Context) error { if !b.started.CompareAndSwap(false, true) { return nil } @@ -301,7 +302,7 @@ func (b *Builder) Start() error { } b.wg.Add(1) - go b.networkHandler() + go b.networkHandler(ctx) log.Debug("Builder started") @@ -669,15 +670,15 @@ func (b *Builder) pruneZombieChans() error { // notifies topology changes, if any. // // NOTE: must be run inside goroutine. -func (b *Builder) handleNetworkUpdate(vb *ValidationBarrier, +func (b *Builder) handleNetworkUpdate(ctx context.Context, vb *ValidationBarrier, update *routingMsg) { defer b.wg.Done() - defer vb.CompleteJob() + defer vb.CompleteJob(ctx) // If this message has an existing dependency, then we'll wait until // that has been fully validated before we proceed. - err := vb.WaitForDependants(update.msg) + err := vb.WaitForDependants(ctx, update.msg) if err != nil { switch { case IsError(err, ErrVBarrierShuttingDown): @@ -698,7 +699,7 @@ func (b *Builder) handleNetworkUpdate(vb *ValidationBarrier, // Process the routing update to determine if this is either a new // update from our PoV or an update to a prior vertex/edge we // previously accepted. - err = b.processUpdate(update.msg, update.op...) + err = b.processUpdate(ctx, update.msg, update.op...) update.err <- err // If this message had any dependencies, then we can now signal them to @@ -743,7 +744,7 @@ func (b *Builder) handleNetworkUpdate(vb *ValidationBarrier, // updates, and registering new topology clients. // // NOTE: This MUST be run as a goroutine. -func (b *Builder) networkHandler() { +func (b *Builder) networkHandler(ctx context.Context) { defer b.wg.Done() graphPruneTicker := time.NewTicker(b.cfg.GraphPruneInterval) @@ -771,11 +772,9 @@ func (b *Builder) networkHandler() { // See https://github.com/lightningnetwork/lnd/issues/4892. var validationBarrier *ValidationBarrier if b.cfg.AssumeChannelValid { - validationBarrier = NewValidationBarrier(1000, b.quit) + validationBarrier = NewValidationBarrier(1000) } else { - validationBarrier = NewValidationBarrier( - 4*runtime.NumCPU(), b.quit, - ) + validationBarrier = NewValidationBarrier(4 * runtime.NumCPU()) } for { @@ -792,10 +791,10 @@ func (b *Builder) networkHandler() { // We'll set up any dependants, and wait until a free // slot for this job opens up, this allows us to not // have thousands of goroutines active. - validationBarrier.InitJobDependencies(update.msg) + validationBarrier.InitJobDependencies(ctx, update.msg) b.wg.Add(1) - go b.handleNetworkUpdate(validationBarrier, update) + go b.handleNetworkUpdate(ctx, validationBarrier, update) // TODO(roasbeef): remove all unconnected vertexes // after N blocks pass with no corresponding @@ -1161,7 +1160,7 @@ func makeFundingScript(bitcoinKey1, bitcoinKey2 []byte, chanFeatures []byte, // then error is returned. // //nolint:funlen -func (b *Builder) processUpdate(msg interface{}, +func (b *Builder) processUpdate(ctx context.Context, msg interface{}, op ...batch.SchedulerOption) error { switch msg := msg.(type) { @@ -1233,7 +1232,7 @@ func (b *Builder) processUpdate(msg interface{}, // the channel ID. channelID := lnwire.NewShortChanIDFromInt(msg.ChannelID) fundingTx, err := lnwallet.FetchFundingTxWrapper( - b.cfg.Chain, &channelID, b.quit, + ctx, b.cfg.Chain, &channelID, ) if err != nil { //nolint:lll diff --git a/graph/validation_barrier.go b/graph/validation_barrier.go index 98d910d899..3ed1022079 100644 --- a/graph/validation_barrier.go +++ b/graph/validation_barrier.go @@ -1,6 +1,7 @@ package graph import ( + "context" "fmt" "sync" @@ -53,21 +54,17 @@ type ValidationBarrier struct { // ChannelAnnouncement before proceeding. nodeAnnDependencies map[route.Vertex]*validationSignals - quit chan struct{} sync.Mutex } // NewValidationBarrier creates a new instance of a validation barrier given // the total number of active requests, and a quit channel which will be used // to know when to kill pending, but unfilled jobs. -func NewValidationBarrier(numActiveReqs int, - quitChan chan struct{}) *ValidationBarrier { - +func NewValidationBarrier(numActiveReqs int) *ValidationBarrier { v := &ValidationBarrier{ chanAnnFinSignal: make(map[lnwire.ShortChannelID]*validationSignals), chanEdgeDependencies: make(map[lnwire.ShortChannelID]*validationSignals), nodeAnnDependencies: make(map[route.Vertex]*validationSignals), - quit: quitChan, } // We'll first initialize a set of semaphores to limit our concurrency @@ -82,12 +79,14 @@ func NewValidationBarrier(numActiveReqs int, // InitJobDependencies will wait for a new job slot to become open, and then // sets up any dependent signals/trigger for the new job -func (v *ValidationBarrier) InitJobDependencies(job interface{}) { +func (v *ValidationBarrier) InitJobDependencies(ctx context.Context, + job interface{}) { + // We'll wait for either a new slot to become open, or for the quit // channel to be closed. select { case <-v.validationSemaphore: - case <-v.quit: + case <-ctx.Done(): } v.Lock() @@ -163,10 +162,10 @@ func (v *ValidationBarrier) InitJobDependencies(job interface{}) { // should be called once a job has been fully completed. Otherwise, slots may // not be returned to the internal scheduling, causing a deadlock when a new // overflow job is attempted. -func (v *ValidationBarrier) CompleteJob() { +func (v *ValidationBarrier) CompleteJob(ctx context.Context) { select { case v.validationSemaphore <- struct{}{}: - case <-v.quit: + case <-ctx.Done(): } } @@ -174,7 +173,8 @@ func (v *ValidationBarrier) CompleteJob() { // finished executing. This allows us a graceful way to schedule goroutines // based on any pending uncompleted dependent jobs. If this job doesn't have an // active dependent, then this function will return immediately. -func (v *ValidationBarrier) WaitForDependants(job interface{}) error { +func (v *ValidationBarrier) WaitForDependants(ctx context.Context, + job interface{}) error { var ( signals *validationSignals @@ -237,7 +237,7 @@ func (v *ValidationBarrier) WaitForDependants(job interface{}) error { // If we do have an active job, then we'll wait until either the signal // is closed, or the set of jobs exits. select { - case <-v.quit: + case <-ctx.Done(): return NewErrf(ErrVBarrierShuttingDown, "validation barrier shutting down") diff --git a/lnd.go b/lnd.go index f511811950..d5e51b5b66 100644 --- a/lnd.go +++ b/lnd.go @@ -598,7 +598,7 @@ func Main(cfg *Config, lisCfg ListenerCfg, implCfg *ImplementationCfg, // Set up the core server which will listen for incoming peer // connections. server, err := newServer( - cfg, cfg.Listeners, dbs, activeChainControl, &idKeyDesc, + ctx, cfg, cfg.Listeners, dbs, activeChainControl, &idKeyDesc, activeChainControl.Cfg.WalletUnlockParams.ChansToRestore, multiAcceptor, torController, tlsManager, leaderElector, implCfg, @@ -731,15 +731,19 @@ func Main(cfg *Config, lisCfg ListenerCfg, implCfg *ImplementationCfg, // case the startup of the subservers do not behave as expected. errChan := make(chan error) go func() { - errChan <- server.Start() + errChan <- server.Start(ctx) }() defer func() { - err := server.Stop() + stopCtx, cancel := context.WithTimeout( + context.Background(), time.Minute, // Give LND a min to shutdown. + ) + err := server.Stop(stopCtx) if err != nil { ltndLog.Warnf("Stopping the server including all "+ "its subsystems failed with %v", err) } + cancel() }() select { @@ -761,7 +765,7 @@ func Main(cfg *Config, lisCfg ListenerCfg, implCfg *ImplementationCfg, // active, then we'll start the autopilot agent immediately. It will be // stopped together with the autopilot service. if cfg.Autopilot.Active { - if err := atplManager.StartAgent(); err != nil { + if err := atplManager.StartAgent(ctx); err != nil { return mkErr("unable to start autopilot agent: %v", err) } } diff --git a/lnrpc/autopilotrpc/autopilot_server.go b/lnrpc/autopilotrpc/autopilot_server.go index 761d5f0926..e41b48200a 100644 --- a/lnrpc/autopilotrpc/autopilot_server.go +++ b/lnrpc/autopilotrpc/autopilot_server.go @@ -205,7 +205,7 @@ func (s *Server) ModifyStatus(ctx context.Context, var err error if in.Enable { - err = s.manager.StartAgent() + err = s.manager.StartAgent(ctx) } else { err = s.manager.StopAgent() } diff --git a/lnwallet/interface.go b/lnwallet/interface.go index e2e491c735..58db485949 100644 --- a/lnwallet/interface.go +++ b/lnwallet/interface.go @@ -1,6 +1,7 @@ package lnwallet import ( + "context" "errors" "fmt" "sync" @@ -750,10 +751,9 @@ func SupportedWallets() []string { return supportedWallets } -// FetchFundingTxWrapper is a wrapper around FetchFundingTx, except that it will -// exit when the supplied quit channel is closed. -func FetchFundingTxWrapper(chain BlockChainIO, chanID *lnwire.ShortChannelID, - quit chan struct{}) (*wire.MsgTx, error) { +// FetchFundingTxWrapper is a wrapper around FetchFundingTx. +func FetchFundingTxWrapper(ctx context.Context, chain BlockChainIO, + chanID *lnwire.ShortChannelID) (*wire.MsgTx, error) { txChan := make(chan *wire.MsgTx, 1) errChan := make(chan error, 1) @@ -775,9 +775,8 @@ func FetchFundingTxWrapper(chain BlockChainIO, chanID *lnwire.ShortChannelID, case err := <-errChan: return nil, err - case <-quit: - return nil, fmt.Errorf("quit channel passed to " + - "lnwallet.FetchFundingTxWrapper has been closed") + case <-ctx.Done(): + return nil, ctx.Err() } } @@ -815,13 +814,11 @@ func FetchFundingTx(chain BlockChainIO, return fundingBlock.Transactions[chanID.TxIndex].Copy(), nil } -// FetchPKScriptWithQuit fetches the output script for the given SCID and exits -// early with an error if the provided quit channel is closed before -// completion. -func FetchPKScriptWithQuit(chain BlockChainIO, chanID *lnwire.ShortChannelID, - quit chan struct{}) ([]byte, error) { +// FetchPKScript fetches the output script for the given SCID. +func FetchPKScript(ctx context.Context, chain BlockChainIO, + chanID *lnwire.ShortChannelID) ([]byte, error) { - tx, err := FetchFundingTxWrapper(chain, chanID, quit) + tx, err := FetchFundingTxWrapper(ctx, chain, chanID) if err != nil { return nil, err } diff --git a/netann/chan_status_manager.go b/netann/chan_status_manager.go index c4db4009dc..364bc7352f 100644 --- a/netann/chan_status_manager.go +++ b/netann/chan_status_manager.go @@ -1,6 +1,7 @@ package netann import ( + "context" "errors" "sync" "time" @@ -60,8 +61,8 @@ type ChanStatusConfig struct { // ApplyChannelUpdate processes new ChannelUpdates signed by our node by // updating our local routing table and broadcasting the update to our // peers. - ApplyChannelUpdate func(*lnwire.ChannelUpdate1, *wire.OutPoint, - bool) error + ApplyChannelUpdate func(context.Context, *lnwire.ChannelUpdate1, + *wire.OutPoint, bool) error // DB stores the set of channels that are to be monitored. DB DB @@ -171,16 +172,16 @@ func NewChanStatusManager(cfg *ChanStatusConfig) (*ChanStatusManager, error) { } // Start safely starts the ChanStatusManager. -func (m *ChanStatusManager) Start() error { +func (m *ChanStatusManager) Start(ctx context.Context) error { var err error m.started.Do(func() { log.Info("Channel Status Manager starting") - err = m.start() + err = m.start(ctx) }) return err } -func (m *ChanStatusManager) start() error { +func (m *ChanStatusManager) start(ctx context.Context) error { channels, err := m.fetchChannels() if err != nil { return err @@ -217,7 +218,7 @@ func (m *ChanStatusManager) start() error { } m.wg.Add(1) - go m.statusManager() + go m.statusManager(ctx) return nil } @@ -331,7 +332,7 @@ func (m *ChanStatusManager) submitRequest(reqChan chan statusRequest, // should be scheduled or broadcast. // // NOTE: This method MUST be run as a goroutine. -func (m *ChanStatusManager) statusManager() { +func (m *ChanStatusManager) statusManager(ctx context.Context) { defer m.wg.Done() for { @@ -339,11 +340,15 @@ func (m *ChanStatusManager) statusManager() { // Process any requests to mark channel as enabled. case req := <-m.enableRequests: - req.errChan <- m.processEnableRequest(req.outpoint, req.manual) + req.errChan <- m.processEnableRequest( + ctx, req.outpoint, req.manual, + ) // Process any requests to mark channel as disabled. case req := <-m.disableRequests: - req.errChan <- m.processDisableRequest(req.outpoint, req.manual) + req.errChan <- m.processDisableRequest( + ctx, req.outpoint, req.manual, + ) // Process any requests to restore automatic channel state management. case req := <-m.autoRequests: @@ -361,7 +366,7 @@ func (m *ChanStatusManager) statusManager() { // Now, do another sweep to disable any channels that // were marked in a prior iteration as pending inactive // if the inactive chan timeout has elapsed. - m.disableInactiveChannels() + m.disableInactiveChannels(ctx) case <-m.quit: return @@ -381,8 +386,8 @@ func (m *ChanStatusManager) statusManager() { // // An update will be broadcast only if the channel is currently disabled, // otherwise no update will be sent on the network. -func (m *ChanStatusManager) processEnableRequest(outpoint wire.OutPoint, - manual bool) error { +func (m *ChanStatusManager) processEnableRequest(ctx context.Context, + outpoint wire.OutPoint, manual bool) error { curState, err := m.getOrInitChanStatus(outpoint) if err != nil { @@ -421,7 +426,7 @@ func (m *ChanStatusManager) processEnableRequest(outpoint wire.OutPoint, case ChanStatusDisabled: log.Infof("Announcing channel(%v) enabled", outpoint) - err := m.signAndSendNextUpdate(outpoint, false) + err := m.signAndSendNextUpdate(ctx, outpoint, false) if err != nil { return err } @@ -439,8 +444,8 @@ func (m *ChanStatusManager) processEnableRequest(outpoint wire.OutPoint, // // An update will only be sent if the channel has a status other than // ChanStatusEnabled, otherwise no update will be sent on the network. -func (m *ChanStatusManager) processDisableRequest(outpoint wire.OutPoint, - manual bool) error { +func (m *ChanStatusManager) processDisableRequest(ctx context.Context, + outpoint wire.OutPoint, manual bool) error { curState, err := m.getOrInitChanStatus(outpoint) if err != nil { @@ -452,7 +457,7 @@ func (m *ChanStatusManager) processDisableRequest(outpoint wire.OutPoint, log.Infof("Announcing channel(%v) disabled [requested]", outpoint) - err := m.signAndSendNextUpdate(outpoint, true) + err := m.signAndSendNextUpdate(ctx, outpoint, true) if err != nil { return err } @@ -552,7 +557,7 @@ func (m *ChanStatusManager) markPendingInactiveChannels() { // disableInactiveChannels scans through the set of monitored channels, and // broadcast a disable update for any pending inactive channels whose // SendDisableTime has been superseded by the current time. -func (m *ChanStatusManager) disableInactiveChannels() { +func (m *ChanStatusManager) disableInactiveChannels(ctx context.Context) { // Now, disable any channels whose inactive chan timeout has elapsed. now := time.Now() for outpoint, state := range m.chanStates { @@ -571,7 +576,7 @@ func (m *ChanStatusManager) disableInactiveChannels() { "[detected]", outpoint) // Sign an update disabling the channel. - err := m.signAndSendNextUpdate(outpoint, true) + err := m.signAndSendNextUpdate(ctx, outpoint, true) if err != nil { log.Errorf("Unable to sign update disabling "+ "channel(%v): %v", outpoint, err) @@ -624,8 +629,8 @@ func (m *ChanStatusManager) fetchChannels() ([]*channeldb.OpenChannel, error) { // use the current time as the update's timestamp, or increment the old // timestamp by 1 to ensure the update can propagate. If signing is successful, // the new update will be sent out on the network. -func (m *ChanStatusManager) signAndSendNextUpdate(outpoint wire.OutPoint, - disabled bool) error { +func (m *ChanStatusManager) signAndSendNextUpdate(ctx context.Context, + outpoint wire.OutPoint, disabled bool) error { // Retrieve the latest update for this channel. We'll use this // as our starting point to send the new update. @@ -642,7 +647,7 @@ func (m *ChanStatusManager) signAndSendNextUpdate(outpoint wire.OutPoint, return err } - return m.cfg.ApplyChannelUpdate(chanUpdate, &outpoint, private) + return m.cfg.ApplyChannelUpdate(ctx, chanUpdate, &outpoint, private) } // fetchLastChanUpdateByOutPoint fetches the latest policy for our direction of diff --git a/netann/channel_announcement.go b/netann/channel_announcement.go index 9644a523ff..df176daf96 100644 --- a/netann/channel_announcement.go +++ b/netann/channel_announcement.go @@ -2,6 +2,7 @@ package netann import ( "bytes" + "context" "errors" "fmt" @@ -108,17 +109,17 @@ func CreateChanAnnouncement(chanProof *models.ChannelAuthProof, // FetchPkScript defines a function that can be used to fetch the output script // for the transaction with the given SCID. -type FetchPkScript func(*lnwire.ShortChannelID) ([]byte, error) +type FetchPkScript func(context.Context, *lnwire.ShortChannelID) ([]byte, error) // ValidateChannelAnn validates the channel announcement. -func ValidateChannelAnn(a lnwire.ChannelAnnouncement, +func ValidateChannelAnn(ctx context.Context, a lnwire.ChannelAnnouncement, fetchPkScript FetchPkScript) error { switch ann := a.(type) { case *lnwire.ChannelAnnouncement1: return validateChannelAnn1(ann) case *lnwire.ChannelAnnouncement2: - return validateChannelAnn2(ann, fetchPkScript) + return validateChannelAnn2(ctx, ann, fetchPkScript) default: return fmt.Errorf("unhandled implementation of "+ "lnwire.ChannelAnnouncement: %T", a) @@ -199,7 +200,7 @@ func validateChannelAnn1(a *lnwire.ChannelAnnouncement1) error { // validateChannelAnn2 validates the channel announcement message and checks // that message signature covers the announcement message. -func validateChannelAnn2(a *lnwire.ChannelAnnouncement2, +func validateChannelAnn2(ctx context.Context, a *lnwire.ChannelAnnouncement2, fetchPkScript FetchPkScript) error { dataHash, err := ChanAnn2DigestToSign(a) @@ -253,7 +254,7 @@ func validateChannelAnn2(a *lnwire.ChannelAnnouncement2, // If bitcoin keys are not provided, then we need to get the // on-chain output key since this will be the 3rd key in the // 3-of-3 MuSig2 signature. - pkScript, err := fetchPkScript(&a.ShortChannelID.Val) + pkScript, err := fetchPkScript(ctx, &a.ShortChannelID.Val) if err != nil { return err } diff --git a/peer/brontide.go b/peer/brontide.go index 50d5111016..fd3b9916f8 100644 --- a/peer/brontide.go +++ b/peer/brontide.go @@ -3,6 +3,7 @@ package peer import ( "bytes" "container/list" + "context" "errors" "fmt" "math/rand" @@ -691,7 +692,7 @@ func NewBrontide(cfg Config) *Brontide { // Start starts all helper goroutines the peer needs for normal operations. In // the case this peer has already been started, then this function is a noop. -func (p *Brontide) Start() error { +func (p *Brontide) Start(ctx context.Context) error { if atomic.AddInt32(&p.started, 1) != 1 { return nil } @@ -837,7 +838,7 @@ func (p *Brontide) Start() error { go p.queueHandler() go p.writeHandler() go p.channelManager() - go p.readHandler() + go p.readHandler(ctx) // Signal to any external processes that the peer is now active. close(p.activeSignal) @@ -862,7 +863,7 @@ func (p *Brontide) Start() error { // initGossipSync initializes either a gossip syncer or an initial routing // dump, depending on the negotiated synchronization method. -func (p *Brontide) initGossipSync() { +func (p *Brontide) initGossipSync(ctx context.Context) { // If the remote peer knows of the new gossip queries feature, then // we'll create a new gossipSyncer in the AuthenticatedGossiper for it. if p.remoteFeatures.HasFeature(lnwire.GossipQueriesOptional) { @@ -884,7 +885,7 @@ func (p *Brontide) initGossipSync() { // requires an improved version of the current network // bootstrapper to ensure we can find and connect to non-channel // peers. - p.cfg.AuthGossiper.InitSyncState(p) + p.cfg.AuthGossiper.InitSyncState(ctx, p) } } @@ -1586,7 +1587,7 @@ type msgStream struct { peer *Brontide - apply func(lnwire.Message) + apply func(context.Context, lnwire.Message) startMsg string stopMsg string @@ -1608,7 +1609,7 @@ type msgStream struct { // sane value that avoids blocking unnecessarily, but doesn't allow an // unbounded amount of memory to be allocated to buffer incoming messages. func newMsgStream(p *Brontide, startMsg, stopMsg string, bufSize uint32, - apply func(lnwire.Message)) *msgStream { + apply func(context.Context, lnwire.Message)) *msgStream { stream := &msgStream{ peer: p, @@ -1632,9 +1633,9 @@ func newMsgStream(p *Brontide, startMsg, stopMsg string, bufSize uint32, } // Start starts the chanMsgStream. -func (ms *msgStream) Start() { +func (ms *msgStream) Start(ctx context.Context) { ms.wg.Add(1) - go ms.msgConsumer() + go ms.msgConsumer(ctx) } // Stop stops the chanMsgStream. @@ -1655,7 +1656,7 @@ func (ms *msgStream) Stop() { // msgConsumer is the main goroutine that streams messages from the peer's // readHandler directly to the target channel. -func (ms *msgStream) msgConsumer() { +func (ms *msgStream) msgConsumer(ctx context.Context) { defer ms.wg.Done() defer peerLog.Tracef(ms.stopMsg) defer atomic.StoreInt32(&ms.streamShutdown, 1) @@ -1692,7 +1693,7 @@ func (ms *msgStream) msgConsumer() { ms.msgCond.L.Unlock() - ms.apply(msg) + ms.apply(ctx, msg) // We've just successfully processed an item, so we'll signal // to the producer that a new slot in the buffer. We'll use @@ -1810,7 +1811,7 @@ func waitUntilLinkActive(p *Brontide, func newChanMsgStream(p *Brontide, cid lnwire.ChannelID) *msgStream { var chanLink htlcswitch.ChannelUpdateHandler - apply := func(msg lnwire.Message) { + apply := func(ctx context.Context, msg lnwire.Message) { // This check is fine because if the link no longer exists, it will // be removed from the activeChannels map and subsequent messages // shouldn't reach the chan msg stream. @@ -1829,7 +1830,7 @@ func newChanMsgStream(p *Brontide, cid lnwire.ChannelID) *msgStream { // as the peer is exiting, we'll check quickly to see // if we need to exit. select { - case <-p.quit: + case <-ctx.Done(): return default: } @@ -1849,10 +1850,10 @@ func newChanMsgStream(p *Brontide, cid lnwire.ChannelID) *msgStream { // authenticated gossiper. This stream should be used to forward all remote // channel announcements. func newDiscMsgStream(p *Brontide) *msgStream { - apply := func(msg lnwire.Message) { + apply := func(ctx context.Context, msg lnwire.Message) { // TODO(yy): `ProcessRemoteAnnouncement` returns an error chan // and we need to process it. - p.cfg.AuthGossiper.ProcessRemoteAnnouncement(msg, p) + p.cfg.AuthGossiper.ProcessRemoteAnnouncement(ctx, msg, p) } return newMsgStream( @@ -1868,7 +1869,7 @@ func newDiscMsgStream(p *Brontide) *msgStream { // properly dispatching the handling of the message to the proper subsystem. // // NOTE: This method MUST be run as a goroutine. -func (p *Brontide) readHandler() { +func (p *Brontide) readHandler(ctx context.Context) { defer p.wg.Done() // We'll stop the timer after a new messages is received, and also @@ -1885,10 +1886,10 @@ func (p *Brontide) readHandler() { // // TODO(conner): have peer store gossip syncer directly and bypass // gossiper? - p.initGossipSync() + p.initGossipSync(ctx) discStream := newDiscMsgStream(p) - discStream.Start() + discStream.Start(ctx) defer discStream.Stop() out: for atomic.LoadInt32(&p.disconnect) == 0 { @@ -2076,7 +2077,7 @@ out: if isLinkUpdate { // If this is a channel update, then we need to feed it // into the channel's in-order message stream. - p.sendLinkUpdateMsg(targetChan, nextMsg) + p.sendLinkUpdateMsg(ctx, targetChan, nextMsg) } idleTimer.Reset(idleTimeout) @@ -4425,7 +4426,9 @@ func (p *Brontide) handleRemovePendingChannel(req *newChannelMsg) { // sendLinkUpdateMsg sends a message that updates the channel to the // channel's message stream. -func (p *Brontide) sendLinkUpdateMsg(cid lnwire.ChannelID, msg lnwire.Message) { +func (p *Brontide) sendLinkUpdateMsg(ctx context.Context, cid lnwire.ChannelID, + msg lnwire.Message) { + p.log.Tracef("Sending link update msg=%v", msg.MsgType()) chanStream, ok := p.activeMsgStreams[cid] @@ -4434,7 +4437,7 @@ func (p *Brontide) sendLinkUpdateMsg(cid lnwire.ChannelID, msg lnwire.Message) { // it to the map, and finally start it. chanStream = newChanMsgStream(p, cid) p.activeMsgStreams[cid] = chanStream - chanStream.Start() + chanStream.Start(ctx) // Stop the stream when quit. go func() { diff --git a/peer/test_utils.go b/peer/test_utils.go index 1a62355908..06ca8c1962 100644 --- a/peer/test_utils.go +++ b/peer/test_utils.go @@ -2,6 +2,7 @@ package peer import ( "bytes" + "context" crand "crypto/rand" "encoding/binary" "io" @@ -87,7 +88,7 @@ func createTestPeerWithChannel(t *testing.T, updateChan func(a, chanStatusMgr = params.chanStatusMgr ) - err := chanStatusMgr.Start() + err := chanStatusMgr.Start(context.Background()) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, chanStatusMgr.Stop()) @@ -619,7 +620,8 @@ func createTestPeer(t *testing.T) *peerTestCtx { IsChannelActive: func(lnwire.ChannelID) bool { return true }, - ApplyChannelUpdate: func(*lnwire.ChannelUpdate1, + ApplyChannelUpdate: func(context.Context, + *lnwire.ChannelUpdate1, *wire.OutPoint, bool) error { return nil @@ -758,7 +760,7 @@ func startPeer(t *testing.T, mockConn *mockMessageConn, // indicates a successful startup. done := make(chan struct{}) go func() { - require.NoError(t, peer.Start()) + require.NoError(t, peer.Start(context.Background())) close(done) }() diff --git a/pilot.go b/pilot.go index 2a37b080d0..710be087d6 100644 --- a/pilot.go +++ b/pilot.go @@ -1,6 +1,7 @@ package lnd import ( + "context" "errors" "fmt" "net" @@ -187,7 +188,9 @@ func initAutoPilot(svr *server, cfg *lncfg.AutoPilot, }, Graph: autopilot.ChannelGraphFromDatabase(svr.graphDB), Constraints: atplConstraints, - ConnectToPeer: func(target *btcec.PublicKey, addrs []net.Addr) (bool, error) { + ConnectToPeer: func(ctx context.Context, + target *btcec.PublicKey, addrs []net.Addr) (bool, error) { + // First, we'll check if we're already connected to the // target peer. If we are, we can exit early. Otherwise, // we'll need to establish a connection. @@ -224,7 +227,8 @@ func initAutoPilot(svr *server, cfg *lncfg.AutoPilot, } err := svr.ConnectToPeer( - lnAddr, false, svr.cfg.ConnectionTimeout, + ctx, lnAddr, false, + svr.cfg.ConnectionTimeout, ) if err != nil { // If we weren't able to connect to the diff --git a/routing/localchans/manager.go b/routing/localchans/manager.go index f0f9b88de0..84f75d728c 100644 --- a/routing/localchans/manager.go +++ b/routing/localchans/manager.go @@ -1,6 +1,7 @@ package localchans import ( + "context" "errors" "fmt" "sync" @@ -26,7 +27,7 @@ type Manager struct { // PropagateChanPolicyUpdate is called to persist a new policy to disk // and broadcast it to the network. - PropagateChanPolicyUpdate func( + PropagateChanPolicyUpdate func(ctx context.Context, edgesToUpdate []discovery.EdgeWithInfo) error // ForAllOutgoingChannels is required to iterate over all our local @@ -50,7 +51,8 @@ type Manager struct { // UpdatePolicy updates the policy for the specified channels on disk and in // the active links. -func (r *Manager) UpdatePolicy(newSchema routing.ChannelPolicy, +func (r *Manager) UpdatePolicy(ctx context.Context, + newSchema routing.ChannelPolicy, chanPoints ...wire.OutPoint) ([]*lnrpc.FailedUpdate, error) { r.policyUpdateLock.Lock() @@ -169,7 +171,7 @@ func (r *Manager) UpdatePolicy(newSchema routing.ChannelPolicy, // this would happen because of a bug, the link policy will be // desynchronized. It is currently not possible to atomically commit // multiple edge updates. - err = r.PropagateChanPolicyUpdate(edgesToUpdate) + err = r.PropagateChanPolicyUpdate(ctx, edgesToUpdate) if err != nil { return nil, err } diff --git a/rpcserver.go b/rpcserver.go index 43d922c383..6a52296d58 100644 --- a/rpcserver.go +++ b/rpcserver.go @@ -1811,7 +1811,7 @@ func (r *rpcServer) ConnectPeer(ctx context.Context, } if err := r.server.ConnectToPeer( - peerAddr, in.Perm, timeout, + ctx, peerAddr, in.Perm, timeout, ); err != nil { rpcsLog.Errorf("[connectpeer]: error connecting to peer: %v", err) @@ -7761,8 +7761,9 @@ func (r *rpcServer) UpdateChannelPolicy(ctx context.Context, // With the scope resolved, we'll now send this to the local channel // manager so it can propagate the new policy for our target channel(s). - failedUpdates, err := r.server.localChanMgr.UpdatePolicy(chanPolicy, - targetChans...) + failedUpdates, err := r.server.localChanMgr.UpdatePolicy( + ctx, chanPolicy, targetChans..., + ) if err != nil { return nil, err } @@ -8187,7 +8188,7 @@ func (r *rpcServer) RestoreChannelBackups(ctx context.Context, // out to any peers that we know of which were our prior // channel peers. numRestored, err = chanbackup.UnpackAndRecoverSingles( - chanbackup.PackedSingles(packedBackups), + ctx, chanbackup.PackedSingles(packedBackups), r.server.cc.KeyRing, chanRestorer, r.server, ) if err != nil { @@ -8204,7 +8205,7 @@ func (r *rpcServer) RestoreChannelBackups(ctx context.Context, // channel peers. packedMulti := chanbackup.PackedMulti(packedMultiBackup) numRestored, err = chanbackup.UnpackAndRecoverMulti( - packedMulti, r.server.cc.KeyRing, chanRestorer, + ctx, packedMulti, r.server.cc.KeyRing, chanRestorer, r.server, ) if err != nil { diff --git a/server.go b/server.go index c46c1a5f9e..866a31c56e 100644 --- a/server.go +++ b/server.go @@ -494,7 +494,7 @@ func noiseDial(idKey keychain.SingleKeyECDH, // newServer creates a new instance of the server which is to listen using the // passed listener address. -func newServer(cfg *Config, listenAddrs []net.Addr, +func newServer(ctx context.Context, cfg *Config, listenAddrs []net.Addr, dbs *DatabaseInstances, cc *chainreg.ChainControl, nodeKeyDesc *keychain.KeyDescriptor, chansToRestore walletunlocker.ChannelsToRecover, @@ -1773,14 +1773,18 @@ func newServer(cfg *Config, listenAddrs []net.Addr, // maintaining persistent outbound connections and also accepting new // incoming connections cmgr, err := connmgr.New(&connmgr.Config{ - Listeners: listeners, - OnAccept: s.InboundPeerConnected, + Listeners: listeners, + OnAccept: func(conn net.Conn) { + s.InboundPeerConnected(ctx, conn) + }, RetryDuration: time.Second * 5, TargetOutbound: 100, Dial: noiseDial( nodeKeyECDH, s.cfg.net, s.cfg.ConnectionTimeout, ), - OnConnection: s.OutboundPeerConnected, + OnConnection: func(req *connmgr.ConnReq, conn net.Conn) { + s.OutboundPeerConnected(ctx, req, conn) + }, }) if err != nil { return nil, err @@ -2046,7 +2050,7 @@ func (c cleaner) run() { // NOTE: This function is safe for concurrent access. // //nolint:funlen -func (s *server) Start() error { +func (s *server) Start(ctx context.Context) error { var startErr error // If one sub system fails to start, the following code ensures that the @@ -2165,7 +2169,7 @@ func (s *server) Start() error { } cleanup = cleanup.add(s.fundingMgr.Stop) - if err := s.fundingMgr.Start(); err != nil { + if err := s.fundingMgr.Start(ctx); err != nil { startErr = err return } @@ -2198,7 +2202,7 @@ func (s *server) Start() error { } cleanup = cleanup.add(s.graphBuilder.Stop) - if err := s.graphBuilder.Start(); err != nil { + if err := s.graphBuilder.Start(ctx); err != nil { startErr = err return } @@ -2211,7 +2215,7 @@ func (s *server) Start() error { // The authGossiper depends on the chanRouter and therefore // should be started after it. cleanup = cleanup.add(s.authGossiper.Stop) - if err := s.authGossiper.Start(); err != nil { + if err := s.authGossiper.Start(ctx); err != nil { startErr = err return } @@ -2229,7 +2233,7 @@ func (s *server) Start() error { } cleanup = cleanup.add(s.chanStatusMgr.Stop) - if err := s.chanStatusMgr.Start(); err != nil { + if err := s.chanStatusMgr.Start(ctx); err != nil { startErr = err return } @@ -2257,7 +2261,7 @@ func (s *server) Start() error { } if len(s.chansToRestore.PackedSingleChanBackups) != 0 { _, err := chanbackup.UnpackAndRecoverSingles( - s.chansToRestore.PackedSingleChanBackups, + ctx, s.chansToRestore.PackedSingleChanBackups, s.cc.KeyRing, chanRestorer, s, ) if err != nil { @@ -2268,7 +2272,7 @@ func (s *server) Start() error { } if len(s.chansToRestore.PackedMultiChanBackup) != 0 { _, err := chanbackup.UnpackAndRecoverMulti( - s.chansToRestore.PackedMultiChanBackup, + ctx, s.chansToRestore.PackedMultiChanBackup, s.cc.KeyRing, chanRestorer, s, ) if err != nil { @@ -2333,7 +2337,7 @@ func (s *server) Start() error { } err = s.ConnectToPeer( - peerAddr, true, + ctx, peerAddr, true, s.cfg.ConnectionTimeout, ) if err != nil { @@ -2428,7 +2432,9 @@ func (s *server) Start() error { } s.wg.Add(1) - go s.peerBootstrapper(defaultMinPeers, bootstrappers) + go s.peerBootstrapper( + ctx, defaultMinPeers, bootstrappers, + ) } else { srvrLog.Infof("Auto peer bootstrapping is disabled") } @@ -2448,7 +2454,7 @@ func (s *server) Start() error { // any active goroutines, or helper objects to exit, then blocks until they've // all successfully exited. Additionally, any/all listeners are closed. // NOTE: This function is safe for concurrent access. -func (s *server) Stop() error { +func (s *server) Stop(ctx context.Context) error { s.stop.Do(func() { atomic.StoreInt32(&s.stopping, 1) @@ -2855,7 +2861,7 @@ func (s *server) createBootstrapIgnorePeers() map[autopilot.NodeID]struct{} { // invariant, we ensure that our node is connected to a diverse set of peers // and that nodes newly joining the network receive an up to date network view // as soon as possible. -func (s *server) peerBootstrapper(numTargetPeers uint32, +func (s *server) peerBootstrapper(ctx context.Context, numTargetPeers uint32, bootstrappers []discovery.NetworkPeerBootstrapper) { defer s.wg.Done() @@ -2865,7 +2871,7 @@ func (s *server) peerBootstrapper(numTargetPeers uint32, // We'll start off by aggressively attempting connections to peers in // order to be a part of the network as soon as possible. - s.initialPeerBootstrap(ignoreList, numTargetPeers, bootstrappers) + s.initialPeerBootstrap(ctx, ignoreList, numTargetPeers, bootstrappers) // Once done, we'll attempt to maintain our target minimum number of // peers. @@ -2961,7 +2967,7 @@ func (s *server) peerBootstrapper(numTargetPeers uint32, // country diversity, etc errChan := make(chan error, 1) s.connectToPeer( - a, errChan, + ctx, a, errChan, s.cfg.ConnectionTimeout, ) select { @@ -2992,8 +2998,8 @@ const bootstrapBackOffCeiling = time.Minute * 5 // initialPeerBootstrap attempts to continuously connect to peers on startup // until the target number of peers has been reached. This ensures that nodes // receive an up to date network view as soon as possible. -func (s *server) initialPeerBootstrap(ignore map[autopilot.NodeID]struct{}, - numTargetPeers uint32, +func (s *server) initialPeerBootstrap(ctx context.Context, + ignore map[autopilot.NodeID]struct{}, numTargetPeers uint32, bootstrappers []discovery.NetworkPeerBootstrapper) { srvrLog.Debugf("Init bootstrap with targetPeers=%v, bootstrappers=%v, "+ @@ -3070,7 +3076,8 @@ func (s *server) initialPeerBootstrap(ignore map[autopilot.NodeID]struct{}, errChan := make(chan error, 1) go s.connectToPeer( - addr, errChan, s.cfg.ConnectionTimeout, + ctx, addr, errChan, + s.cfg.ConnectionTimeout, ) // We'll only allow this connection attempt to @@ -3748,7 +3755,7 @@ func shouldDropLocalConnection(local, remote *btcec.PublicKey) bool { // connection. // // NOTE: This function is safe for concurrent access. -func (s *server) InboundPeerConnected(conn net.Conn) { +func (s *server) InboundPeerConnected(ctx context.Context, conn net.Conn) { // Exit early if we have already been instructed to shutdown, this // prevents any delayed callbacks from accidentally registering peers. if s.Stopped() { @@ -3818,7 +3825,7 @@ func (s *server) InboundPeerConnected(conn net.Conn) { // We were unable to locate an existing connection with the // target peer, proceed to connect. s.cancelConnReqs(pubStr, nil) - s.peerConnected(conn, nil, true) + s.peerConnected(ctx, conn, nil, true) case nil: // We already have a connection with the incoming peer. If the @@ -3850,7 +3857,7 @@ func (s *server) InboundPeerConnected(conn net.Conn) { s.removePeer(connectedPeer) s.ignorePeerTermination[connectedPeer] = struct{}{} s.scheduledPeerConnection[pubStr] = func() { - s.peerConnected(conn, nil, true) + s.peerConnected(ctx, conn, nil, true) } } } @@ -3858,7 +3865,9 @@ func (s *server) InboundPeerConnected(conn net.Conn) { // OutboundPeerConnected initializes a new peer in response to a new outbound // connection. // NOTE: This function is safe for concurrent access. -func (s *server) OutboundPeerConnected(connReq *connmgr.ConnReq, conn net.Conn) { +func (s *server) OutboundPeerConnected(ctx context.Context, + connReq *connmgr.ConnReq, conn net.Conn) { + // Exit early if we have already been instructed to shutdown, this // prevents any delayed callbacks from accidentally registering peers. if s.Stopped() { @@ -3956,7 +3965,7 @@ func (s *server) OutboundPeerConnected(connReq *connmgr.ConnReq, conn net.Conn) case ErrPeerNotConnected: // We were unable to locate an existing connection with the // target peer, proceed to connect. - s.peerConnected(conn, connReq, false) + s.peerConnected(ctx, conn, connReq, false) case nil: // We already have a connection with the incoming peer. If the @@ -3990,7 +3999,7 @@ func (s *server) OutboundPeerConnected(connReq *connmgr.ConnReq, conn net.Conn) s.removePeer(connectedPeer) s.ignorePeerTermination[connectedPeer] = struct{}{} s.scheduledPeerConnection[pubStr] = func() { - s.peerConnected(conn, connReq, false) + s.peerConnected(ctx, conn, connReq, false) } } } @@ -4068,8 +4077,8 @@ func (s *server) SubscribeCustomMessages() (*subscribe.Client, error) { // peer by adding it to the server's global list of all active peers, and // starting all the goroutines the peer needs to function properly. The inbound // boolean should be true if the peer initiated the connection to us. -func (s *server) peerConnected(conn net.Conn, connReq *connmgr.ConnReq, - inbound bool) { +func (s *server) peerConnected(ctx context.Context, conn net.Conn, + connReq *connmgr.ConnReq, inbound bool) { brontideConn := conn.(*brontide.Conn) addr := conn.RemoteAddr() @@ -4213,7 +4222,7 @@ func (s *server) peerConnected(conn net.Conn, connReq *connmgr.ConnReq, // includes sending and receiving Init messages, which would be a DOS // vector if we held the server's mutex throughout the procedure. s.wg.Add(1) - go s.peerInitializer(p) + go s.peerInitializer(ctx, p) } // addPeer adds the passed peer to the server's global state of all active @@ -4268,7 +4277,7 @@ func (s *server) addPeer(p *peer.Brontide) { // be signaled of the new peer once the method returns. // // NOTE: This MUST be launched as a goroutine. -func (s *server) peerInitializer(p *peer.Brontide) { +func (s *server) peerInitializer(ctx context.Context, p *peer.Brontide) { defer s.wg.Done() pubBytes := p.IdentityKey().SerializeCompressed() @@ -4292,11 +4301,11 @@ func (s *server) peerInitializer(p *peer.Brontide) { // the peer is ever added to the ignorePeerTermination map, indicating // that the server has already handled the removal of this peer. s.wg.Add(1) - go s.peerTerminationWatcher(p, ready) + go s.peerTerminationWatcher(ctx, p, ready) // Start the peer! If an error occurs, we Disconnect the peer, which // will unblock the peerTerminationWatcher. - if err := p.Start(); err != nil { + if err := p.Start(ctx); err != nil { srvrLog.Warnf("Starting peer=%x got error: %v", pubBytes, err) p.Disconnect(fmt.Errorf("unable to start peer: %w", err)) @@ -4337,7 +4346,9 @@ func (s *server) peerInitializer(p *peer.Brontide) { // successfully, otherwise the peer should be disconnected instead. // // NOTE: This MUST be launched as a goroutine. -func (s *server) peerTerminationWatcher(p *peer.Brontide, ready chan struct{}) { +func (s *server) peerTerminationWatcher(ctx context.Context, p *peer.Brontide, + ready chan struct{}) { + defer s.wg.Done() p.WaitForDisconnect(ready) @@ -4361,7 +4372,7 @@ func (s *server) peerTerminationWatcher(p *peer.Brontide, ready chan struct{}) { // We'll also inform the gossiper that this peer is no longer active, // so we don't need to maintain sync state for it any longer. - s.authGossiper.PruneSyncState(p.PubKey()) + s.authGossiper.PruneSyncState(ctx, p.PubKey()) // Tell the switch to remove all links associated with this peer. // Passing nil as the target link indicates that all links associated @@ -4675,7 +4686,7 @@ func (s *server) removePeer(p *peer.Brontide) { // connection is established, or the initial handshake process fails. // // NOTE: This function is safe for concurrent access. -func (s *server) ConnectToPeer(addr *lnwire.NetAddress, +func (s *server) ConnectToPeer(ctx context.Context, addr *lnwire.NetAddress, perm bool, timeout time.Duration) error { targetPub := string(addr.IdentityKey.SerializeCompressed()) @@ -4737,7 +4748,7 @@ func (s *server) ConnectToPeer(addr *lnwire.NetAddress, // the crypto negotiation breaks down, then return an error to the // caller. errChan := make(chan error, 1) - s.connectToPeer(addr, errChan, timeout) + s.connectToPeer(ctx, addr, errChan, timeout) select { case err := <-errChan: @@ -4750,7 +4761,7 @@ func (s *server) ConnectToPeer(addr *lnwire.NetAddress, // connectToPeer establishes a connection to a remote peer. errChan is used to // notify the caller if the connection attempt has failed. Otherwise, it will be // closed. -func (s *server) connectToPeer(addr *lnwire.NetAddress, +func (s *server) connectToPeer(ctx context.Context, addr *lnwire.NetAddress, errChan chan<- error, timeout time.Duration) { conn, err := brontide.Dial( @@ -4770,7 +4781,7 @@ func (s *server) connectToPeer(addr *lnwire.NetAddress, srvrLog.Tracef("Brontide dialer made local=%v, remote=%v", conn.LocalAddr(), conn.RemoteAddr()) - s.OutboundPeerConnected(nil, conn) + s.OutboundPeerConnected(ctx, nil, conn) } // DisconnectPeer sends the request to server to close the connection with peer @@ -4955,8 +4966,8 @@ func (s *server) fetchLastChanUpdate() func(lnwire.ShortChannelID) ( // applyChannelUpdate applies the channel update to the different sub-systems of // the server. The useAlias boolean denotes whether or not to send an alias in // place of the real SCID. -func (s *server) applyChannelUpdate(update *lnwire.ChannelUpdate1, - op *wire.OutPoint, useAlias bool) error { +func (s *server) applyChannelUpdate(ctx context.Context, + update *lnwire.ChannelUpdate1, op *wire.OutPoint, useAlias bool) error { var ( peerAlias *lnwire.ShortChannelID @@ -4975,7 +4986,7 @@ func (s *server) applyChannelUpdate(update *lnwire.ChannelUpdate1, } errChan := s.authGossiper.ProcessLocalAnnouncement( - update, discovery.RemoteAlias(peerAlias), + ctx, update, discovery.RemoteAlias(peerAlias), ) select { case err := <-errChan: