diff --git a/chanbackup/backup.go b/chanbackup/backup.go index 5853b37e45..8f53185134 100644 --- a/chanbackup/backup.go +++ b/chanbackup/backup.go @@ -1,6 +1,7 @@ package chanbackup import ( + "context" "fmt" "github.com/btcsuite/btcd/wire" @@ -24,7 +25,7 @@ type LiveChannelSource interface { // passed open channel. The backup includes all information required to restore // the channel, as well as addressing information so we can find the peer and // reconnect to them to initiate the protocol. -func assembleChanBackup(addrSource channeldb.AddrSource, +func assembleChanBackup(ctx context.Context, addrSource channeldb.AddrSource, openChan *channeldb.OpenChannel) (*Single, error) { log.Debugf("Crafting backup for ChannelPoint(%v)", @@ -32,7 +33,9 @@ func assembleChanBackup(addrSource channeldb.AddrSource, // First, we'll query the channel source to obtain all the addresses // that are associated with the peer for this channel. - known, nodeAddrs, err := addrSource.AddrsForNode(openChan.IdentityPub) + known, nodeAddrs, err := addrSource.AddrsForNode( + ctx, openChan.IdentityPub, + ) if err != nil { return nil, err } @@ -90,7 +93,8 @@ func buildCloseTxInputs( // FetchBackupForChan attempts to create a plaintext static channel backup for // the target channel identified by its channel point. If we're unable to find // the target channel, then an error will be returned. -func FetchBackupForChan(chanPoint wire.OutPoint, chanSource LiveChannelSource, +func FetchBackupForChan(ctx context.Context, chanPoint wire.OutPoint, + chanSource LiveChannelSource, addrSource channeldb.AddrSource) (*Single, error) { // First, we'll query the channel source to see if the channel is known @@ -104,7 +108,7 @@ func FetchBackupForChan(chanPoint wire.OutPoint, chanSource LiveChannelSource, // Once we have the target channel, we can assemble the backup using // the source to obtain any extra information that we may need. - staticChanBackup, err := assembleChanBackup(addrSource, targetChan) + staticChanBackup, err := assembleChanBackup(ctx, addrSource, targetChan) if err != nil { return nil, fmt.Errorf("unable to create chan backup: %w", err) } @@ -114,7 +118,7 @@ func FetchBackupForChan(chanPoint wire.OutPoint, chanSource LiveChannelSource, // FetchStaticChanBackups will return a plaintext static channel back up for // all known active/open channels within the passed channel source. -func FetchStaticChanBackups(chanSource LiveChannelSource, +func FetchStaticChanBackups(ctx context.Context, chanSource LiveChannelSource, addrSource channeldb.AddrSource) ([]Single, error) { // First, we'll query the backup source for information concerning all @@ -129,7 +133,7 @@ func FetchStaticChanBackups(chanSource LiveChannelSource, // channel. staticChanBackups := make([]Single, 0, len(openChans)) for _, openChan := range openChans { - chanBackup, err := assembleChanBackup(addrSource, openChan) + chanBackup, err := assembleChanBackup(ctx, addrSource, openChan) if err != nil { return nil, err } diff --git a/chanbackup/backup_test.go b/chanbackup/backup_test.go index 46ccf4c244..2da50b7e6f 100644 --- a/chanbackup/backup_test.go +++ b/chanbackup/backup_test.go @@ -1,6 +1,7 @@ package chanbackup import ( + "context" "fmt" "net" "testing" @@ -61,8 +62,8 @@ func (m *mockChannelSource) addAddrsForNode(nodePub *btcec.PublicKey, addrs []ne m.addrs[nodeKey] = addrs } -func (m *mockChannelSource) AddrsForNode(nodePub *btcec.PublicKey) (bool, - []net.Addr, error) { +func (m *mockChannelSource) AddrsForNode(_ context.Context, + nodePub *btcec.PublicKey) (bool, []net.Addr, error) { if m.failQuery { return false, nil, fmt.Errorf("fail") @@ -81,6 +82,7 @@ func (m *mockChannelSource) AddrsForNode(nodePub *btcec.PublicKey) (bool, // can find addresses for and otherwise. func TestFetchBackupForChan(t *testing.T) { t.Parallel() + ctx := context.Background() // First, we'll make two channels, only one of them will have all the // information we need to construct set of backups for them. @@ -120,7 +122,7 @@ func TestFetchBackupForChan(t *testing.T) { } for i, testCase := range testCases { _, err := FetchBackupForChan( - testCase.chanPoint, chanSource, chanSource, + ctx, testCase.chanPoint, chanSource, chanSource, ) switch { // If this is a valid test case, and we failed, then we'll @@ -141,6 +143,7 @@ func TestFetchBackupForChan(t *testing.T) { // channel source for all channels and construct a Single for each channel. func TestFetchStaticChanBackups(t *testing.T) { t.Parallel() + ctx := context.Background() // First, we'll make the set of channels that we want to seed the // channel source with. Both channels will be fully populated in the @@ -160,7 +163,7 @@ func TestFetchStaticChanBackups(t *testing.T) { // With the channel source populated, we'll now attempt to create a set // of backups for all the channels. This should succeed, as all items // are populated within the channel source. - backups, err := FetchStaticChanBackups(chanSource, chanSource) + backups, err := FetchStaticChanBackups(ctx, chanSource, chanSource) require.NoError(t, err, "unable to create chan back ups") if len(backups) != numChans { @@ -175,7 +178,7 @@ func TestFetchStaticChanBackups(t *testing.T) { copy(n[:], randomChan2.IdentityPub.SerializeCompressed()) delete(chanSource.addrs, n) - _, err = FetchStaticChanBackups(chanSource, chanSource) + _, err = FetchStaticChanBackups(ctx, chanSource, chanSource) if err == nil { t.Fatalf("query with incomplete information should fail") } @@ -184,7 +187,7 @@ func TestFetchStaticChanBackups(t *testing.T) { // source at all, then we'll fail as well. chanSource = newMockChannelSource() chanSource.failQuery = true - _, err = FetchStaticChanBackups(chanSource, chanSource) + _, err = FetchStaticChanBackups(ctx, chanSource, chanSource) if err == nil { t.Fatalf("query should fail") } diff --git a/chanbackup/pubsub.go b/chanbackup/pubsub.go index 8fa1d5f348..2b5872f187 100644 --- a/chanbackup/pubsub.go +++ b/chanbackup/pubsub.go @@ -2,6 +2,7 @@ package chanbackup import ( "bytes" + "context" "fmt" "net" "os" @@ -81,7 +82,8 @@ type ChannelNotifier interface { // synchronization point to ensure that the chanbackup.SubSwapper does // not miss any channel open or close events in the period between when // it's created, and when it requests the channel subscription. - SubscribeChans(map[wire.OutPoint]struct{}) (*ChannelSubscription, error) + SubscribeChans(context.Context, map[wire.OutPoint]struct{}) ( + *ChannelSubscription, error) } // SubSwapper subscribes to new updates to the open channel state, and then @@ -119,8 +121,9 @@ type SubSwapper struct { // set of channels, and the required interfaces to be notified of new channel // updates, pack a multi backup, and swap the current best backup from its // storage location. -func NewSubSwapper(startingChans []Single, chanNotifier ChannelNotifier, - keyRing keychain.KeyRing, backupSwapper Swapper) (*SubSwapper, error) { +func NewSubSwapper(ctx context.Context, startingChans []Single, + chanNotifier ChannelNotifier, keyRing keychain.KeyRing, + backupSwapper Swapper) (*SubSwapper, error) { // First, we'll subscribe to the latest set of channel updates given // the set of channels we already know of. @@ -128,7 +131,7 @@ func NewSubSwapper(startingChans []Single, chanNotifier ChannelNotifier, for _, chanBackup := range startingChans { knownChans[chanBackup.FundingOutpoint] = struct{}{} } - chanEvents, err := chanNotifier.SubscribeChans(knownChans) + chanEvents, err := chanNotifier.SubscribeChans(ctx, knownChans) if err != nil { return nil, err } diff --git a/chanbackup/pubsub_test.go b/chanbackup/pubsub_test.go index 32694e5a75..c134b91fc0 100644 --- a/chanbackup/pubsub_test.go +++ b/chanbackup/pubsub_test.go @@ -1,6 +1,7 @@ package chanbackup import ( + "context" "fmt" "testing" "time" @@ -62,8 +63,8 @@ func newMockChannelNotifier() *mockChannelNotifier { } } -func (m *mockChannelNotifier) SubscribeChans(chans map[wire.OutPoint]struct{}) ( - *ChannelSubscription, error) { +func (m *mockChannelNotifier) SubscribeChans(_ context.Context, + _ map[wire.OutPoint]struct{}) (*ChannelSubscription, error) { if m.fail { return nil, fmt.Errorf("fail") @@ -80,6 +81,7 @@ func (m *mockChannelNotifier) SubscribeChans(chans map[wire.OutPoint]struct{}) ( // channel subscription, then the entire sub-swapper will fail to start. func TestNewSubSwapperSubscribeFail(t *testing.T) { t.Parallel() + ctx := context.Background() keyRing := &lnencrypt.MockKeyRing{} @@ -88,7 +90,7 @@ func TestNewSubSwapperSubscribeFail(t *testing.T) { fail: true, } - _, err := NewSubSwapper(nil, &chanNotifier, keyRing, &swapper) + _, err := NewSubSwapper(ctx, nil, &chanNotifier, keyRing, &swapper) if err == nil { t.Fatalf("expected fail due to lack of subscription") } @@ -152,13 +154,16 @@ func assertExpectedBackupSwap(t *testing.T, swapper *mockSwapper, // multiple time is permitted. func TestSubSwapperIdempotentStartStop(t *testing.T) { t.Parallel() + ctx := context.Background() keyRing := &lnencrypt.MockKeyRing{} var chanNotifier mockChannelNotifier swapper := newMockSwapper(keyRing) - subSwapper, err := NewSubSwapper(nil, &chanNotifier, keyRing, swapper) + subSwapper, err := NewSubSwapper( + ctx, nil, &chanNotifier, keyRing, swapper, + ) require.NoError(t, err, "unable to init subSwapper") if err := subSwapper.Start(); err != nil { @@ -181,6 +186,7 @@ func TestSubSwapperIdempotentStartStop(t *testing.T) { // the master multi file backup. func TestSubSwapperUpdater(t *testing.T) { t.Parallel() + ctx := context.Background() keyRing := &lnencrypt.MockKeyRing{} chanNotifier := newMockChannelNotifier() @@ -224,7 +230,7 @@ func TestSubSwapperUpdater(t *testing.T) { // With our channel set created, we'll make a fresh sub swapper // instance to begin our test. subSwapper, err := NewSubSwapper( - initialChanSet, chanNotifier, keyRing, swapper, + ctx, initialChanSet, chanNotifier, keyRing, swapper, ) require.NoError(t, err, "unable to make swapper") if err := subSwapper.Start(); err != nil { diff --git a/chanbackup/recover.go b/chanbackup/recover.go index 033bd695f2..daaad62487 100644 --- a/chanbackup/recover.go +++ b/chanbackup/recover.go @@ -1,6 +1,7 @@ package chanbackup import ( + "context" "net" "github.com/btcsuite/btcd/btcec/v2" @@ -29,7 +30,8 @@ type PeerConnector interface { // available addresses. Once this method returns with a non-nil error, // the connector should attempt to persistently connect to the target // peer in the background as a persistent attempt. - ConnectPeer(node *btcec.PublicKey, addrs []net.Addr) error + ConnectPeer(ctx context.Context, node *btcec.PublicKey, + addrs []net.Addr) error } // Recover attempts to recover the static channel state from a set of static @@ -41,7 +43,7 @@ type PeerConnector interface { // well, in order to expose the addressing information required to locate to // and connect to each peer in order to initiate the recovery protocol. // The number of channels that were successfully restored is returned. -func Recover(backups []Single, restorer ChannelRestorer, +func Recover(ctx context.Context, backups []Single, restorer ChannelRestorer, peerConnector PeerConnector) (int, error) { var numRestored int @@ -70,7 +72,7 @@ func Recover(backups []Single, restorer ChannelRestorer, backup.FundingOutpoint) err = peerConnector.ConnectPeer( - backup.RemoteNodePub, backup.Addresses, + ctx, backup.RemoteNodePub, backup.Addresses, ) if err != nil { return numRestored, err @@ -95,7 +97,7 @@ func Recover(backups []Single, restorer ChannelRestorer, // established, then the PeerConnector will continue to attempt to re-establish // a persistent connection in the background. The number of channels that were // successfully restored is returned. -func UnpackAndRecoverSingles(singles PackedSingles, +func UnpackAndRecoverSingles(ctx context.Context, singles PackedSingles, keyChain keychain.KeyRing, restorer ChannelRestorer, peerConnector PeerConnector) (int, error) { @@ -104,7 +106,7 @@ func UnpackAndRecoverSingles(singles PackedSingles, return 0, err } - return Recover(chanBackups, restorer, peerConnector) + return Recover(ctx, chanBackups, restorer, peerConnector) } // UnpackAndRecoverMulti is a one-shot method, that given a set of packed @@ -114,7 +116,7 @@ func UnpackAndRecoverSingles(singles PackedSingles, // established, then the PeerConnector will continue to attempt to re-establish // a persistent connection in the background. The number of channels that were // successfully restored is returned. -func UnpackAndRecoverMulti(packedMulti PackedMulti, +func UnpackAndRecoverMulti(ctx context.Context, packedMulti PackedMulti, keyChain keychain.KeyRing, restorer ChannelRestorer, peerConnector PeerConnector) (int, error) { @@ -123,5 +125,5 @@ func UnpackAndRecoverMulti(packedMulti PackedMulti, return 0, err } - return Recover(chanBackups.StaticBackups, restorer, peerConnector) + return Recover(ctx, chanBackups.StaticBackups, restorer, peerConnector) } diff --git a/chanbackup/recover_test.go b/chanbackup/recover_test.go index c8719cb3fe..c3a7f45e16 100644 --- a/chanbackup/recover_test.go +++ b/chanbackup/recover_test.go @@ -2,6 +2,7 @@ package chanbackup import ( "bytes" + "context" "errors" "net" "testing" @@ -39,7 +40,7 @@ type mockPeerConnector struct { callCount int } -func (m *mockPeerConnector) ConnectPeer(_ *btcec.PublicKey, +func (m *mockPeerConnector) ConnectPeer(_ context.Context, _ *btcec.PublicKey, _ []net.Addr) error { if m.fail { @@ -55,6 +56,7 @@ func (m *mockPeerConnector) ConnectPeer(_ *btcec.PublicKey, // recover a set of packed singles. func TestUnpackAndRecoverSingles(t *testing.T) { t.Parallel() + ctx := context.Background() keyRing := &lnencrypt.MockKeyRing{} @@ -87,7 +89,7 @@ func TestUnpackAndRecoverSingles(t *testing.T) { // as well chanRestorer.fail = true _, err := UnpackAndRecoverSingles( - packedBackups, keyRing, &chanRestorer, &peerConnector, + ctx, packedBackups, keyRing, &chanRestorer, &peerConnector, ) require.ErrorIs(t, err, errRestoreFail) @@ -97,7 +99,7 @@ func TestUnpackAndRecoverSingles(t *testing.T) { // well peerConnector.fail = true _, err = UnpackAndRecoverSingles( - packedBackups, keyRing, &chanRestorer, &peerConnector, + ctx, packedBackups, keyRing, &chanRestorer, &peerConnector, ) require.ErrorIs(t, err, errConnectFail) @@ -107,7 +109,7 @@ func TestUnpackAndRecoverSingles(t *testing.T) { // Next, we'll ensure that if all the interfaces function as expected, // then the channels will properly be unpacked and restored. numRestored, err := UnpackAndRecoverSingles( - packedBackups, keyRing, &chanRestorer, &peerConnector, + ctx, packedBackups, keyRing, &chanRestorer, &peerConnector, ) require.NoError(t, err) require.EqualValues(t, numSingles, numRestored) @@ -124,7 +126,7 @@ func TestUnpackAndRecoverSingles(t *testing.T) { // If we modify the keyRing, then unpacking should fail. keyRing.Fail = true _, err = UnpackAndRecoverSingles( - packedBackups, keyRing, &chanRestorer, &peerConnector, + ctx, packedBackups, keyRing, &chanRestorer, &peerConnector, ) require.ErrorContains(t, err, "fail") @@ -135,7 +137,7 @@ func TestUnpackAndRecoverSingles(t *testing.T) { // recover a packed multi. func TestUnpackAndRecoverMulti(t *testing.T) { t.Parallel() - + ctx := context.Background() keyRing := &lnencrypt.MockKeyRing{} // First, we'll create a number of single chan backups that we'll @@ -171,7 +173,7 @@ func TestUnpackAndRecoverMulti(t *testing.T) { // as well chanRestorer.fail = true _, err = UnpackAndRecoverMulti( - packedMulti, keyRing, &chanRestorer, &peerConnector, + ctx, packedMulti, keyRing, &chanRestorer, &peerConnector, ) require.ErrorIs(t, err, errRestoreFail) @@ -181,7 +183,7 @@ func TestUnpackAndRecoverMulti(t *testing.T) { // well peerConnector.fail = true _, err = UnpackAndRecoverMulti( - packedMulti, keyRing, &chanRestorer, &peerConnector, + ctx, packedMulti, keyRing, &chanRestorer, &peerConnector, ) require.ErrorIs(t, err, errConnectFail) @@ -191,7 +193,7 @@ func TestUnpackAndRecoverMulti(t *testing.T) { // Next, we'll ensure that if all the interfaces function as expected, // then the channels will properly be unpacked and restored. numRestored, err := UnpackAndRecoverMulti( - packedMulti, keyRing, &chanRestorer, &peerConnector, + ctx, packedMulti, keyRing, &chanRestorer, &peerConnector, ) require.NoError(t, err) require.EqualValues(t, numSingles, numRestored) @@ -208,7 +210,7 @@ func TestUnpackAndRecoverMulti(t *testing.T) { // If we modify the keyRing, then unpacking should fail. keyRing.Fail = true _, err = UnpackAndRecoverMulti( - packedMulti, keyRing, &chanRestorer, &peerConnector, + ctx, packedMulti, keyRing, &chanRestorer, &peerConnector, ) require.ErrorContains(t, err, "fail") diff --git a/channel_notifier.go b/channel_notifier.go index 88a05ac4ce..8affd48f08 100644 --- a/channel_notifier.go +++ b/channel_notifier.go @@ -1,6 +1,7 @@ package lnd import ( + "context" "fmt" "github.com/btcsuite/btcd/wire" @@ -31,7 +32,8 @@ type channelNotifier struct { // the channel subscription. // // NOTE: This is part of the chanbackup.ChannelNotifier interface. -func (c *channelNotifier) SubscribeChans(startingChans map[wire.OutPoint]struct{}) ( +func (c *channelNotifier) SubscribeChans(ctx context.Context, + startingChans map[wire.OutPoint]struct{}) ( *chanbackup.ChannelSubscription, error) { ltndLog.Infof("Channel backup proxy channel notifier starting") @@ -46,7 +48,7 @@ func (c *channelNotifier) SubscribeChans(startingChans map[wire.OutPoint]struct{ // confirmed channels. sendChanOpenUpdate := func(newOrPendingChan *channeldb.OpenChannel) { _, nodeAddrs, err := c.addrs.AddrsForNode( - newOrPendingChan.IdentityPub, + ctx, newOrPendingChan.IdentityPub, ) if err != nil { pub := newOrPendingChan.IdentityPub diff --git a/channeldb/addr_source.go b/channeldb/addr_source.go index de933ed496..f6b1760909 100644 --- a/channeldb/addr_source.go +++ b/channeldb/addr_source.go @@ -1,6 +1,7 @@ package channeldb import ( + "context" "errors" "net" @@ -13,7 +14,8 @@ type AddrSource interface { // AddrsForNode returns all known addresses for the target node public // key. The returned boolean must indicate if the given node is unknown // to the backing source. - AddrsForNode(nodePub *btcec.PublicKey) (bool, []net.Addr, error) + AddrsForNode(ctx context.Context, nodePub *btcec.PublicKey) (bool, + []net.Addr, error) } // multiAddrSource is an implementation of AddrSource which gathers all the @@ -38,8 +40,8 @@ func NewMultiAddrSource(sources ...AddrSource) AddrSource { // node. // // NOTE: this implements the AddrSource interface. -func (c *multiAddrSource) AddrsForNode(nodePub *btcec.PublicKey) (bool, - []net.Addr, error) { +func (c *multiAddrSource) AddrsForNode(ctx context.Context, + nodePub *btcec.PublicKey) (bool, []net.Addr, error) { if len(c.sources) == 0 { return false, nil, errors.New("no address sources") @@ -55,7 +57,7 @@ func (c *multiAddrSource) AddrsForNode(nodePub *btcec.PublicKey) (bool, // Iterate over all the address sources and query each one for the // addresses it has for the node in question. for _, src := range c.sources { - isKnown, addrs, err := src.AddrsForNode(nodePub) + isKnown, addrs, err := src.AddrsForNode(ctx, nodePub) if err != nil { return false, nil, err } diff --git a/channeldb/addr_source_test.go b/channeldb/addr_source_test.go index 85ee30bf53..19f46aebe4 100644 --- a/channeldb/addr_source_test.go +++ b/channeldb/addr_source_test.go @@ -1,6 +1,7 @@ package channeldb import ( + "context" "net" "testing" @@ -20,7 +21,10 @@ var ( func TestMultiAddrSource(t *testing.T) { t.Parallel() - var pk1 = newTestPubKey(t) + var ( + ctx = context.Background() + pk1 = newTestPubKey(t) + ) t.Run("both sources have results", func(t *testing.T) { t.Parallel() @@ -35,12 +39,12 @@ func TestMultiAddrSource(t *testing.T) { }) // Let source 1 know of 2 addresses (addr 1 and 2) for node 1. - src1.On("AddrsForNode", pk1).Return( + src1.On("AddrsForNode", ctx, pk1).Return( true, []net.Addr{addr1, addr2}, nil, ).Once() // Let source 2 know of 2 addresses (addr 2 and 3) for node 1. - src2.On("AddrsForNode", pk1).Return( + src2.On("AddrsForNode", ctx, pk1).Return( true, []net.Addr{addr2, addr3}, nil, []net.Addr{addr2, addr3}, nil, ).Once() @@ -51,7 +55,7 @@ func TestMultiAddrSource(t *testing.T) { // Query it for the addresses known for node 1. The results // should contain addr 1, 2 and 3. - known, addrs, err := multiSrc.AddrsForNode(pk1) + known, addrs, err := multiSrc.AddrsForNode(ctx, pk1) require.NoError(t, err) require.True(t, known) require.ElementsMatch(t, addrs, []net.Addr{addr1, addr2, addr3}) @@ -70,10 +74,10 @@ func TestMultiAddrSource(t *testing.T) { }) // Let source 1 know of address 1 for node 1. - src1.On("AddrsForNode", pk1).Return( + src1.On("AddrsForNode", ctx, pk1).Return( true, []net.Addr{addr1}, nil, ).Once() - src2.On("AddrsForNode", pk1).Return(false, nil, nil).Once() + src2.On("AddrsForNode", ctx, pk1).Return(false, nil, nil).Once() // Create a multi-addr source that consists of both source 1 // and 2. @@ -81,7 +85,7 @@ func TestMultiAddrSource(t *testing.T) { // Query it for the addresses known for node 1. The results // should contain addr 1. - known, addrs, err := multiSrc.AddrsForNode(pk1) + known, addrs, err := multiSrc.AddrsForNode(ctx, pk1) require.NoError(t, err) require.True(t, known) require.ElementsMatch(t, addrs, []net.Addr{addr1}) @@ -103,13 +107,13 @@ func TestMultiAddrSource(t *testing.T) { // and 2. Neither source known of node 1. multiSrc := NewMultiAddrSource(src1, src2) - src1.On("AddrsForNode", pk1).Return(false, nil, nil).Once() - src2.On("AddrsForNode", pk1).Return(false, nil, nil).Once() + src1.On("AddrsForNode", ctx, pk1).Return(false, nil, nil).Once() + src2.On("AddrsForNode", ctx, pk1).Return(false, nil, nil).Once() // Query it for the addresses known for node 1. It should return // false to indicate that the node is unknown to all backing // sources. - known, addrs, err := multiSrc.AddrsForNode(pk1) + known, addrs, err := multiSrc.AddrsForNode(ctx, pk1) require.NoError(t, err) require.False(t, known) require.Empty(t, addrs) @@ -127,10 +131,10 @@ func newMockAddrSource(t *testing.T) *mockAddrSource { return &mockAddrSource{t: t} } -func (m *mockAddrSource) AddrsForNode(pub *btcec.PublicKey) (bool, []net.Addr, - error) { +func (m *mockAddrSource) AddrsForNode(ctx context.Context, + pub *btcec.PublicKey) (bool, []net.Addr, error) { - args := m.Called(pub) + args := m.Called(ctx, pub) if args.Get(1) == nil { return args.Bool(0), nil, args.Error(2) } diff --git a/channeldb/db.go b/channeldb/db.go index bf7909ba52..f884cd4a69 100644 --- a/channeldb/db.go +++ b/channeldb/db.go @@ -2,6 +2,7 @@ package channeldb import ( "bytes" + "context" "encoding/binary" "fmt" "net" @@ -1344,7 +1345,9 @@ func (c *ChannelStateDB) RestoreChannelShells(channelShells ...*ChannelShell) er // unknown to the channel DB or not. // // NOTE: this is part of the AddrSource interface. -func (d *DB) AddrsForNode(nodePub *btcec.PublicKey) (bool, []net.Addr, error) { +func (d *DB) AddrsForNode(_ context.Context, nodePub *btcec.PublicKey) (bool, + []net.Addr, error) { + linkNode, err := d.channelStateDB.linkNodeDB.FetchLinkNode(nodePub) // Only if the error is something other than ErrNodeNotFound do we // return it. diff --git a/channeldb/db_test.go b/channeldb/db_test.go index 9fed9934ba..740e293386 100644 --- a/channeldb/db_test.go +++ b/channeldb/db_test.go @@ -1,6 +1,7 @@ package channeldb import ( + "context" "image/color" "math" "math/rand" @@ -182,6 +183,8 @@ func TestFetchClosedChannelForID(t *testing.T) { func TestMultiSourceAddrsForNode(t *testing.T) { t.Parallel() + ctx := context.Background() + fullDB, err := MakeTestDB(t) require.NoError(t, err, "unable to make test database") @@ -212,7 +215,7 @@ func TestMultiSourceAddrsForNode(t *testing.T) { // Now that we've created a link node, as well as a vertex for the // node, we'll query for all its addresses. - known, nodeAddrs, err := addrSource.AddrsForNode(nodePub) + known, nodeAddrs, err := addrSource.AddrsForNode(ctx, nodePub) require.NoError(t, err, "unable to obtain node addrs") require.True(t, known) diff --git a/chanrestore.go b/chanrestore.go index 5b221c105a..6daf3922c9 100644 --- a/chanrestore.go +++ b/chanrestore.go @@ -1,6 +1,7 @@ package lnd import ( + "context" "fmt" "math" "net" @@ -309,7 +310,9 @@ var _ chanbackup.ChannelRestorer = (*chanDBRestorer)(nil) // as a persistent attempt. // // NOTE: Part of the chanbackup.PeerConnector interface. -func (s *server) ConnectPeer(nodePub *btcec.PublicKey, addrs []net.Addr) error { +func (s *server) ConnectPeer(ctx context.Context, nodePub *btcec.PublicKey, + addrs []net.Addr) error { + // Before we connect to the remote peer, we'll remove any connections // to ensure the new connection is created after this new link/channel // is known. @@ -333,7 +336,9 @@ func (s *server) ConnectPeer(nodePub *btcec.PublicKey, addrs []net.Addr) error { // Attempt to connect to the peer using this full address. If // we're unable to connect to them, then we'll try the next // address in place of it. - err := s.ConnectToPeer(netAddr, true, s.cfg.ConnectionTimeout) + err := s.ConnectToPeer( + ctx, netAddr, true, s.cfg.ConnectionTimeout, + ) // If we're already connected to this peer, then we don't // consider this an error, so we'll exit here. diff --git a/config.go b/config.go index 04a4917658..ef0a8fc1e3 100644 --- a/config.go +++ b/config.go @@ -1802,6 +1802,7 @@ func (c *Config) ImplementationConfig( ), WalletConfigBuilder: rpcImpl, ChainControlBuilder: rpcImpl, + GraphProvider: rpcImpl, } } @@ -1813,6 +1814,7 @@ func (c *Config) ImplementationConfig( DatabaseBuilder: NewDefaultDatabaseBuilder(c, ltndLog), WalletConfigBuilder: defaultImpl, ChainControlBuilder: defaultImpl, + GraphProvider: defaultImpl, } } diff --git a/config_builder.go b/config_builder.go index 42650bb68b..30173a6de5 100644 --- a/config_builder.go +++ b/config_builder.go @@ -36,6 +36,7 @@ import ( "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/funding" graphdb "github.com/lightningnetwork/lnd/graph/db" + "github.com/lightningnetwork/lnd/graph/sources" "github.com/lightningnetwork/lnd/invoices" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/kvdb" @@ -125,6 +126,14 @@ type ChainControlBuilder interface { *btcwallet.Config) (*chainreg.ChainControl, func(), error) } +// GraphProvider is an interface that must be satisfied by any external system +// that wants to provide LND with graph information. +type GraphProvider interface { + // Graph returns the GraphSource that LND will use for read-only graph + // related queries. + Graph(context.Context, *DatabaseInstances) (sources.GraphSource, error) +} + // ImplementationCfg is a struct that holds all configuration items for // components that can be implemented outside lnd itself. type ImplementationCfg struct { @@ -155,6 +164,10 @@ type ImplementationCfg struct { // AuxComponents is a set of auxiliary components that can be used by // lnd for certain custom channel types. AuxComponents + + // GraphProvider is a type that can provide a custom GraphSource for LND + // to use for read-only graph calls. + GraphProvider } // AuxComponents is a set of auxiliary components that can be used by lnd for @@ -249,6 +262,17 @@ func (d *DefaultWalletImpl) RegisterGrpcSubserver(s *grpc.Server) error { return nil } +// Graph returns the GraphSource that LND will use for read-only graph related +// queries. By default, the GraphSource implementation is this LND node's +// backing graphdb.ChannelGraph. +// +// NOTE: this is part of the GraphProvider interface. +func (d *DefaultWalletImpl) Graph(_ context.Context, + dbs *DatabaseInstances) (sources.GraphSource, error) { + + return sources.NewDBGSource(dbs.GraphDB), nil +} + // ValidateMacaroon extracts the macaroon from the context's gRPC metadata, // checks its signature, makes sure all specified permissions for the called // method are contained within and finally ensures all caveat conditions are diff --git a/discovery/bootstrapper.go b/discovery/bootstrapper.go index 0d370663d6..8eea900a92 100644 --- a/discovery/bootstrapper.go +++ b/discovery/bootstrapper.go @@ -2,6 +2,7 @@ package discovery import ( "bytes" + "context" "crypto/rand" "crypto/sha256" "errors" @@ -36,12 +37,13 @@ type NetworkPeerBootstrapper interface { // denotes how many valid peer addresses to return. The passed set of // node nodes allows the caller to ignore a set of nodes perhaps // because they already have connections established. - SampleNodeAddrs(numAddrs uint32, - ignore map[autopilot.NodeID]struct{}) ([]*lnwire.NetAddress, error) + SampleNodeAddrs(ctx context.Context, numAddrs uint32, + ignore map[autopilot.NodeID]struct{}) ([]*lnwire.NetAddress, + error) // Name returns a human readable string which names the concrete // implementation of the NetworkPeerBootstrapper. - Name() string + Name(ctx context.Context) string } // MultiSourceBootstrap attempts to utilize a set of NetworkPeerBootstrapper @@ -50,7 +52,8 @@ type NetworkPeerBootstrapper interface { // bootstrapper will be queried successively until the target amount is met. If // the ignore map is populated, then the bootstrappers will be instructed to // skip those nodes. -func MultiSourceBootstrap(ignore map[autopilot.NodeID]struct{}, numAddrs uint32, +func MultiSourceBootstrap(ctx context.Context, + ignore map[autopilot.NodeID]struct{}, numAddrs uint32, bootstrappers ...NetworkPeerBootstrapper) ([]*lnwire.NetAddress, error) { // We'll randomly shuffle our bootstrappers before querying them in @@ -67,19 +70,23 @@ func MultiSourceBootstrap(ignore map[autopilot.NodeID]struct{}, numAddrs uint32, break } - log.Infof("Attempting to bootstrap with: %v", bootstrapper.Name()) + name := bootstrapper.Name(ctx) + + log.Infof("Attempting to bootstrap with: %v", name) // If we still need additional addresses, then we'll compute // the number of address remaining that we need to fetch. numAddrsLeft := numAddrs - uint32(len(addrs)) log.Tracef("Querying for %v addresses", numAddrsLeft) - netAddrs, err := bootstrapper.SampleNodeAddrs(numAddrsLeft, ignore) + netAddrs, err := bootstrapper.SampleNodeAddrs( + ctx, numAddrsLeft, ignore, + ) if err != nil { // If we encounter an error with a bootstrapper, then // we'll continue on to the next available // bootstrapper. - log.Errorf("Unable to query bootstrapper %v: %v", - bootstrapper.Name(), err) + log.Errorf("Unable to query bootstrapper %v: %v", name, + err) continue } @@ -152,8 +159,9 @@ func NewGraphBootstrapper(cg autopilot.ChannelGraph) (NetworkPeerBootstrapper, e // many valid peer addresses to return. // // NOTE: Part of the NetworkPeerBootstrapper interface. -func (c *ChannelGraphBootstrapper) SampleNodeAddrs(numAddrs uint32, - ignore map[autopilot.NodeID]struct{}) ([]*lnwire.NetAddress, error) { +func (c *ChannelGraphBootstrapper) SampleNodeAddrs(_ context.Context, + numAddrs uint32, ignore map[autopilot.NodeID]struct{}) ( + []*lnwire.NetAddress, error) { // We'll merge the ignore map with our currently selected map in order // to ensure we don't return any duplicate nodes. @@ -269,7 +277,7 @@ func (c *ChannelGraphBootstrapper) SampleNodeAddrs(numAddrs uint32, // of the NetworkPeerBootstrapper. // // NOTE: Part of the NetworkPeerBootstrapper interface. -func (c *ChannelGraphBootstrapper) Name() string { +func (c *ChannelGraphBootstrapper) Name(_ context.Context) string { return "Authenticated Channel Graph" } @@ -382,8 +390,9 @@ func (d *DNSSeedBootstrapper) fallBackSRVLookup(soaShim string, // network peer bootstrapper source. The num addrs field passed in denotes how // many valid peer addresses to return. The set of DNS seeds are used // successively to retrieve eligible target nodes. -func (d *DNSSeedBootstrapper) SampleNodeAddrs(numAddrs uint32, - ignore map[autopilot.NodeID]struct{}) ([]*lnwire.NetAddress, error) { +func (d *DNSSeedBootstrapper) SampleNodeAddrs(_ context.Context, + numAddrs uint32, ignore map[autopilot.NodeID]struct{}) ( + []*lnwire.NetAddress, error) { var netAddrs []*lnwire.NetAddress @@ -532,6 +541,6 @@ search: // Name returns a human readable string which names the concrete // implementation of the NetworkPeerBootstrapper. -func (d *DNSSeedBootstrapper) Name() string { +func (d *DNSSeedBootstrapper) Name(_ context.Context) string { return fmt.Sprintf("BOLT-0010 DNS Seed: %v", d.dnsSeeds) } diff --git a/docs/release-notes/release-notes-0.19.0.md b/docs/release-notes/release-notes-0.19.0.md index 377255bbf9..f1a9d3a0a7 100644 --- a/docs/release-notes/release-notes-0.19.0.md +++ b/docs/release-notes/release-notes-0.19.0.md @@ -204,6 +204,11 @@ The underlying functionality between those two options remain the same. `channeldb` package](https://github.com/lightningnetwork/lnd/pull/9236) and into the `graph/db` package. +* Add a [graph source abstraction](https://github.com/lightningnetwork/lnd/pull/9243) + and use it throughout LND. This is so that callers of LND can choose provide + it with an external graph source rather than requiring it to first sync its + own graph. + ## Tooling and Documentation * [Improved `lncli create` command help text](https://github.com/lightningnetwork/lnd/pull/9077) diff --git a/graph/builder.go b/graph/builder.go index 8c2ba2e3b8..658ee61993 100644 --- a/graph/builder.go +++ b/graph/builder.go @@ -1629,18 +1629,6 @@ func (b *Builder) FetchLightningNode( return b.cfg.Graph.FetchLightningNode(node) } -// ForEachNode is used to iterate over every node in router topology. -// -// NOTE: This method is part of the ChannelGraphSource interface. -func (b *Builder) ForEachNode( - cb func(*models.LightningNode) error) error { - - return b.cfg.Graph.ForEachNode( - func(_ kvdb.RTx, n *models.LightningNode) error { - return cb(n) - }) -} - // ForAllOutgoingChannels is used to iterate over all outgoing channels owned by // the router. // diff --git a/graph/db/graph.go b/graph/db/graph.go index fc1b26ad0f..9f240de98e 100644 --- a/graph/db/graph.go +++ b/graph/db/graph.go @@ -2,6 +2,7 @@ package graphdb import ( "bytes" + "context" "crypto/sha256" "encoding/binary" "errors" @@ -418,8 +419,8 @@ func (c *ChannelGraph) NewPathFindTx() (kvdb.RTx, error) { // unknown to the graph DB or not. // // NOTE: this is part of the channeldb.AddrSource interface. -func (c *ChannelGraph) AddrsForNode(nodePub *btcec.PublicKey) (bool, []net.Addr, - error) { +func (c *ChannelGraph) AddrsForNode(_ context.Context, + nodePub *btcec.PublicKey) (bool, []net.Addr, error) { pubKey, err := route.NewVertexFromBytes(nodePub.SerializeCompressed()) if err != nil { @@ -503,9 +504,12 @@ func (c *ChannelGraph) ForEachChannel(cb func(*models.ChannelEdgeInfo, // ForEachNodeDirectedChannel iterates through all channels of a given node, // executing the passed callback on the directed edge representing the channel // and its incoming policy. If the callback returns an error, then the iteration -// is halted with the error propagated back up to the caller. +// is halted with the error propagated back up to the caller. An optional read +// transaction may be provided. If none is provided, a new one will be created. // // Unknown policies are passed into the callback as nil values. +// +// NOTE: this is part of the graphsession.graph interface. func (c *ChannelGraph) ForEachNodeDirectedChannel(tx kvdb.RTx, node route.Vertex, cb func(channel *DirectedChannel) error) error { @@ -517,7 +521,7 @@ func (c *ChannelGraph) ForEachNodeDirectedChannel(tx kvdb.RTx, toNodeCallback := func() route.Vertex { return node } - toNodeFeatures, err := c.FetchNodeFeatures(node) + toNodeFeatures, err := c.FetchNodeFeatures(tx, node) if err != nil { return err } @@ -562,8 +566,11 @@ func (c *ChannelGraph) ForEachNodeDirectedChannel(tx kvdb.RTx, } // FetchNodeFeatures returns the features of a given node. If no features are -// known for the node, an empty feature vector is returned. -func (c *ChannelGraph) FetchNodeFeatures( +// known for the node, an empty feature vector is returned. An optional read +// transaction may be provided. If none is provided, a new one will be created. +// +// NOTE: this is part of the graphsession.graph interface. +func (c *ChannelGraph) FetchNodeFeatures(tx kvdb.RTx, node route.Vertex) (*lnwire.FeatureVector, error) { if c.graphCache != nil { @@ -571,7 +578,7 @@ func (c *ChannelGraph) FetchNodeFeatures( } // Fallback that uses the database. - targetNode, err := c.FetchLightningNode(node) + targetNode, err := c.FetchLightningNodeTx(tx, node) switch err { // If the node exists and has features, return them directly. case nil: @@ -618,7 +625,7 @@ func (c *ChannelGraph) ForEachNodeCached(cb func(node route.Vertex, return node.PubKeyBytes } toNodeFeatures, err := c.FetchNodeFeatures( - node.PubKeyBytes, + tx, node.PubKeyBytes, ) if err != nil { return err diff --git a/graph/db/models/stats.go b/graph/db/models/stats.go new file mode 100644 index 0000000000..a737fe7d4a --- /dev/null +++ b/graph/db/models/stats.go @@ -0,0 +1,47 @@ +package models + +import "github.com/btcsuite/btcd/btcutil" + +// NetworkStats represents various statistics about the state of the Lightning +// network graph. +type NetworkStats struct { + // Diameter is the diameter of the graph, which is the length of the + // longest shortest path between any two nodes in the graph. + Diameter uint32 + + // MaxChanOut is the maximum number of outgoing channels from a single + // node. + MaxChanOut uint32 + + // NumNodes is the total number of nodes in the graph. + NumNodes uint32 + + // NumChannels is the total number of channels in the graph. + NumChannels uint32 + + // TotalNetworkCapacity is the total capacity of all channels in the + // graph. + TotalNetworkCapacity btcutil.Amount + + // MinChanSize is the smallest channel size in the graph. + MinChanSize btcutil.Amount + + // MaxChanSize is the largest channel size in the graph. + MaxChanSize btcutil.Amount + + // MedianChanSize is the median channel size in the graph. + MedianChanSize btcutil.Amount + + // NumZombies is the number of zombie channels in the graph. + NumZombies uint64 +} + +// BetweennessCentrality represents the betweenness centrality of a node in the +// graph. +type BetweennessCentrality struct { + // Normalized is the normalized betweenness centrality of a node. + Normalized float64 + + // NonNormalized is the non-normalized betweenness centrality of a node. + NonNormalized float64 +} diff --git a/graph/interfaces.go b/graph/interfaces.go index eb7f56603a..10ca200f37 100644 --- a/graph/interfaces.go +++ b/graph/interfaces.go @@ -85,9 +85,6 @@ type ChannelGraphSource interface { // public key. channeldb.ErrGraphNodeNotFound is returned if the node // doesn't exist within the graph. FetchLightningNode(route.Vertex) (*models.LightningNode, error) - - // ForEachNode is used to iterate over every node in the known graph. - ForEachNode(func(node *models.LightningNode) error) error } // DB is an interface describing a persisted Lightning Network graph. @@ -241,12 +238,6 @@ type DB interface { FetchLightningNode(nodePub route.Vertex) (*models.LightningNode, error) - // ForEachNode iterates through all the stored vertices/nodes in the - // graph, executing the passed callback with each node encountered. If - // the callback returns an error, then the transaction is aborted and - // the iteration stops early. - ForEachNode(cb func(kvdb.RTx, *models.LightningNode) error) error - // ForEachNodeChannel iterates through all channels of the given node, // executing the passed callback with an edge info structure and the // policies of each end of the channel. The first edge policy is the diff --git a/graph/graphsession/graph_session.go b/graph/session/graph_session.go similarity index 81% rename from graph/graphsession/graph_session.go rename to graph/session/graph_session.go index 6976fad79b..d1617bb593 100644 --- a/graph/graphsession/graph_session.go +++ b/graph/session/graph_session.go @@ -1,10 +1,10 @@ -package graphsession +package session import ( + "context" "fmt" graphdb "github.com/lightningnetwork/lnd/graph/db" - "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing" "github.com/lightningnetwork/lnd/routing/route" @@ -30,8 +30,10 @@ func NewGraphSessionFactory(graph ReadOnlyGraph) routing.GraphSessionFactory { // was created at Graph construction time. // // NOTE: This is part of the routing.GraphSessionFactory interface. -func (g *Factory) NewGraphSession() (routing.Graph, func() error, error) { - tx, err := g.graph.NewPathFindTx() +func (g *Factory) NewGraphSession(ctx context.Context) (routing.Graph, + func() error, error) { + + tx, err := g.graph.NewPathFindTx(ctx) if err != nil { return nil, nil, err } @@ -53,7 +55,7 @@ var _ routing.GraphSessionFactory = (*Factory)(nil) // access the backing channel graph. type session struct { graph graph - tx kvdb.RTx + tx RTx } // NewRoutingGraph constructs a session that which does not first start a @@ -72,7 +74,7 @@ func (g *session) close() error { return nil } - err := g.tx.Rollback() + err := g.tx.Close() if err != nil { return fmt.Errorf("error closing db tx: %w", err) } @@ -83,20 +85,20 @@ func (g *session) close() error { // ForEachNodeChannel calls the callback for every channel of the given node. // // NOTE: Part of the routing.Graph interface. -func (g *session) ForEachNodeChannel(nodePub route.Vertex, +func (g *session) ForEachNodeChannel(ctx context.Context, nodePub route.Vertex, cb func(channel *graphdb.DirectedChannel) error) error { - return g.graph.ForEachNodeDirectedChannel(g.tx, nodePub, cb) + return g.graph.ForEachNodeDirectedChannel(ctx, g.tx, nodePub, cb) } // FetchNodeFeatures returns the features of the given node. If the node is // unknown, assume no additional features are supported. // // NOTE: Part of the routing.Graph interface. -func (g *session) FetchNodeFeatures(nodePub route.Vertex) ( - *lnwire.FeatureVector, error) { +func (g *session) FetchNodeFeatures(ctx context.Context, + nodePub route.Vertex) (*lnwire.FeatureVector, error) { - return g.graph.FetchNodeFeatures(nodePub) + return g.graph.FetchNodeFeatures(ctx, g.tx, nodePub) } // A compile-time check to ensure that *session implements the @@ -109,7 +111,7 @@ type ReadOnlyGraph interface { // NewPathFindTx returns a new read transaction that can be used for a // single path finding session. Will return nil if the graph cache is // enabled. - NewPathFindTx() (kvdb.RTx, error) + NewPathFindTx(ctx context.Context) (RTx, error) graph } @@ -128,14 +130,15 @@ type graph interface { // // NOTE: if a nil tx is provided, then it is expected that the // implementation create a read only tx. - ForEachNodeDirectedChannel(tx kvdb.RTx, node route.Vertex, + ForEachNodeDirectedChannel(ctx context.Context, tx RTx, + node route.Vertex, cb func(channel *graphdb.DirectedChannel) error) error // FetchNodeFeatures returns the features of a given node. If no // features are known for the node, an empty feature vector is returned. - FetchNodeFeatures(node route.Vertex) (*lnwire.FeatureVector, error) + // + // NOTE: if a nil tx is provided, then it is expected that the + // implementation create a read only tx. + FetchNodeFeatures(ctx context.Context, tx RTx, node route.Vertex) ( + *lnwire.FeatureVector, error) } - -// A compile-time check to ensure that *channeldb.ChannelGraph implements the -// graph interface. -var _ graph = (*graphdb.ChannelGraph)(nil) diff --git a/graph/session/read_tx.go b/graph/session/read_tx.go new file mode 100644 index 0000000000..f675e59db9 --- /dev/null +++ b/graph/session/read_tx.go @@ -0,0 +1,14 @@ +package session + +// RTx represents a read-only transaction that can only be used for graph +// reads during a path-finding session. +type RTx interface { + // Close closes the transaction. + Close() error + + // MustImplementRTx is a helper method that ensures that the RTx + // interface is implemented by the underlying type. This is useful since + // the other methods in the interface are quite generic and so many + // types will satisfy the interface if it only contains those methods. + MustImplementRTx() +} diff --git a/graph/sources/chan_graph.go b/graph/sources/chan_graph.go new file mode 100644 index 0000000000..99ee0f0eb9 --- /dev/null +++ b/graph/sources/chan_graph.go @@ -0,0 +1,449 @@ +package sources + +import ( + "context" + "fmt" + "math" + "net" + "runtime" + "time" + + "github.com/btcsuite/btcd/btcec/v2" + "github.com/btcsuite/btcd/btcutil" + "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/autopilot" + "github.com/lightningnetwork/lnd/discovery" + graphdb "github.com/lightningnetwork/lnd/graph/db" + "github.com/lightningnetwork/lnd/graph/db/models" + "github.com/lightningnetwork/lnd/graph/session" + "github.com/lightningnetwork/lnd/kvdb" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/routing/route" +) + +// DBSource is an implementation of the GraphSource interface backed by a local +// persistence layer holding graph related data. +type DBSource struct { + db *graphdb.ChannelGraph +} + +// A compile-time check to ensure that sources.DBSource implements the +// GraphSource interface. +var _ GraphSource = (*DBSource)(nil) + +// NewDBGSource returns a new instance of the DBSource backed by a +// graphdb.ChannelGraph instance. +func NewDBGSource(db *graphdb.ChannelGraph) *DBSource { + return &DBSource{ + db: db, + } +} + +// NewPathFindTx returns a new read transaction that can be used for a single +// path finding session. Will return nil if the graph cache is enabled for the +// underlying graphdb.ChannelGraph. +// +// NOTE: this is part of the session.ReadOnlyGraph interface. +func (s *DBSource) NewPathFindTx(_ context.Context) (session.RTx, error) { + tx, err := s.db.NewPathFindTx() + if err != nil { + return nil, err + } + + return newKVDBRTx(tx), nil +} + +// ForEachNodeDirectedChannel iterates through all channels of a given node, +// executing the passed callback on the directed edge representing the channel +// and its incoming policy. If the callback returns an error, then the +// iteration is halted with the error propagated back up to the caller. An +// optional read transaction may be provided. If it is, then it will be cast +// into a kvdb.RTx and passed into the callback. +// +// Unknown policies are passed into the callback as nil values. +// +// NOTE: this is part of the session.ReadOnlyGraph interface. +func (s *DBSource) ForEachNodeDirectedChannel(_ context.Context, tx session.RTx, + node route.Vertex, + cb func(channel *graphdb.DirectedChannel) error) error { + + kvdbTx, err := extractKVDBRTx(tx) + if err != nil { + return err + } + + return s.db.ForEachNodeDirectedChannel(kvdbTx, node, cb) +} + +// FetchNodeFeatures returns the features of a given node. If no features are +// known for the node, an empty feature vector is returned. An optional read +// transaction may be provided. If it is, then it will be cast into a kvdb.RTx +// and passed into the callback. +// +// NOTE: this is part of the graphsession.ReadOnlyGraph interface. +func (s *DBSource) FetchNodeFeatures(_ context.Context, tx session.RTx, + node route.Vertex) (*lnwire.FeatureVector, error) { + + kvdbTx, err := extractKVDBRTx(tx) + if err != nil { + return nil, err + } + + return s.db.FetchNodeFeatures(kvdbTx, node) +} + +// FetchChannelEdgesByID attempts to look up the two directed edges for the +// channel identified by the channel ID. If the channel can't be found, then +// graphdb.ErrEdgeNotFound is returned. +// +// NOTE: this is part of the invoicesrpc.GraphSource interface. +func (s *DBSource) FetchChannelEdgesByID(_ context.Context, + chanID uint64) (*models.ChannelEdgeInfo, *models.ChannelEdgePolicy, + *models.ChannelEdgePolicy, error) { + + return s.db.FetchChannelEdgesByID(chanID) +} + +// IsPublicNode determines whether the node with the given public key is seen as +// a public node in the graph from the graph's source node's point of view. +// +// NOTE: this is part of the invoicesrpc.GraphSource interface. +func (s *DBSource) IsPublicNode(_ context.Context, + pubKey [33]byte) (bool, error) { + + return s.db.IsPublicNode(pubKey) +} + +// FetchChannelEdgesByOutpoint returns the channel edge info and most recent +// channel edge policies for a given outpoint. +// +// NOTE: this is part of the netann.ChannelGraph interface. +func (s *DBSource) FetchChannelEdgesByOutpoint(_ context.Context, + point *wire.OutPoint) (*models.ChannelEdgeInfo, + *models.ChannelEdgePolicy, *models.ChannelEdgePolicy, error) { + + return s.db.FetchChannelEdgesByOutpoint(point) +} + +// AddrsForNode returns all known addresses for the target node public key. The +// returned boolean indicatex if the given node is unknown to the backing +// source. +// +// NOTE: this is part of the channeldb.AddrSource interface. +func (s *DBSource) AddrsForNode(ctx context.Context, + nodePub *btcec.PublicKey) (bool, []net.Addr, error) { + + return s.db.AddrsForNode(ctx, nodePub) +} + +// ForEachChannel iterates through all the channel edges stored within the graph +// and invokes the passed callback for each edge. If the callback returns an +// error, then the transaction is aborted and the iteration stops early. An +// edge's policy structs may be nil if the ChannelUpdate in question has not yet +// been received for the channel. +// +// NOTE: this is part of the GraphSource interface. +func (s *DBSource) ForEachChannel(_ context.Context, + cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy, + *models.ChannelEdgePolicy) error) error { + + return s.db.ForEachChannel(cb) +} + +// ForEachNode iterates through all the stored vertices/nodes in the graph, +// executing the passed callback with each node encountered. If the callback +// returns an error, then the transaction is aborted and the iteration stops +// early. +// +// NOTE: this is part of the GraphSource interface. +func (s *DBSource) ForEachNode(_ context.Context, + cb func(*models.LightningNode) error) error { + + return s.db.ForEachNode(func(_ kvdb.RTx, + node *models.LightningNode) error { + + return cb(node) + }) +} + +// HasLightningNode determines if the graph has a vertex identified by the +// target node identity public key. If the node exists in the database, a +// timestamp of when the data for the node was lasted updated is returned along +// with a true boolean. Otherwise, an empty time.Time is returned with a false +// boolean. +// +// NOTE: this is part of the GraphSource interface. +func (s *DBSource) HasLightningNode(_ context.Context, + nodePub [33]byte) (time.Time, bool, error) { + + return s.db.HasLightningNode(nodePub) +} + +// LookupAlias attempts to return the alias as advertised by the target node. +// graphdb.ErrNodeAliasNotFound is returned if the alias is not found. +// +// NOTE: this is part of the GraphSource interface. +func (s *DBSource) LookupAlias(_ context.Context, + pub *btcec.PublicKey) (string, error) { + + return s.db.LookupAlias(pub) +} + +// ForEachNodeChannel iterates through all channels of the given node, executing +// the passed callback with an edge info structure and the policies of each end +// of the channel. The first edge policy is the outgoing edge *to* the +// connecting node, while the second is the incoming edge *from* the connecting +// node. If the callback returns an error, then the iteration is halted with the +// error propagated back up to the caller. Unknown policies are passed into the +// callback as nil values. +// +// NOTE: this is part of the GraphSource interface. +func (s *DBSource) ForEachNodeChannel(_ context.Context, + nodePub route.Vertex, cb func(*models.ChannelEdgeInfo, + *models.ChannelEdgePolicy, + *models.ChannelEdgePolicy) error) error { + + return s.db.ForEachNodeChannel(nodePub, func(_ kvdb.RTx, + info *models.ChannelEdgeInfo, policy *models.ChannelEdgePolicy, + policy2 *models.ChannelEdgePolicy) error { + + return cb(info, policy, policy2) + }) +} + +// FetchLightningNode attempts to look up a target node by its identity public +// key. If the node isn't found in the database, then +// graphdb.ErrGraphNodeNotFound is returned. +// +// NOTE: this is part of the GraphSource interface. +func (s *DBSource) FetchLightningNode(_ context.Context, + nodePub route.Vertex) (*models.LightningNode, error) { + + return s.db.FetchLightningNode(nodePub) +} + +// GraphBootstrapper returns a NetworkPeerBootstrapper instance backed by the +// ChannelGraph instance. +// +// NOTE: this is part of the GraphSource interface. +func (s *DBSource) GraphBootstrapper(_ context.Context) ( + discovery.NetworkPeerBootstrapper, error) { + + chanGraph := autopilot.ChannelGraphFromDatabase(s.db) + + return discovery.NewGraphBootstrapper(chanGraph) +} + +// NetworkStats returns statistics concerning the current state of the known +// channel graph within the network. +// +// NOTE: this is part of the GraphSource interface. +func (s *DBSource) NetworkStats(_ context.Context) (*models.NetworkStats, + error) { + + var ( + numNodes uint32 + numChannels uint32 + maxChanOut uint32 + totalNetworkCapacity btcutil.Amount + minChannelSize btcutil.Amount = math.MaxInt64 + maxChannelSize btcutil.Amount + medianChanSize btcutil.Amount + ) + + // We'll use this map to de-duplicate channels during our traversal. + // This is needed since channels are directional, so there will be two + // edges for each channel within the graph. + seenChans := make(map[uint64]struct{}) + + // We also keep a list of all encountered capacities, in order to + // calculate the median channel size. + var allChans []btcutil.Amount + + // We'll run through all the known nodes in the within our view of the + // network, tallying up the total number of nodes, and also gathering + // each node so we can measure the graph diameter and degree stats + // below. + err := s.db.ForEachNodeCached(func(node route.Vertex, + edges map[uint64]*graphdb.DirectedChannel) error { + + // Increment the total number of nodes with each iteration. + numNodes++ + + // For each channel we'll compute the out degree of each node, + // and also update our running tallies of the min/max channel + // capacity, as well as the total channel capacity. We pass + // through the DB transaction from the outer view so we can + // re-use it within this inner view. + var outDegree uint32 + for _, edge := range edges { + // Bump up the out degree for this node for each + // channel encountered. + outDegree++ + + // If we've already seen this channel, then we'll + // return early to ensure that we don't double-count + // stats. + if _, ok := seenChans[edge.ChannelID]; ok { + return nil + } + + // Compare the capacity of this channel against the + // running min/max to see if we should update the + // extrema. + chanCapacity := edge.Capacity + if chanCapacity < minChannelSize { + minChannelSize = chanCapacity + } + if chanCapacity > maxChannelSize { + maxChannelSize = chanCapacity + } + + // Accumulate the total capacity of this channel to the + // network wide-capacity. + totalNetworkCapacity += chanCapacity + + numChannels++ + + seenChans[edge.ChannelID] = struct{}{} + allChans = append(allChans, edge.Capacity) + } + + // Finally, if the out degree of this node is greater than what + // we've seen so far, update the maxChanOut variable. + if outDegree > maxChanOut { + maxChanOut = outDegree + } + + return nil + }) + if err != nil { + return nil, err + } + + // Find the median. + medianChanSize = autopilot.Median(allChans) + + // If we don't have any channels, then reset the minChannelSize to zero + // to avoid outputting NaN in encoded JSON. + if numChannels == 0 { + minChannelSize = 0 + } + + // Graph diameter. + channelGraph := autopilot.ChannelGraphFromCachedDatabase(s.db) + simpleGraph, err := autopilot.NewSimpleGraph(channelGraph) + if err != nil { + return nil, err + } + start := time.Now() + diameter := simpleGraph.DiameterRadialCutoff() + + log.Infof("Elapsed time for diameter (%d) calculation: %v", diameter, + time.Since(start)) + + // Query the graph for the current number of zombie channels. + numZombies, err := s.db.NumZombies() + if err != nil { + return nil, err + } + + return &models.NetworkStats{ + Diameter: diameter, + MaxChanOut: maxChanOut, + NumNodes: numNodes, + NumChannels: numChannels, + TotalNetworkCapacity: totalNetworkCapacity, + MinChanSize: minChannelSize, + MaxChanSize: maxChannelSize, + MedianChanSize: medianChanSize, + NumZombies: numZombies, + }, nil +} + +// BetweennessCentrality computes the normalised and non-normalised betweenness +// centrality for each node in the graph. +// +// NOTE: this is part of the GraphSource interface. +func (s *DBSource) BetweennessCentrality(_ context.Context) ( + map[autopilot.NodeID]*models.BetweennessCentrality, error) { + + channelGraph := autopilot.ChannelGraphFromDatabase(s.db) + centralityMetric, err := autopilot.NewBetweennessCentralityMetric( + runtime.NumCPU(), + ) + if err != nil { + return nil, err + } + + if err := centralityMetric.Refresh(channelGraph); err != nil { + return nil, err + } + + centrality := make(map[autopilot.NodeID]*models.BetweennessCentrality) + + for nodeID, val := range centralityMetric.GetMetric(true) { + centrality[nodeID] = &models.BetweennessCentrality{ + Normalized: val, + } + } + + for nodeID, val := range centralityMetric.GetMetric(false) { + if _, ok := centrality[nodeID]; !ok { + centrality[nodeID] = &models.BetweennessCentrality{ + Normalized: val, + } + + continue + } + centrality[nodeID].NonNormalized = val + } + + return centrality, nil +} + +// kvdbRTx is an implementation of graphdb.RTx backed by a KVDB database read +// transaction. +type kvdbRTx struct { + kvdb.RTx +} + +// newKVDBRTx constructs a kvdbRTx instance backed by the given kvdb.RTx. +func newKVDBRTx(tx kvdb.RTx) *kvdbRTx { + return &kvdbRTx{tx} +} + +// Close closes the underlying transaction. +// +// NOTE: this is part of the graphdb.RTx interface. +func (t *kvdbRTx) Close() error { + if t.RTx == nil { + return nil + } + + return t.RTx.Rollback() +} + +// MustImplementRTx is a helper method that ensures that the kvdbRTx type +// implements the RTx interface. +// +// NOTE: this is part of the graphdb.RTx interface. +func (t *kvdbRTx) MustImplementRTx() {} + +// A compile-time assertion to ensure that kvdbRTx implements the RTx interface. +var _ session.RTx = (*kvdbRTx)(nil) + +// extractKVDBRTx is a helper function that casts an RTx into a kvdbRTx and +// errors if the cast fails. +func extractKVDBRTx(tx session.RTx) (kvdb.RTx, error) { + if tx == nil { + return nil, nil + } + + kvdbTx, ok := tx.(*kvdbRTx) + if !ok { + return nil, fmt.Errorf("expected a graphdb.kvdbRTx, got %T", tx) + } + + return kvdbTx, nil +} diff --git a/graph/sources/interfaces.go b/graph/sources/interfaces.go new file mode 100644 index 0000000000..448c5bb00c --- /dev/null +++ b/graph/sources/interfaces.go @@ -0,0 +1,89 @@ +package sources + +import ( + "context" + "time" + + "github.com/btcsuite/btcd/btcec/v2" + "github.com/lightningnetwork/lnd/autopilot" + "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/discovery" + "github.com/lightningnetwork/lnd/graph/db/models" + "github.com/lightningnetwork/lnd/graph/session" + "github.com/lightningnetwork/lnd/lnrpc/invoicesrpc" + "github.com/lightningnetwork/lnd/netann" + "github.com/lightningnetwork/lnd/routing/route" +) + +// GraphSource defines the read-only graph interface required by LND for graph +// related queries. +// +//nolint:interfacebloat +type GraphSource interface { + session.ReadOnlyGraph + invoicesrpc.GraphSource + netann.ChannelGraph + channeldb.AddrSource + + // ForEachChannel iterates through all the channel edges stored within + // the graph and invokes the passed callback for each edge. If the + // callback returns an error, then the transaction is aborted and the + // iteration stops early. An edge's policy structs may be nil if the + // ChannelUpdate in question has not yet been received for the channel. + ForEachChannel(ctx context.Context, cb func(*models.ChannelEdgeInfo, + *models.ChannelEdgePolicy, + *models.ChannelEdgePolicy) error) error + + // ForEachNode iterates through all the stored vertices/nodes in the + // graph, executing the passed callback with each node encountered. If + // the callback returns an error, then the transaction is aborted and + // the iteration stops early. + ForEachNode(ctx context.Context, + cb func(*models.LightningNode) error) error + + // HasLightningNode determines if the graph has a vertex identified by + // the target node identity public key. If the node exists in the + // database, a timestamp of when the data for the node was lasted + // updated is returned along with a true boolean. Otherwise, an empty + // time.Time is returned with a false boolean. + HasLightningNode(ctx context.Context, nodePub [33]byte) (time.Time, + bool, error) + + // LookupAlias attempts to return the alias as advertised by the target + // node. graphdb.ErrNodeAliasNotFound is returned if the alias is not + // found. + LookupAlias(ctx context.Context, pub *btcec.PublicKey) (string, error) + + // ForEachNodeChannel iterates through all channels of the given node, + // executing the passed callback with an edge info structure and the + // policies of each end of the channel. The first edge policy is the + // outgoing edge *to* the connecting node, while the second is the + // incoming edge *from* the connecting node. If the callback returns an + // error, then the iteration is halted with the error propagated back up + // to the caller. Unknown policies are passed into the callback as nil + // values. + ForEachNodeChannel(ctx context.Context, + nodePub route.Vertex, cb func(*models.ChannelEdgeInfo, + *models.ChannelEdgePolicy, + *models.ChannelEdgePolicy) error) error + + // FetchLightningNode attempts to look up a target node by its identity + // public key. If the node isn't found in the database, then + // graphdb.ErrGraphNodeNotFound is returned. + FetchLightningNode(ctx context.Context, nodePub route.Vertex) ( + *models.LightningNode, error) + + // GraphBootstrapper returns a network peer bootstrapper that can be + // used to discover new peers to connect to. + GraphBootstrapper(ctx context.Context) ( + discovery.NetworkPeerBootstrapper, error) + + // NetworkStats returns statistics concerning the current state of the + // known channel graph within the network. + NetworkStats(ctx context.Context) (*models.NetworkStats, error) + + // BetweennessCentrality computes the normalised and non-normalised + // betweenness centrality for each node in the graph. + BetweennessCentrality(ctx context.Context) ( + map[autopilot.NodeID]*models.BetweennessCentrality, error) +} diff --git a/graph/sources/log.go b/graph/sources/log.go new file mode 100644 index 0000000000..b05d38928a --- /dev/null +++ b/graph/sources/log.go @@ -0,0 +1,31 @@ +package sources + +import ( + "github.com/btcsuite/btclog/v2" + "github.com/lightningnetwork/lnd/build" +) + +// log is a logger that is initialized with no output filters. This means the +// package will not perform any logging by default until the caller requests +// it. +var log btclog.Logger + +const Subsystem = "GRSR" + +// The default amount of logging is none. +func init() { + UseLogger(build.NewSubLogger(Subsystem, nil)) +} + +// DisableLog disables all library log output. Logging output is disabled by +// default until UseLogger is called. +func DisableLog() { + UseLogger(btclog.Disabled) +} + +// UseLogger uses a specified Logger to output package logging info. This +// should be used in preference to SetLogWriter if the caller is also using +// btclog. +func UseLogger(logger btclog.Logger) { + log = logger +} diff --git a/kvdb/etcd/db_test.go b/kvdb/etcd/db_test.go index 59e29fc94a..9ef68d9fe5 100644 --- a/kvdb/etcd/db_test.go +++ b/kvdb/etcd/db_test.go @@ -19,7 +19,7 @@ func TestDump(t *testing.T) { f := NewEtcdTestFixture(t) - db, err := newEtcdBackend(context.TODO(), f.BackendConfig()) + db, err := newEtcdBackend(context.Background(), f.BackendConfig()) require.NoError(t, err) err = db.Update(func(tx walletdb.ReadWriteTx) error { diff --git a/kvdb/etcd/readwrite_tx_test.go b/kvdb/etcd/readwrite_tx_test.go index b5758f7a3f..d66b2e2512 100644 --- a/kvdb/etcd/readwrite_tx_test.go +++ b/kvdb/etcd/readwrite_tx_test.go @@ -16,7 +16,7 @@ func TestChangeDuringManualTx(t *testing.T) { f := NewEtcdTestFixture(t) - db, err := newEtcdBackend(context.TODO(), f.BackendConfig()) + db, err := newEtcdBackend(context.Background(), f.BackendConfig()) require.NoError(t, err) tx, err := db.BeginReadWriteTx() @@ -44,7 +44,7 @@ func TestChangeDuringUpdate(t *testing.T) { f := NewEtcdTestFixture(t) - db, err := newEtcdBackend(context.TODO(), f.BackendConfig()) + db, err := newEtcdBackend(context.Background(), f.BackendConfig()) require.NoError(t, err) count := 0 diff --git a/kvdb/etcd/walletdb_interface_test.go b/kvdb/etcd/walletdb_interface_test.go index 13c57e337d..483becbb2f 100644 --- a/kvdb/etcd/walletdb_interface_test.go +++ b/kvdb/etcd/walletdb_interface_test.go @@ -15,5 +15,5 @@ import ( func TestWalletDBInterface(t *testing.T) { f := NewEtcdTestFixture(t) cfg := f.BackendConfig() - walletdbtest.TestInterface(t, dbType, context.TODO(), &cfg) + walletdbtest.TestInterface(t, dbType, context.Background(), &cfg) } diff --git a/lnd.go b/lnd.go index f511811950..55d5b63869 100644 --- a/lnd.go +++ b/lnd.go @@ -485,6 +485,11 @@ func Main(cfg *Config, lisCfg ListenerCfg, implCfg *ImplementationCfg, return mkErr("error deriving node key: %v", err) } + graphSource, err := implCfg.Graph(ctx, dbs) + if err != nil { + return mkErr("error obtaining graph source: %v", err) + } + if cfg.Tor.StreamIsolation && cfg.Tor.SkipProxyForClearNetTargets { return errStreamIsolationWithProxySkip } @@ -598,10 +603,10 @@ func Main(cfg *Config, lisCfg ListenerCfg, implCfg *ImplementationCfg, // Set up the core server which will listen for incoming peer // connections. server, err := newServer( - cfg, cfg.Listeners, dbs, activeChainControl, &idKeyDesc, + ctx, cfg, cfg.Listeners, dbs, activeChainControl, &idKeyDesc, activeChainControl.Cfg.WalletUnlockParams.ChansToRestore, multiAcceptor, torController, tlsManager, leaderElector, - implCfg, + graphSource, implCfg, ) if err != nil { return mkErr("unable to create server: %v", err) @@ -611,7 +616,7 @@ func Main(cfg *Config, lisCfg ListenerCfg, implCfg *ImplementationCfg, // used to manage the underlying autopilot agent, starting and stopping // it at will. atplCfg, err := initAutoPilot( - server, cfg.Autopilot, activeChainControl.MinHtlcIn, + ctx, server, cfg.Autopilot, activeChainControl.MinHtlcIn, cfg.ActiveNetParams, ) if err != nil { @@ -731,7 +736,7 @@ func Main(cfg *Config, lisCfg ListenerCfg, implCfg *ImplementationCfg, // case the startup of the subservers do not behave as expected. errChan := make(chan error) go func() { - errChan <- server.Start() + errChan <- server.Start(ctx) }() defer func() { diff --git a/lnrpc/invoicesrpc/addinvoice.go b/lnrpc/invoicesrpc/addinvoice.go index 59f7df610b..c69db4f5c0 100644 --- a/lnrpc/invoicesrpc/addinvoice.go +++ b/lnrpc/invoicesrpc/addinvoice.go @@ -18,7 +18,6 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" - graphdb "github.com/lightningnetwork/lnd/graph/db" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/invoices" "github.com/lightningnetwork/lnd/lntypes" @@ -75,8 +74,9 @@ type AddInvoiceConfig struct { // channel graph. ChanDB *channeldb.ChannelStateDB - // Graph holds a reference to the ChannelGraph database. - Graph *graphdb.ChannelGraph + // Graph holds a reference to a GraphSource that can be queried for + // graph related data. + Graph GraphSource // GenInvoiceFeatures returns a feature containing feature bits that // should be advertised on freshly generated invoices. @@ -96,7 +96,8 @@ type AddInvoiceConfig struct { // QueryBlindedRoutes can be used to generate a few routes to this node // that can then be used in the construction of a blinded payment path. - QueryBlindedRoutes func(lnwire.MilliSatoshi) ([]*route.Route, error) + QueryBlindedRoutes func(context.Context, lnwire.MilliSatoshi) ( + []*route.Route, error) } // AddInvoiceData contains the required data to create a new invoice. @@ -462,7 +463,7 @@ func AddInvoice(ctx context.Context, cfg *AddInvoiceConfig, hopHintsCfg := newSelectHopHintsCfg(cfg, totalHopHints) hopHints, err := PopulateHopHints( - hopHintsCfg, amtMSat, invoice.RouteHints, + ctx, hopHintsCfg, amtMSat, invoice.RouteHints, ) if err != nil { return nil, nil, fmt.Errorf("unable to populate hop "+ @@ -521,7 +522,7 @@ func AddInvoice(ctx context.Context, cfg *AddInvoiceConfig, //nolint:lll paths, err := blindedpath.BuildBlindedPaymentPaths( - &blindedpath.BuildBlindedPathCfg{ + ctx, &blindedpath.BuildBlindedPathCfg{ FindRoutes: cfg.QueryBlindedRoutes, FetchChannelEdgesByID: cfg.Graph.FetchChannelEdgesByID, FetchOurOpenChannels: cfg.ChanDB.FetchAllOpenChannels, @@ -624,8 +625,8 @@ func AddInvoice(ctx context.Context, cfg *AddInvoiceConfig, // chanCanBeHopHint returns true if the target channel is eligible to be a hop // hint. -func chanCanBeHopHint(channel *HopHintInfo, cfg *SelectHopHintsCfg) ( - *models.ChannelEdgePolicy, bool) { +func chanCanBeHopHint(ctx context.Context, channel *HopHintInfo, + cfg *SelectHopHintsCfg) (*models.ChannelEdgePolicy, bool) { // Since we're only interested in our private channels, we'll skip // public ones. @@ -648,7 +649,7 @@ func chanCanBeHopHint(channel *HopHintInfo, cfg *SelectHopHintsCfg) ( // channels. var remotePub [33]byte copy(remotePub[:], channel.RemotePubkey.SerializeCompressed()) - isRemoteNodePublic, err := cfg.IsPublicNode(remotePub) + isRemoteNodePublic, err := cfg.IsPublicNode(ctx, remotePub) if err != nil { log.Errorf("Unable to determine if node %x "+ "is advertised: %v", remotePub, err) @@ -663,13 +664,17 @@ func chanCanBeHopHint(channel *HopHintInfo, cfg *SelectHopHintsCfg) ( } // Fetch the policies for each end of the channel. - info, p1, p2, err := cfg.FetchChannelEdgesByID(channel.ShortChannelID) + info, p1, p2, err := cfg.FetchChannelEdgesByID( + ctx, channel.ShortChannelID, + ) if err != nil { // In the case of zero-conf channels, it may be the case that // the alias SCID was deleted from the graph, and replaced by // the confirmed SCID. Check the Graph for the confirmed SCID. confirmedScid := channel.ConfirmedScidZC - info, p1, p2, err = cfg.FetchChannelEdgesByID(confirmedScid) + info, p1, p2, err = cfg.FetchChannelEdgesByID( + ctx, confirmedScid, + ) if err != nil { log.Errorf("Unable to fetch the routing policies for "+ "the edges of the channel %v: %v", @@ -759,13 +764,13 @@ type SelectHopHintsCfg struct { // IsPublicNode is returns a bool indicating whether the node with the // given public key is seen as a public node in the graph from the // graph's source node's point of view. - IsPublicNode func(pubKey [33]byte) (bool, error) + IsPublicNode func(ctx context.Context, pubKey [33]byte) (bool, error) // FetchChannelEdgesByID attempts to lookup the two directed edges for // the channel identified by the channel ID. - FetchChannelEdgesByID func(chanID uint64) (*models.ChannelEdgeInfo, - *models.ChannelEdgePolicy, *models.ChannelEdgePolicy, - error) + FetchChannelEdgesByID func(ctx context.Context, + chanID uint64) (*models.ChannelEdgeInfo, + *models.ChannelEdgePolicy, *models.ChannelEdgePolicy, error) // GetAlias allows the peer's alias SCID to be retrieved for private // option_scid_alias channels. @@ -856,7 +861,7 @@ func getPotentialHints(cfg *SelectHopHintsCfg) ([]*channeldb.OpenChannel, // shouldIncludeChannel returns true if the channel passes all the checks to // be a hopHint in a given invoice. -func shouldIncludeChannel(cfg *SelectHopHintsCfg, +func shouldIncludeChannel(ctx context.Context, cfg *SelectHopHintsCfg, channel *channeldb.OpenChannel, alreadyIncluded map[uint64]bool) (zpay32.HopHint, lnwire.MilliSatoshi, bool) { @@ -872,7 +877,7 @@ func shouldIncludeChannel(cfg *SelectHopHintsCfg, hopHintInfo := newHopHintInfo(channel, cfg.IsChannelActive(chanID)) // If this channel can't be a hop hint, then skip it. - edgePolicy, canBeHopHint := chanCanBeHopHint(hopHintInfo, cfg) + edgePolicy, canBeHopHint := chanCanBeHopHint(ctx, hopHintInfo, cfg) if edgePolicy == nil || !canBeHopHint { return zpay32.HopHint{}, 0, false } @@ -901,7 +906,7 @@ func shouldIncludeChannel(cfg *SelectHopHintsCfg, // // NOTE: selectHopHints expects potentialHints to be already sorted in // descending priority. -func selectHopHints(cfg *SelectHopHintsCfg, nHintsLeft int, +func selectHopHints(ctx context.Context, cfg *SelectHopHintsCfg, nHintsLeft int, targetBandwidth lnwire.MilliSatoshi, potentialHints []*channeldb.OpenChannel, alreadyIncluded map[uint64]bool) [][]zpay32.HopHint { @@ -917,7 +922,7 @@ func selectHopHints(cfg *SelectHopHintsCfg, nHintsLeft int, } hopHint, remoteBalance, include := shouldIncludeChannel( - cfg, channel, alreadyIncluded, + ctx, cfg, channel, alreadyIncluded, ) if include { @@ -945,8 +950,9 @@ func selectHopHints(cfg *SelectHopHintsCfg, nHintsLeft int, // options that'll append the route hint to the set of all route hints. // // TODO(roasbeef): do proper sub-set sum max hints usually << numChans. -func PopulateHopHints(cfg *SelectHopHintsCfg, amtMSat lnwire.MilliSatoshi, - forcedHints [][]zpay32.HopHint) ([][]zpay32.HopHint, error) { +func PopulateHopHints(ctx context.Context, cfg *SelectHopHintsCfg, + amtMSat lnwire.MilliSatoshi, forcedHints [][]zpay32.HopHint) ( + [][]zpay32.HopHint, error) { hopHints := forcedHints @@ -968,7 +974,7 @@ func PopulateHopHints(cfg *SelectHopHintsCfg, amtMSat lnwire.MilliSatoshi, targetBandwidth := amtMSat * hopHintFactor selectedHints := selectHopHints( - cfg, nHintsLeft, targetBandwidth, potentialHints, + ctx, cfg, nHintsLeft, targetBandwidth, potentialHints, alreadyIncluded, ) diff --git a/lnrpc/invoicesrpc/addinvoice_test.go b/lnrpc/invoicesrpc/addinvoice_test.go index 546b9cc725..8aadff354b 100644 --- a/lnrpc/invoicesrpc/addinvoice_test.go +++ b/lnrpc/invoicesrpc/addinvoice_test.go @@ -1,6 +1,7 @@ package invoicesrpc import ( + "context" "encoding/hex" "fmt" "testing" @@ -35,8 +36,10 @@ func newHopHintsConfigMock(t *testing.T) *hopHintsConfigMock { } // IsPublicNode mocks node public state lookup. -func (h *hopHintsConfigMock) IsPublicNode(pubKey [33]byte) (bool, error) { - args := h.Mock.Called(pubKey) +func (h *hopHintsConfigMock) IsPublicNode(ctx context.Context, + pubKey [33]byte) (bool, error) { + + args := h.Mock.Called(ctx, pubKey) return args.Bool(0), args.Error(1) } @@ -66,11 +69,11 @@ func (h *hopHintsConfigMock) FetchAllChannels() ([]*channeldb.OpenChannel, // FetchChannelEdgesByID attempts to lookup the two directed edges for // the channel identified by the channel ID. -func (h *hopHintsConfigMock) FetchChannelEdgesByID(chanID uint64) ( - *models.ChannelEdgeInfo, *models.ChannelEdgePolicy, +func (h *hopHintsConfigMock) FetchChannelEdgesByID(ctx context.Context, + chanID uint64) (*models.ChannelEdgeInfo, *models.ChannelEdgePolicy, *models.ChannelEdgePolicy, error) { - args := h.Mock.Called(chanID) + args := h.Mock.Called(ctx, chanID) // If our error is non-nil, we expect nil responses otherwise. Our // casts below will fail with nil values, so we check our error and @@ -161,7 +164,7 @@ var shouldIncludeChannelTestCases = []struct { ).Once().Return(true) h.Mock.On( - "IsPublicNode", mock.Anything, + "IsPublicNode", mock.Anything, mock.Anything, ).Once().Return(false, nil) }, channel: &channeldb.OpenChannel{ @@ -185,18 +188,18 @@ var shouldIncludeChannelTestCases = []struct { ).Once().Return(true) h.Mock.On( - "IsPublicNode", mock.Anything, + "IsPublicNode", mock.Anything, mock.Anything, ).Once().Return(true, nil) h.Mock.On( - "FetchChannelEdgesByID", mock.Anything, + "FetchChannelEdgesByID", mock.Anything, mock.Anything, ).Once().Return(nil, nil, nil, fmt.Errorf("no edge")) // TODO(positiveblue): check that the func is called with the // right scid when we have access to the `confirmedscid` form // here. h.Mock.On( - "FetchChannelEdgesByID", mock.Anything, + "FetchChannelEdgesByID", mock.Anything, mock.Anything, ).Once().Return(nil, nil, nil, fmt.Errorf("no edge")) }, channel: &channeldb.OpenChannel{ @@ -220,11 +223,11 @@ var shouldIncludeChannelTestCases = []struct { ).Once().Return(true) h.Mock.On( - "IsPublicNode", mock.Anything, + "IsPublicNode", mock.Anything, mock.Anything, ).Once().Return(true, nil) h.Mock.On( - "FetchChannelEdgesByID", mock.Anything, + "FetchChannelEdgesByID", mock.Anything, mock.Anything, ).Once().Return( &models.ChannelEdgeInfo{}, &models.ChannelEdgePolicy{}, @@ -258,11 +261,11 @@ var shouldIncludeChannelTestCases = []struct { ).Once().Return(true) h.Mock.On( - "IsPublicNode", mock.Anything, + "IsPublicNode", mock.Anything, mock.Anything, ).Once().Return(true, nil) h.Mock.On( - "FetchChannelEdgesByID", mock.Anything, + "FetchChannelEdgesByID", mock.Anything, mock.Anything, ).Once().Return( &models.ChannelEdgeInfo{}, &models.ChannelEdgePolicy{}, @@ -296,14 +299,14 @@ var shouldIncludeChannelTestCases = []struct { ).Once().Return(true) h.Mock.On( - "IsPublicNode", mock.Anything, + "IsPublicNode", mock.Anything, mock.Anything, ).Once().Return(true, nil) var selectedPolicy [33]byte copy(selectedPolicy[:], getTestPubKey().SerializeCompressed()) h.Mock.On( - "FetchChannelEdgesByID", mock.Anything, + "FetchChannelEdgesByID", mock.Anything, mock.Anything, ).Once().Return( &models.ChannelEdgeInfo{ NodeKey1Bytes: selectedPolicy, @@ -347,11 +350,11 @@ var shouldIncludeChannelTestCases = []struct { ).Once().Return(true) h.Mock.On( - "IsPublicNode", mock.Anything, + "IsPublicNode", mock.Anything, mock.Anything, ).Once().Return(true, nil) h.Mock.On( - "FetchChannelEdgesByID", mock.Anything, + "FetchChannelEdgesByID", mock.Anything, mock.Anything, ).Once().Return( &models.ChannelEdgeInfo{}, &models.ChannelEdgePolicy{}, @@ -392,11 +395,11 @@ var shouldIncludeChannelTestCases = []struct { ).Once().Return(true) h.Mock.On( - "IsPublicNode", mock.Anything, + "IsPublicNode", mock.Anything, mock.Anything, ).Once().Return(true, nil) h.Mock.On( - "FetchChannelEdgesByID", mock.Anything, + "FetchChannelEdgesByID", mock.Anything, mock.Anything, ).Once().Return( &models.ChannelEdgeInfo{}, &models.ChannelEdgePolicy{}, @@ -432,6 +435,7 @@ var shouldIncludeChannelTestCases = []struct { }} func TestShouldIncludeChannel(t *testing.T) { + ctx := context.Background() for _, tc := range shouldIncludeChannelTestCases { tc := tc @@ -453,7 +457,7 @@ func TestShouldIncludeChannel(t *testing.T) { } hopHint, remoteBalance, include := shouldIncludeChannel( - cfg, tc.channel, tc.alreadyIncluded, + ctx, cfg, tc.channel, tc.alreadyIncluded, ) require.Equal(t, tc.include, include) @@ -559,11 +563,11 @@ var populateHopHintsTestCases = []struct { ).Once().Return(true) h.Mock.On( - "IsPublicNode", mock.Anything, + "IsPublicNode", mock.Anything, mock.Anything, ).Once().Return(true, nil) h.Mock.On( - "FetchChannelEdgesByID", mock.Anything, + "FetchChannelEdgesByID", mock.Anything, mock.Anything, ).Once().Return( &models.ChannelEdgeInfo{}, &models.ChannelEdgePolicy{}, @@ -609,11 +613,11 @@ var populateHopHintsTestCases = []struct { ).Once().Return(true) h.Mock.On( - "IsPublicNode", mock.Anything, + "IsPublicNode", mock.Anything, mock.Anything, ).Once().Return(true, nil) h.Mock.On( - "FetchChannelEdgesByID", mock.Anything, + "FetchChannelEdgesByID", mock.Anything, mock.Anything, ).Once().Return( &models.ChannelEdgeInfo{}, &models.ChannelEdgePolicy{}, @@ -660,11 +664,11 @@ var populateHopHintsTestCases = []struct { ).Once().Return(true) h.Mock.On( - "IsPublicNode", mock.Anything, + "IsPublicNode", mock.Anything, mock.Anything, ).Once().Return(true, nil) h.Mock.On( - "FetchChannelEdgesByID", mock.Anything, + "FetchChannelEdgesByID", mock.Anything, mock.Anything, ).Once().Return( &models.ChannelEdgeInfo{}, &models.ChannelEdgePolicy{}, @@ -693,11 +697,11 @@ var populateHopHintsTestCases = []struct { ).Once().Return(true) h.Mock.On( - "IsPublicNode", mock.Anything, + "IsPublicNode", mock.Anything, mock.Anything, ).Once().Return(true, nil) h.Mock.On( - "FetchChannelEdgesByID", mock.Anything, + "FetchChannelEdgesByID", mock.Anything, mock.Anything, ).Once().Return( &models.ChannelEdgeInfo{}, &models.ChannelEdgePolicy{}, @@ -710,11 +714,11 @@ var populateHopHintsTestCases = []struct { ).Once().Return(true) h.Mock.On( - "IsPublicNode", mock.Anything, + "IsPublicNode", mock.Anything, mock.Anything, ).Once().Return(true, nil) h.Mock.On( - "FetchChannelEdgesByID", mock.Anything, + "FetchChannelEdgesByID", mock.Anything, mock.Anything, ).Once().Return( &models.ChannelEdgeInfo{}, &models.ChannelEdgePolicy{}, @@ -747,11 +751,11 @@ var populateHopHintsTestCases = []struct { ).Once().Return(true) h.Mock.On( - "IsPublicNode", mock.Anything, + "IsPublicNode", mock.Anything, mock.Anything, ).Once().Return(true, nil) h.Mock.On( - "FetchChannelEdgesByID", mock.Anything, + "FetchChannelEdgesByID", mock.Anything, mock.Anything, ).Once().Return( &models.ChannelEdgeInfo{}, &models.ChannelEdgePolicy{}, @@ -764,11 +768,11 @@ var populateHopHintsTestCases = []struct { ).Once().Return(true) h.Mock.On( - "IsPublicNode", mock.Anything, + "IsPublicNode", mock.Anything, mock.Anything, ).Once().Return(true, nil) h.Mock.On( - "FetchChannelEdgesByID", mock.Anything, + "FetchChannelEdgesByID", mock.Anything, mock.Anything, ).Once().Return( &models.ChannelEdgeInfo{}, &models.ChannelEdgePolicy{}, @@ -802,11 +806,11 @@ var populateHopHintsTestCases = []struct { ).Once().Return(true) h.Mock.On( - "IsPublicNode", mock.Anything, + "IsPublicNode", mock.Anything, mock.Anything, ).Once().Return(true, nil) h.Mock.On( - "FetchChannelEdgesByID", mock.Anything, + "FetchChannelEdgesByID", mock.Anything, mock.Anything, ).Once().Return( &models.ChannelEdgeInfo{}, &models.ChannelEdgePolicy{}, @@ -865,6 +869,7 @@ func setupMockTwoChannels(h *hopHintsConfigMock) (lnwire.ChannelID, } func TestPopulateHopHints(t *testing.T) { + ctx := context.Background() for _, tc := range populateHopHintsTestCases { tc := tc @@ -887,7 +892,7 @@ func TestPopulateHopHints(t *testing.T) { MaxHopHints: tc.maxHopHints, } hopHints, err := PopulateHopHints( - cfg, tc.amount, tc.forcedHints, + ctx, cfg, tc.amount, tc.forcedHints, ) require.NoError(t, err) // We shuffle the elements in the hop hint list so we diff --git a/lnrpc/invoicesrpc/config_active.go b/lnrpc/invoicesrpc/config_active.go index 14799c67ba..60051fe36a 100644 --- a/lnrpc/invoicesrpc/config_active.go +++ b/lnrpc/invoicesrpc/config_active.go @@ -6,7 +6,6 @@ package invoicesrpc import ( "github.com/btcsuite/btcd/chaincfg" "github.com/lightningnetwork/lnd/channeldb" - graphdb "github.com/lightningnetwork/lnd/graph/db" "github.com/lightningnetwork/lnd/invoices" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/macaroons" @@ -52,9 +51,8 @@ type Config struct { // specified. DefaultCLTVExpiry uint32 - // GraphDB is a global database instance which is needed to access the - // channel graph. - GraphDB *graphdb.ChannelGraph + // Graph can be used for graph related queries. + Graph GraphSource // ChanStateDB is a possibly replicated db instance which contains the // channel state diff --git a/lnrpc/invoicesrpc/interfaces.go b/lnrpc/invoicesrpc/interfaces.go new file mode 100644 index 0000000000..99a8bef23d --- /dev/null +++ b/lnrpc/invoicesrpc/interfaces.go @@ -0,0 +1,22 @@ +package invoicesrpc + +import ( + "context" + + "github.com/lightningnetwork/lnd/graph/db/models" +) + +// GraphSource defines the graph interface required by the invoice rpc server. +type GraphSource interface { + // FetchChannelEdgesByID attempts to look up the two directed edges for + // the channel identified by the channel ID. If the channel can't be + // found, then graphdb.ErrEdgeNotFound is returned. + FetchChannelEdgesByID(ctx context.Context, chanID uint64) ( + *models.ChannelEdgeInfo, *models.ChannelEdgePolicy, + *models.ChannelEdgePolicy, error) + + // IsPublicNode is a helper method that determines whether the node with + // the given public key is seen as a public node in the graph from the + // graph's source node's point of view. + IsPublicNode(ctx context.Context, pubKey [33]byte) (bool, error) +} diff --git a/lnrpc/invoicesrpc/invoices_server.go b/lnrpc/invoicesrpc/invoices_server.go index 8ca02b260d..7b00b9b643 100644 --- a/lnrpc/invoicesrpc/invoices_server.go +++ b/lnrpc/invoicesrpc/invoices_server.go @@ -346,7 +346,7 @@ func (s *Server) AddHoldInvoice(ctx context.Context, NodeSigner: s.cfg.NodeSigner, DefaultCLTVExpiry: s.cfg.DefaultCLTVExpiry, ChanDB: s.cfg.ChanStateDB, - Graph: s.cfg.GraphDB, + Graph: s.cfg.Graph, GenInvoiceFeatures: s.cfg.GenInvoiceFeatures, GenAmpInvoiceFeatures: s.cfg.GenAmpInvoiceFeatures, GetAlias: s.cfg.GetAlias, diff --git a/lnrpc/routerrpc/router_backend.go b/lnrpc/routerrpc/router_backend.go index 9421e991b6..5aabbf618d 100644 --- a/lnrpc/routerrpc/router_backend.go +++ b/lnrpc/routerrpc/router_backend.go @@ -50,7 +50,8 @@ type RouterBackend struct { // FetchAmountPairCapacity determines the maximal channel capacity // between two nodes given a certain amount. - FetchAmountPairCapacity func(nodeFrom, nodeTo route.Vertex, + FetchAmountPairCapacity func(ctx context.Context, nodeFrom, + nodeTo route.Vertex, amount lnwire.MilliSatoshi) (btcutil.Amount, error) // FetchChannelEndpoints returns the pubkeys of both endpoints of the @@ -60,7 +61,8 @@ type RouterBackend struct { // FindRoute is a closure that abstracts away how we locate/query for // routes. - FindRoute func(*routing.RouteRequest) (*route.Route, float64, error) + FindRoute func(context.Context, *routing.RouteRequest) (*route.Route, + float64, error) MissionControl MissionControl @@ -169,7 +171,7 @@ func (r *RouterBackend) QueryRoutes(ctx context.Context, // Query the channel router for a possible path to the destination that // can carry `in.Amt` satoshis _including_ the total fee required on // the route - route, successProb, err := r.FindRoute(routeReq) + route, successProb, err := r.FindRoute(ctx, routeReq) if err != nil { return nil, err } diff --git a/lnrpc/routerrpc/router_backend_test.go b/lnrpc/routerrpc/router_backend_test.go index 877a3cc171..78e25f7e5a 100644 --- a/lnrpc/routerrpc/router_backend_test.go +++ b/lnrpc/routerrpc/router_backend_test.go @@ -120,8 +120,8 @@ func testQueryRoutes(t *testing.T, useMissionControl bool, useMsat bool, } } - findRoute := func(req *routing.RouteRequest) (*route.Route, float64, - error) { + findRoute := func(_ context.Context, req *routing.RouteRequest) ( + *route.Route, float64, error) { if int64(req.Amount) != amtSat*1000 { t.Fatal("unexpected amount") @@ -200,7 +200,8 @@ func testQueryRoutes(t *testing.T, useMissionControl bool, useMsat bool, return 1, nil }, - FetchAmountPairCapacity: func(nodeFrom, nodeTo route.Vertex, + FetchAmountPairCapacity: func(_ context.Context, nodeFrom, + nodeTo route.Vertex, amount lnwire.MilliSatoshi) (btcutil.Amount, error) { return 1, nil diff --git a/lnrpc/routerrpc/router_server.go b/lnrpc/routerrpc/router_server.go index a4112ba646..b4460d87da 100644 --- a/lnrpc/routerrpc/router_server.go +++ b/lnrpc/routerrpc/router_server.go @@ -426,7 +426,7 @@ func (s *Server) EstimateRouteFee(ctx context.Context, return nil, errors.New("amount must be greater than 0") default: - return s.probeDestination(req.Dest, req.AmtSat) + return s.probeDestination(ctx, req.Dest, req.AmtSat) } case isProbeInvoice: @@ -440,8 +440,8 @@ func (s *Server) EstimateRouteFee(ctx context.Context, // probeDestination estimates fees along a route to a destination based on the // contents of the local graph. -func (s *Server) probeDestination(dest []byte, amtSat int64) (*RouteFeeResponse, - error) { +func (s *Server) probeDestination(ctx context.Context, dest []byte, + amtSat int64) (*RouteFeeResponse, error) { destNode, err := route.NewVertexFromBytes(dest) if err != nil { @@ -469,7 +469,7 @@ func (s *Server) probeDestination(dest []byte, amtSat int64) (*RouteFeeResponse, return nil, err } - route, _, err := s.cfg.Router.FindRoute(routeReq) + route, _, err := s.cfg.Router.FindRoute(ctx, routeReq) if err != nil { return nil, err } @@ -1429,7 +1429,7 @@ func (s *Server) trackPaymentStream(context context.Context, } // BuildRoute builds a route from a list of hop addresses. -func (s *Server) BuildRoute(_ context.Context, +func (s *Server) BuildRoute(ctx context.Context, req *BuildRouteRequest) (*BuildRouteResponse, error) { if len(req.HopPubkeys) == 0 { @@ -1490,7 +1490,7 @@ func (s *Server) BuildRoute(_ context.Context, // Build the route and return it to the caller. route, err := s.cfg.Router.BuildRoute( - amt, hops, outgoingChan, req.FinalCltvDelta, payAddr, + ctx, amt, hops, outgoingChan, req.FinalCltvDelta, payAddr, firstHopBlob, ) if err != nil { diff --git a/lnrpc/routerrpc/router_server_deprecated.go b/lnrpc/routerrpc/router_server_deprecated.go index 7be1e3d919..fee5bcab58 100644 --- a/lnrpc/routerrpc/router_server_deprecated.go +++ b/lnrpc/routerrpc/router_server_deprecated.go @@ -123,7 +123,7 @@ func (s *Server) SendToRoute(ctx context.Context, // QueryProbability returns the current success probability estimate for a // given node pair and amount. -func (s *Server) QueryProbability(_ context.Context, +func (s *Server) QueryProbability(ctx context.Context, req *QueryProbabilityRequest) (*QueryProbabilityResponse, error) { fromNode, err := route.NewVertexFromBytes(req.FromNode) @@ -142,7 +142,7 @@ func (s *Server) QueryProbability(_ context.Context, var prob float64 mc := s.cfg.RouterBackend.MissionControl capacity, err := s.cfg.RouterBackend.FetchAmountPairCapacity( - fromNode, toNode, amt, + ctx, fromNode, toNode, amt, ) // If we cannot query the capacity this means that either we don't have diff --git a/log.go b/log.go index 795fb4d729..720343913b 100644 --- a/log.go +++ b/log.go @@ -22,6 +22,7 @@ import ( "github.com/lightningnetwork/lnd/funding" "github.com/lightningnetwork/lnd/graph" graphdb "github.com/lightningnetwork/lnd/graph/db" + "github.com/lightningnetwork/lnd/graph/sources" "github.com/lightningnetwork/lnd/healthcheck" "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/invoices" @@ -196,6 +197,7 @@ func SetupLoggers(root *build.SubLoggerManager, interceptor signal.Interceptor) root, blindedpath.Subsystem, interceptor, blindedpath.UseLogger, ) AddV1SubLogger(root, graphdb.Subsystem, interceptor, graphdb.UseLogger) + AddSubLogger(root, sources.Subsystem, interceptor, sources.UseLogger) } // AddSubLogger is a helper method to conveniently create and register the diff --git a/macaroons/service_test.go b/macaroons/service_test.go index aad8af8db0..e28a03589c 100644 --- a/macaroons/service_test.go +++ b/macaroons/service_test.go @@ -55,6 +55,7 @@ func setupTestRootKeyStorage(t *testing.T) kvdb.Backend { // TestNewService tests the creation of the macaroon service. func TestNewService(t *testing.T) { t.Parallel() + ctx := context.Background() // First, initialize a dummy DB file with a store that the service // can read from. Make sure the file is removed in the end. @@ -74,13 +75,13 @@ func TestNewService(t *testing.T) { require.NoError(t, err, "Error unlocking root key storage") // Third, check if the created service can bake macaroons. - _, err = service.NewMacaroon(context.TODO(), nil, testOperation) + _, err = service.NewMacaroon(ctx, nil, testOperation) if err != macaroons.ErrMissingRootKeyID { t.Fatalf("Received %v instead of ErrMissingRootKeyID", err) } macaroon, err := service.NewMacaroon( - context.TODO(), macaroons.DefaultRootKeyID, testOperation, + ctx, macaroons.DefaultRootKeyID, testOperation, ) require.NoError(t, err, "Error creating macaroon from service") if macaroon.Namespace().String() != "std:" { @@ -108,6 +109,7 @@ func TestNewService(t *testing.T) { // incoming context. func TestValidateMacaroon(t *testing.T) { t.Parallel() + ctx := context.Background() // First, initialize the service and unlock it. db := setupTestRootKeyStorage(t) @@ -124,7 +126,7 @@ func TestValidateMacaroon(t *testing.T) { // Then, create a new macaroon that we can serialize. macaroon, err := service.NewMacaroon( - context.TODO(), macaroons.DefaultRootKeyID, testOperation, + ctx, macaroons.DefaultRootKeyID, testOperation, testOperationURI, ) require.NoError(t, err, "Error creating macaroon from service") @@ -155,6 +157,7 @@ func TestValidateMacaroon(t *testing.T) { // TestListMacaroonIDs checks that ListMacaroonIDs returns the expected result. func TestListMacaroonIDs(t *testing.T) { t.Parallel() + ctx := context.Background() // First, initialize a dummy DB file with a store that the service // can read from. Make sure the file is removed in the end. @@ -176,12 +179,12 @@ func TestListMacaroonIDs(t *testing.T) { // Third, make 3 new macaroons with different root key IDs. expectedIDs := [][]byte{{1}, {2}, {3}} for _, v := range expectedIDs { - _, err := service.NewMacaroon(context.TODO(), v, testOperation) + _, err := service.NewMacaroon(ctx, v, testOperation) require.NoError(t, err, "Error creating macaroon from service") } // Finally, check that calling List return the expected values. - ids, _ := service.ListMacaroonIDs(context.TODO()) + ids, _ := service.ListMacaroonIDs(ctx) require.Equal(t, expectedIDs, ids, "root key IDs mismatch") } diff --git a/macaroons/store_test.go b/macaroons/store_test.go index ace1e764a1..2246f63222 100644 --- a/macaroons/store_test.go +++ b/macaroons/store_test.go @@ -58,12 +58,13 @@ func openTestStore(t *testing.T, tempDir string) *macaroons.RootKeyStorage { // TestStore tests the normal use cases of the store like creating, unlocking, // reading keys and closing it. func TestStore(t *testing.T) { + ctx := context.Background() tempDir, store := newTestStore(t) - _, _, err := store.RootKey(context.TODO()) + _, _, err := store.RootKey(ctx) require.Equal(t, macaroons.ErrStoreLocked, err) - _, err = store.Get(context.TODO(), nil) + _, err = store.Get(ctx, nil) require.Equal(t, macaroons.ErrStoreLocked, err) pw := []byte("weks") @@ -72,18 +73,18 @@ func TestStore(t *testing.T) { // Check ErrContextRootKeyID is returned when no root key ID found in // context. - _, _, err = store.RootKey(context.TODO()) + _, _, err = store.RootKey(ctx) require.Equal(t, macaroons.ErrContextRootKeyID, err) // Check ErrMissingRootKeyID is returned when empty root key ID is used. emptyKeyID := make([]byte, 0) - badCtx := macaroons.ContextWithRootKeyID(context.TODO(), emptyKeyID) + badCtx := macaroons.ContextWithRootKeyID(ctx, emptyKeyID) _, _, err = store.RootKey(badCtx) require.Equal(t, macaroons.ErrMissingRootKeyID, err) // Create a context with illegal root key ID value. encryptedKeyID := []byte("enckey") - badCtx = macaroons.ContextWithRootKeyID(context.TODO(), encryptedKeyID) + badCtx = macaroons.ContextWithRootKeyID(ctx, encryptedKeyID) _, _, err = store.RootKey(badCtx) require.Equal(t, macaroons.ErrKeyValueForbidden, err) diff --git a/netann/chan_status_manager.go b/netann/chan_status_manager.go index feb3a5dd19..0578bcc9d8 100644 --- a/netann/chan_status_manager.go +++ b/netann/chan_status_manager.go @@ -1,6 +1,7 @@ package netann import ( + "context" "errors" "sync" "time" @@ -8,6 +9,7 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/fn" graphdb "github.com/lightningnetwork/lnd/graph/db" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lnwallet" @@ -127,8 +129,9 @@ type ChanStatusManager struct { // become inactive. statusSampleTicker *time.Ticker - wg sync.WaitGroup - quit chan struct{} + wg sync.WaitGroup + quit chan struct{} + cancel fn.Option[context.CancelFunc] } // NewChanStatusManager initializes a new ChanStatusManager using the given @@ -176,12 +179,16 @@ func (m *ChanStatusManager) Start() error { var err error m.started.Do(func() { log.Info("Channel Status Manager starting") - err = m.start() + + ctx, cancel := context.WithCancel(context.Background()) + m.cancel = fn.Some(cancel) + + err = m.start(ctx) }) return err } -func (m *ChanStatusManager) start() error { +func (m *ChanStatusManager) start(ctx context.Context) error { channels, err := m.fetchChannels() if err != nil { return err @@ -189,7 +196,7 @@ func (m *ChanStatusManager) start() error { // Populate the initial states of all confirmed, public channels. for _, c := range channels { - _, err := m.getOrInitChanStatus(c.FundingOutpoint) + _, err := m.getOrInitChanStatus(ctx, c.FundingOutpoint) switch { // If we can't retrieve the edge info for this channel, it may @@ -218,7 +225,7 @@ func (m *ChanStatusManager) start() error { } m.wg.Add(1) - go m.statusManager() + go m.statusManager(ctx) return nil } @@ -229,6 +236,10 @@ func (m *ChanStatusManager) Stop() error { log.Info("Channel Status Manager shutting down...") defer log.Debug("Channel Status Manager shutdown complete") + m.cancel.WhenSome(func(cancel context.CancelFunc) { + cancel() + }) + close(m.quit) m.wg.Wait() }) @@ -332,7 +343,7 @@ func (m *ChanStatusManager) submitRequest(reqChan chan statusRequest, // should be scheduled or broadcast. // // NOTE: This method MUST be run as a goroutine. -func (m *ChanStatusManager) statusManager() { +func (m *ChanStatusManager) statusManager(ctx context.Context) { defer m.wg.Done() for { @@ -340,15 +351,20 @@ func (m *ChanStatusManager) statusManager() { // Process any requests to mark channel as enabled. case req := <-m.enableRequests: - req.errChan <- m.processEnableRequest(req.outpoint, req.manual) + req.errChan <- m.processEnableRequest( + ctx, req.outpoint, req.manual, + ) // Process any requests to mark channel as disabled. case req := <-m.disableRequests: - req.errChan <- m.processDisableRequest(req.outpoint, req.manual) + req.errChan <- m.processDisableRequest( + ctx, req.outpoint, req.manual, + ) - // Process any requests to restore automatic channel state management. + // Process any requests to restore automatic channel state + // management. case req := <-m.autoRequests: - req.errChan <- m.processAutoRequest(req.outpoint) + req.errChan <- m.processAutoRequest(ctx, req.outpoint) // Use long-polling to detect when channels become inactive. case <-m.statusSampleTicker.C: @@ -357,12 +373,12 @@ func (m *ChanStatusManager) statusManager() { // ChanStatusPendingDisabled. The channel will then be // disabled if no request to enable is received before // the ChanDisableTimeout expires. - m.markPendingInactiveChannels() + m.markPendingInactiveChannels(ctx) // Now, do another sweep to disable any channels that // were marked in a prior iteration as pending inactive // if the inactive chan timeout has elapsed. - m.disableInactiveChannels() + m.disableInactiveChannels(ctx) case <-m.quit: return @@ -382,10 +398,10 @@ func (m *ChanStatusManager) statusManager() { // // An update will be broadcast only if the channel is currently disabled, // otherwise no update will be sent on the network. -func (m *ChanStatusManager) processEnableRequest(outpoint wire.OutPoint, - manual bool) error { +func (m *ChanStatusManager) processEnableRequest(ctx context.Context, + outpoint wire.OutPoint, manual bool) error { - curState, err := m.getOrInitChanStatus(outpoint) + curState, err := m.getOrInitChanStatus(ctx, outpoint) if err != nil { return err } @@ -422,7 +438,7 @@ func (m *ChanStatusManager) processEnableRequest(outpoint wire.OutPoint, case ChanStatusDisabled: log.Infof("Announcing channel(%v) enabled", outpoint) - err := m.signAndSendNextUpdate(outpoint, false) + err := m.signAndSendNextUpdate(ctx, outpoint, false) if err != nil { return err } @@ -440,10 +456,10 @@ func (m *ChanStatusManager) processEnableRequest(outpoint wire.OutPoint, // // An update will only be sent if the channel has a status other than // ChanStatusEnabled, otherwise no update will be sent on the network. -func (m *ChanStatusManager) processDisableRequest(outpoint wire.OutPoint, - manual bool) error { +func (m *ChanStatusManager) processDisableRequest(ctx context.Context, + outpoint wire.OutPoint, manual bool) error { - curState, err := m.getOrInitChanStatus(outpoint) + curState, err := m.getOrInitChanStatus(ctx, outpoint) if err != nil { return err } @@ -453,7 +469,7 @@ func (m *ChanStatusManager) processDisableRequest(outpoint wire.OutPoint, log.Infof("Announcing channel(%v) disabled [requested]", outpoint) - err := m.signAndSendNextUpdate(outpoint, true) + err := m.signAndSendNextUpdate(ctx, outpoint, true) if err != nil { return err } @@ -482,8 +498,10 @@ func (m *ChanStatusManager) processDisableRequest(outpoint wire.OutPoint, // which automatic / background requests are ignored). // // No update will be sent on the network. -func (m *ChanStatusManager) processAutoRequest(outpoint wire.OutPoint) error { - curState, err := m.getOrInitChanStatus(outpoint) +func (m *ChanStatusManager) processAutoRequest(ctx context.Context, + outpoint wire.OutPoint) error { + + curState, err := m.getOrInitChanStatus(ctx, outpoint) if err != nil { return err } @@ -504,7 +522,7 @@ func (m *ChanStatusManager) processAutoRequest(outpoint wire.OutPoint) error { // request to enable is received before the scheduled disable is broadcast, or // the channel is successfully re-enabled and channel is returned to an active // state from the POV of the ChanStatusManager. -func (m *ChanStatusManager) markPendingInactiveChannels() { +func (m *ChanStatusManager) markPendingInactiveChannels(ctx context.Context) { channels, err := m.fetchChannels() if err != nil { log.Errorf("Unable to load active channels: %v", err) @@ -514,7 +532,7 @@ func (m *ChanStatusManager) markPendingInactiveChannels() { for _, c := range channels { // Determine the initial status of the active channel, and // populate the entry in the chanStates map. - curState, err := m.getOrInitChanStatus(c.FundingOutpoint) + curState, err := m.getOrInitChanStatus(ctx, c.FundingOutpoint) if err != nil { log.Errorf("Unable to retrieve chan status for "+ "Channel(%v): %v", c.FundingOutpoint, err) @@ -553,7 +571,7 @@ func (m *ChanStatusManager) markPendingInactiveChannels() { // disableInactiveChannels scans through the set of monitored channels, and // broadcast a disable update for any pending inactive channels whose // SendDisableTime has been superseded by the current time. -func (m *ChanStatusManager) disableInactiveChannels() { +func (m *ChanStatusManager) disableInactiveChannels(ctx context.Context) { // Now, disable any channels whose inactive chan timeout has elapsed. now := time.Now() for outpoint, state := range m.chanStates { @@ -572,7 +590,7 @@ func (m *ChanStatusManager) disableInactiveChannels() { "[detected]", outpoint) // Sign an update disabling the channel. - err := m.signAndSendNextUpdate(outpoint, true) + err := m.signAndSendNextUpdate(ctx, outpoint, true) if err != nil { log.Errorf("Unable to sign update disabling "+ "channel(%v): %v", outpoint, err) @@ -625,12 +643,14 @@ func (m *ChanStatusManager) fetchChannels() ([]*channeldb.OpenChannel, error) { // use the current time as the update's timestamp, or increment the old // timestamp by 1 to ensure the update can propagate. If signing is successful, // the new update will be sent out on the network. -func (m *ChanStatusManager) signAndSendNextUpdate(outpoint wire.OutPoint, - disabled bool) error { +func (m *ChanStatusManager) signAndSendNextUpdate(ctx context.Context, + outpoint wire.OutPoint, disabled bool) error { // Retrieve the latest update for this channel. We'll use this // as our starting point to send the new update. - chanUpdate, private, err := m.fetchLastChanUpdateByOutPoint(outpoint) + chanUpdate, private, err := m.fetchLastChanUpdateByOutPoint( + ctx, outpoint, + ) if err != nil { return err } @@ -650,11 +670,13 @@ func (m *ChanStatusManager) signAndSendNextUpdate(outpoint wire.OutPoint, // a channel, and crafts a new ChannelUpdate with this policy. Returns an error // in case our ChannelEdgePolicy is not found in the database. Also returns if // the channel is private by checking AuthProof for nil. -func (m *ChanStatusManager) fetchLastChanUpdateByOutPoint(op wire.OutPoint) ( - *lnwire.ChannelUpdate1, bool, error) { +func (m *ChanStatusManager) fetchLastChanUpdateByOutPoint(ctx context.Context, + op wire.OutPoint) (*lnwire.ChannelUpdate1, bool, error) { // Get the edge info and policies for this channel from the graph. - info, edge1, edge2, err := m.cfg.Graph.FetchChannelEdgesByOutpoint(&op) + info, edge1, edge2, err := m.cfg.Graph.FetchChannelEdgesByOutpoint( + ctx, &op, + ) if err != nil { return nil, false, err } @@ -670,10 +692,10 @@ func (m *ChanStatusManager) fetchLastChanUpdateByOutPoint(op wire.OutPoint) ( // ChanStatusEnabled or ChanStatusDisabled, determined by inspecting the bits on // the most recent announcement. An error is returned if the latest update could // not be retrieved. -func (m *ChanStatusManager) loadInitialChanState( +func (m *ChanStatusManager) loadInitialChanState(ctx context.Context, outpoint *wire.OutPoint) (ChannelState, error) { - lastUpdate, _, err := m.fetchLastChanUpdateByOutPoint(*outpoint) + lastUpdate, _, err := m.fetchLastChanUpdateByOutPoint(ctx, *outpoint) if err != nil { return ChannelState{}, err } @@ -696,7 +718,7 @@ func (m *ChanStatusManager) loadInitialChanState( // outpoint. If the chanStates map already contains an entry for the outpoint, // the value in the map is returned. Otherwise, the outpoint's initial status is // computed and updated in the chanStates map before being returned. -func (m *ChanStatusManager) getOrInitChanStatus( +func (m *ChanStatusManager) getOrInitChanStatus(ctx context.Context, outpoint wire.OutPoint) (ChannelState, error) { // Return the current ChannelState from the chanStates map if it is @@ -707,7 +729,7 @@ func (m *ChanStatusManager) getOrInitChanStatus( // Otherwise, determine the initial state based on the last update we // sent for the outpoint. - initialState, err := m.loadInitialChanState(&outpoint) + initialState, err := m.loadInitialChanState(ctx, &outpoint) if err != nil { return ChannelState{}, err } diff --git a/netann/chan_status_manager_test.go b/netann/chan_status_manager_test.go index 320981d630..4c3cc25aeb 100644 --- a/netann/chan_status_manager_test.go +++ b/netann/chan_status_manager_test.go @@ -2,6 +2,7 @@ package netann_test import ( "bytes" + "context" "crypto/rand" "encoding/binary" "fmt" @@ -160,7 +161,7 @@ func (g *mockGraph) FetchAllOpenChannels() ([]*channeldb.OpenChannel, error) { return g.chans(), nil } -func (g *mockGraph) FetchChannelEdgesByOutpoint( +func (g *mockGraph) FetchChannelEdgesByOutpoint(ctx context.Context, op *wire.OutPoint) (*models.ChannelEdgeInfo, *models.ChannelEdgePolicy, *models.ChannelEdgePolicy, error) { diff --git a/netann/interface.go b/netann/interface.go index aa559435d4..78acc24cdd 100644 --- a/netann/interface.go +++ b/netann/interface.go @@ -1,6 +1,8 @@ package netann import ( + "context" + "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/graph/db/models" @@ -19,6 +21,7 @@ type DB interface { type ChannelGraph interface { // FetchChannelEdgesByOutpoint returns the channel edge info and most // recent channel edge policies for a given outpoint. - FetchChannelEdgesByOutpoint(*wire.OutPoint) (*models.ChannelEdgeInfo, - *models.ChannelEdgePolicy, *models.ChannelEdgePolicy, error) + FetchChannelEdgesByOutpoint(context.Context, *wire.OutPoint) ( + *models.ChannelEdgeInfo, *models.ChannelEdgePolicy, + *models.ChannelEdgePolicy, error) } diff --git a/peer/test_utils.go b/peer/test_utils.go index eb510a53b1..ce92dadff3 100644 --- a/peer/test_utils.go +++ b/peer/test_utils.go @@ -20,6 +20,7 @@ import ( "github.com/lightningnetwork/lnd/channelnotifier" "github.com/lightningnetwork/lnd/fn" graphdb "github.com/lightningnetwork/lnd/graph/db" + "github.com/lightningnetwork/lnd/graph/sources" "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" @@ -628,7 +629,7 @@ func createTestPeer(t *testing.T) *peerTestCtx { ChanEnableTimeout: chanActiveTimeout, ChanDisableTimeout: 2 * time.Minute, DB: dbAliceChannel.ChannelStateDB(), - Graph: dbAliceGraph, + Graph: sources.NewDBGSource(dbAliceGraph), MessageSigner: nodeSignerAlice, OurPubKey: aliceKeyPub, OurKeyLoc: testKeyLoc, diff --git a/pilot.go b/pilot.go index 11333a0722..f91dd21a3a 100644 --- a/pilot.go +++ b/pilot.go @@ -1,6 +1,7 @@ package lnd import ( + "context" "errors" "fmt" "net" @@ -136,7 +137,7 @@ var _ autopilot.ChannelController = (*chanController)(nil) // Agent instance based on the passed configuration structs. The agent and all // interfaces needed to drive it won't be launched before the Manager's // StartAgent method is called. -func initAutoPilot(svr *server, cfg *lncfg.AutoPilot, +func initAutoPilot(ctx context.Context, svr *server, cfg *lncfg.AutoPilot, minHTLCIn lnwire.MilliSatoshi, netParams chainreg.BitcoinNetParams) ( *autopilot.ManagerCfg, error) { @@ -224,7 +225,8 @@ func initAutoPilot(svr *server, cfg *lncfg.AutoPilot, } err := svr.ConnectToPeer( - lnAddr, false, svr.cfg.ConnectionTimeout, + ctx, lnAddr, false, + svr.cfg.ConnectionTimeout, ) if err != nil { // If we weren't able to connect to the diff --git a/routing/bandwidth.go b/routing/bandwidth.go index 12e82131dc..186417327b 100644 --- a/routing/bandwidth.go +++ b/routing/bandwidth.go @@ -1,6 +1,7 @@ package routing import ( + "context" "fmt" "github.com/lightningnetwork/lnd/fn" @@ -82,8 +83,9 @@ type bandwidthManager struct { // hints for the edges we directly have open ourselves. Obtaining these hints // allows us to reduce the number of extraneous attempts as we can skip channels // that are inactive, or just don't have enough bandwidth to carry the payment. -func newBandwidthManager(graph Graph, sourceNode route.Vertex, - linkQuery getLinkQuery, firstHopBlob fn.Option[tlv.Blob], +func newBandwidthManager(ctx context.Context, graph Graph, + sourceNode route.Vertex, linkQuery getLinkQuery, + firstHopBlob fn.Option[tlv.Blob], trafficShaper fn.Option[TlvTrafficShaper]) (*bandwidthManager, error) { manager := &bandwidthManager{ @@ -95,7 +97,7 @@ func newBandwidthManager(graph Graph, sourceNode route.Vertex, // First, we'll collect the set of outbound edges from the target // source node and add them to our bandwidth manager's map of channels. - err := graph.ForEachNodeChannel(sourceNode, + err := graph.ForEachNodeChannel(ctx, sourceNode, func(channel *graphdb.DirectedChannel) error { shortID := lnwire.NewShortChanIDFromInt( channel.ChannelID, diff --git a/routing/bandwidth_test.go b/routing/bandwidth_test.go index 4872b5a7ec..083559f79f 100644 --- a/routing/bandwidth_test.go +++ b/routing/bandwidth_test.go @@ -1,6 +1,7 @@ package routing import ( + "context" "testing" "github.com/btcsuite/btcd/btcutil" @@ -116,7 +117,8 @@ func TestBandwidthManager(t *testing.T) { ) m, err := newBandwidthManager( - g, sourceNode.pubkey, testCase.linkQuery, + context.Background(), g, sourceNode.pubkey, + testCase.linkQuery, fn.None[[]byte](), fn.Some[TlvTrafficShaper](&mockTrafficShaper{}), ) diff --git a/routing/blindedpath/blinded_path.go b/routing/blindedpath/blinded_path.go index a1f9db7b6b..e3d8378b9f 100644 --- a/routing/blindedpath/blinded_path.go +++ b/routing/blindedpath/blinded_path.go @@ -2,6 +2,7 @@ package blindedpath import ( "bytes" + "context" "errors" "fmt" "math" @@ -38,11 +39,13 @@ type BuildBlindedPathCfg struct { // various lengths and may even contain only a single hop. Any route // shorter than MinNumHops will be padded with dummy hops during route // construction. - FindRoutes func(value lnwire.MilliSatoshi) ([]*route.Route, error) + FindRoutes func(ctx context.Context, value lnwire.MilliSatoshi) ( + []*route.Route, error) // FetchChannelEdgesByID attempts to look up the two directed edges for // the channel identified by the channel ID. - FetchChannelEdgesByID func(chanID uint64) (*models.ChannelEdgeInfo, + FetchChannelEdgesByID func(ctx context.Context, + chanID uint64) (*models.ChannelEdgeInfo, *models.ChannelEdgePolicy, *models.ChannelEdgePolicy, error) // FetchOurOpenChannels fetches this node's set of open channels. @@ -111,12 +114,12 @@ type BuildBlindedPathCfg struct { // BuildBlindedPaymentPaths uses the passed config to construct a set of blinded // payment paths that can be added to the invoice. -func BuildBlindedPaymentPaths(cfg *BuildBlindedPathCfg) ( +func BuildBlindedPaymentPaths(ctx context.Context, cfg *BuildBlindedPathCfg) ( []*zpay32.BlindedPaymentPath, error) { // Find some appropriate routes for the value to be routed. This will // return a set of routes made up of real nodes. - routes, err := cfg.FindRoutes(cfg.ValueMsat) + routes, err := cfg.FindRoutes(ctx, cfg.ValueMsat) if err != nil { return nil, err } @@ -141,7 +144,7 @@ func BuildBlindedPaymentPaths(cfg *BuildBlindedPathCfg) ( // of hops is met. candidatePath.padWithDummyHops(cfg.MinNumHops) - path, err := buildBlindedPaymentPath(cfg, candidatePath) + path, err := buildBlindedPaymentPath(ctx, cfg, candidatePath) if errors.Is(err, errInvalidBlindedPath) { log.Debugf("Not using route (%s) as a blinded path "+ "since it resulted in an invalid blinded path", @@ -169,10 +172,10 @@ func BuildBlindedPaymentPaths(cfg *BuildBlindedPathCfg) ( // buildBlindedPaymentPath takes a route from an introduction node to this node // and uses the given config to convert it into a blinded payment path. -func buildBlindedPaymentPath(cfg *BuildBlindedPathCfg, path *candidatePath) ( - *zpay32.BlindedPaymentPath, error) { +func buildBlindedPaymentPath(ctx context.Context, cfg *BuildBlindedPathCfg, + path *candidatePath) (*zpay32.BlindedPaymentPath, error) { - hops, minHTLC, maxHTLC, err := collectRelayInfo(cfg, path) + hops, minHTLC, maxHTLC, err := collectRelayInfo(ctx, cfg, path) if err != nil { return nil, fmt.Errorf("could not collect blinded path relay "+ "info: %w", err) @@ -353,8 +356,9 @@ type hopRelayInfo struct { // policy values. If there are no real hops (in other words we are the // introduction node), then we use some default routing values and we use the // average of our channel capacities for the MaxHTLC value. -func collectRelayInfo(cfg *BuildBlindedPathCfg, path *candidatePath) ( - []*hopRelayInfo, lnwire.MilliSatoshi, lnwire.MilliSatoshi, error) { +func collectRelayInfo(ctx context.Context, cfg *BuildBlindedPathCfg, + path *candidatePath) ([]*hopRelayInfo, lnwire.MilliSatoshi, + lnwire.MilliSatoshi, error) { var ( // The first pub key is that of the introduction node. @@ -381,7 +385,7 @@ func collectRelayInfo(cfg *BuildBlindedPathCfg, path *candidatePath) ( // channel ID in the direction pointing away from the hopSource // node. policy, err := getNodeChannelPolicy( - cfg, hop.channelID, hopSource, + ctx, cfg, hop.channelID, hopSource, ) if err != nil { return nil, 0, 0, err @@ -638,12 +642,12 @@ func buildFinalHopRouteData(node route.Vertex, pathID []byte, // getNodeChanPolicy fetches the routing policy info for the given channel and // node pair. -func getNodeChannelPolicy(cfg *BuildBlindedPathCfg, chanID uint64, - nodeID route.Vertex) (*BlindedHopPolicy, error) { +func getNodeChannelPolicy(ctx context.Context, cfg *BuildBlindedPathCfg, + chanID uint64, nodeID route.Vertex) (*BlindedHopPolicy, error) { // Attempt to fetch channel updates for the given channel. We will have // at most two updates for a given channel. - _, update1, update2, err := cfg.FetchChannelEdgesByID(chanID) + _, update1, update2, err := cfg.FetchChannelEdgesByID(ctx, chanID) if err != nil { return nil, err } diff --git a/routing/blindedpath/blinded_path_test.go b/routing/blindedpath/blinded_path_test.go index 35db89afb7..f2fa7b4a07 100644 --- a/routing/blindedpath/blinded_path_test.go +++ b/routing/blindedpath/blinded_path_test.go @@ -2,6 +2,7 @@ package blindedpath import ( "bytes" + "context" "encoding/hex" "fmt" "math/rand" @@ -548,6 +549,8 @@ func genBlindedRouteData(rand *rand.Rand) *record.BlindedRouteData { // https://github.com/lightning/bolts/blob/master/proposals/route-blinding.md // This example does not use any dummy hops. func TestBuildBlindedPath(t *testing.T) { + ctx := context.Background() + // Alice chooses the following path to herself for blinded path // construction: // Carol -> Bob -> Alice. @@ -591,13 +594,13 @@ func TestBuildBlindedPath(t *testing.T) { }, } - paths, err := BuildBlindedPaymentPaths(&BuildBlindedPathCfg{ - FindRoutes: func(_ lnwire.MilliSatoshi) ([]*route.Route, - error) { + paths, err := BuildBlindedPaymentPaths(ctx, &BuildBlindedPathCfg{ + FindRoutes: func(_ context.Context, _ lnwire.MilliSatoshi) ( + []*route.Route, error) { return []*route.Route{realRoute}, nil }, - FetchChannelEdgesByID: func(chanID uint64) ( + FetchChannelEdgesByID: func(_ context.Context, chanID uint64) ( *models.ChannelEdgeInfo, *models.ChannelEdgePolicy, *models.ChannelEdgePolicy, error) { @@ -716,6 +719,8 @@ func TestBuildBlindedPath(t *testing.T) { // TestBuildBlindedPathWithDummyHops tests the construction of a blinded path // which includes dummy hops. func TestBuildBlindedPathWithDummyHops(t *testing.T) { + ctx := context.Background() + // Alice chooses the following path to herself for blinded path // construction: // Carol -> Bob -> Alice. @@ -759,13 +764,13 @@ func TestBuildBlindedPathWithDummyHops(t *testing.T) { }, } - paths, err := BuildBlindedPaymentPaths(&BuildBlindedPathCfg{ - FindRoutes: func(_ lnwire.MilliSatoshi) ([]*route.Route, - error) { + paths, err := BuildBlindedPaymentPaths(ctx, &BuildBlindedPathCfg{ + FindRoutes: func(_ context.Context, _ lnwire.MilliSatoshi) ( + []*route.Route, error) { return []*route.Route{realRoute}, nil }, - FetchChannelEdgesByID: func(chanID uint64) ( + FetchChannelEdgesByID: func(_ context.Context, chanID uint64) ( *models.ChannelEdgeInfo, *models.ChannelEdgePolicy, *models.ChannelEdgePolicy, error) { @@ -929,14 +934,14 @@ func TestBuildBlindedPathWithDummyHops(t *testing.T) { // the first 2 calls. FindRoutes returns 3 routes and so by the end, we // still get 1 valid path. var errCount int - paths, err = BuildBlindedPaymentPaths(&BuildBlindedPathCfg{ - FindRoutes: func(_ lnwire.MilliSatoshi) ([]*route.Route, - error) { + paths, err = BuildBlindedPaymentPaths(ctx, &BuildBlindedPathCfg{ + FindRoutes: func(_ context.Context, _ lnwire.MilliSatoshi) ( + []*route.Route, error) { return []*route.Route{realRoute, realRoute, realRoute}, nil }, - FetchChannelEdgesByID: func(chanID uint64) ( + FetchChannelEdgesByID: func(_ context.Context, chanID uint64) ( *models.ChannelEdgeInfo, *models.ChannelEdgePolicy, *models.ChannelEdgePolicy, error) { @@ -998,6 +1003,7 @@ func TestBuildBlindedPathWithDummyHops(t *testing.T) { // node. func TestSingleHopBlindedPath(t *testing.T) { var ( + ctx = context.Background() _, pkC = btcec.PrivKeyFromBytes([]byte{1}) carol = route.NewVertex(pkC) ) @@ -1009,9 +1015,9 @@ func TestSingleHopBlindedPath(t *testing.T) { Hops: []*route.Hop{}, } - paths, err := BuildBlindedPaymentPaths(&BuildBlindedPathCfg{ - FindRoutes: func(_ lnwire.MilliSatoshi) ([]*route.Route, - error) { + paths, err := BuildBlindedPaymentPaths(ctx, &BuildBlindedPathCfg{ + FindRoutes: func(_ context.Context, _ lnwire.MilliSatoshi) ( + []*route.Route, error) { return []*route.Route{realRoute}, nil }, diff --git a/routing/graph.go b/routing/graph.go index 7608ee92bb..dc17fea9fd 100644 --- a/routing/graph.go +++ b/routing/graph.go @@ -1,6 +1,7 @@ package routing import ( + "context" "fmt" "github.com/btcsuite/btcd/btcutil" @@ -14,11 +15,12 @@ import ( type Graph interface { // ForEachNodeChannel calls the callback for every channel of the given // node. - ForEachNodeChannel(nodePub route.Vertex, + ForEachNodeChannel(ctx context.Context, nodePub route.Vertex, cb func(channel *graphdb.DirectedChannel) error) error // FetchNodeFeatures returns the features of the given node. - FetchNodeFeatures(nodePub route.Vertex) (*lnwire.FeatureVector, error) + FetchNodeFeatures(ctx context.Context, nodePub route.Vertex) ( + *lnwire.FeatureVector, error) } // GraphSessionFactory can be used to produce a new Graph instance which can @@ -30,13 +32,14 @@ type GraphSessionFactory interface { // session. It returns the Graph along with a call-back that must be // called once Graph access is complete. This call-back will close any // read-only transaction that was created at Graph construction time. - NewGraphSession() (Graph, func() error, error) + NewGraphSession(ctx context.Context) (Graph, func() error, error) } // FetchAmountPairCapacity determines the maximal public capacity between two // nodes depending on the amount we try to send. -func FetchAmountPairCapacity(graph Graph, source, nodeFrom, nodeTo route.Vertex, - amount lnwire.MilliSatoshi) (btcutil.Amount, error) { +func FetchAmountPairCapacity(ctx context.Context, graph Graph, source, nodeFrom, + nodeTo route.Vertex, amount lnwire.MilliSatoshi) (btcutil.Amount, + error) { // Create unified edges for all incoming connections. // @@ -44,7 +47,7 @@ func FetchAmountPairCapacity(graph Graph, source, nodeFrom, nodeTo route.Vertex, // by a deprecated router rpc. u := newNodeEdgeUnifier(source, nodeTo, false, nil) - err := u.addGraphPolicies(graph) + err := u.addGraphPolicies(ctx, graph) if err != nil { return 0, err } diff --git a/routing/integrated_routing_context_test.go b/routing/integrated_routing_context_test.go index 315b0dff22..7bb1677b4c 100644 --- a/routing/integrated_routing_context_test.go +++ b/routing/integrated_routing_context_test.go @@ -1,6 +1,7 @@ package routing import ( + "context" "fmt" "math" "os" @@ -122,6 +123,8 @@ func (h htlcAttempt) String() string { func (c *integratedRoutingContext) testPayment(maxParts uint32, destFeatureBits ...lnwire.FeatureBit) ([]htlcAttempt, error) { + ctx := context.Background() + // We start out with the base set of MPP feature bits. If the caller // overrides this set of bits, then we'll use their feature bits // entirely. @@ -173,7 +176,9 @@ func (c *integratedRoutingContext) testPayment(maxParts uint32, ) require.NoError(c.t, err) - getBandwidthHints := func(_ Graph) (bandwidthHints, error) { + getBandwidthHints := func(_ context.Context, _ Graph) (bandwidthHints, + error) { + // Create bandwidth hints based on local channel balances. bandwidthHints := map[uint64]lnwire.MilliSatoshi{} for _, ch := range c.graph.nodes[c.source.pubkey].channels { @@ -235,8 +240,8 @@ func (c *integratedRoutingContext) testPayment(maxParts uint32, // Find a route. route, err := session.RequestRoute( - amtRemaining, lnwire.MaxMilliSatoshi, inFlightHtlcs, 0, - lnwire.CustomRecords{ + ctx, amtRemaining, lnwire.MaxMilliSatoshi, + inFlightHtlcs, 0, lnwire.CustomRecords{ lnwire.MinCustomRecordsTlvType: []byte{1, 2, 3}, }, ) @@ -326,8 +331,8 @@ func newMockGraphSessionFactory(graph Graph) GraphSessionFactory { return &mockGraphSessionFactory{Graph: graph} } -func (m *mockGraphSessionFactory) NewGraphSession() (Graph, func() error, - error) { +func (m *mockGraphSessionFactory) NewGraphSession(_ context.Context) (Graph, + func() error, error) { return m, func() error { return nil @@ -349,8 +354,8 @@ func newMockGraphSessionFactoryFromChanDB( } } -func (g *mockGraphSessionFactoryChanDB) NewGraphSession() (Graph, func() error, - error) { +func (g *mockGraphSessionFactoryChanDB) NewGraphSession(_ context.Context) ( + Graph, func() error, error) { tx, err := g.graph.NewPathFindTx() if err != nil { @@ -391,14 +396,15 @@ func (g *mockGraphSessionChanDB) close() error { return nil } -func (g *mockGraphSessionChanDB) ForEachNodeChannel(nodePub route.Vertex, +func (g *mockGraphSessionChanDB) ForEachNodeChannel(_ context.Context, + nodePub route.Vertex, cb func(channel *graphdb.DirectedChannel) error) error { return g.graph.ForEachNodeDirectedChannel(g.tx, nodePub, cb) } -func (g *mockGraphSessionChanDB) FetchNodeFeatures(nodePub route.Vertex) ( - *lnwire.FeatureVector, error) { +func (g *mockGraphSessionChanDB) FetchNodeFeatures(_ context.Context, + nodePub route.Vertex) (*lnwire.FeatureVector, error) { - return g.graph.FetchNodeFeatures(nodePub) + return g.graph.FetchNodeFeatures(g.tx, nodePub) } diff --git a/routing/mock_graph_test.go b/routing/mock_graph_test.go index cab7c97266..74fc483338 100644 --- a/routing/mock_graph_test.go +++ b/routing/mock_graph_test.go @@ -2,6 +2,7 @@ package routing import ( "bytes" + "context" "fmt" "testing" @@ -165,7 +166,7 @@ func (m *mockGraph) addChannel(id uint64, node1id, node2id byte, // forEachNodeChannel calls the callback for every channel of the given node. // // NOTE: Part of the Graph interface. -func (m *mockGraph) ForEachNodeChannel(nodePub route.Vertex, +func (m *mockGraph) ForEachNodeChannel(_ context.Context, nodePub route.Vertex, cb func(channel *graphdb.DirectedChannel) error) error { // Look up the mock node. @@ -221,7 +222,7 @@ func (m *mockGraph) sourceNode() route.Vertex { // fetchNodeFeatures returns the features of the given node. // // NOTE: Part of the Graph interface. -func (m *mockGraph) FetchNodeFeatures(nodePub route.Vertex) ( +func (m *mockGraph) FetchNodeFeatures(_ context.Context, _ route.Vertex) ( *lnwire.FeatureVector, error) { return lnwire.EmptyFeatureVector(), nil diff --git a/routing/mock_test.go b/routing/mock_test.go index 3cdb5ebaf2..b6d24d83a4 100644 --- a/routing/mock_test.go +++ b/routing/mock_test.go @@ -1,6 +1,7 @@ package routing import ( + "context" "fmt" "sync" @@ -168,9 +169,9 @@ type mockPaymentSessionOld struct { var _ PaymentSession = (*mockPaymentSessionOld)(nil) -func (m *mockPaymentSessionOld) RequestRoute(_, _ lnwire.MilliSatoshi, - _, height uint32, _ lnwire.CustomRecords) (*route.Route, - error) { +func (m *mockPaymentSessionOld) RequestRoute(_ context.Context, _, + _ lnwire.MilliSatoshi, _, height uint32, + _ lnwire.CustomRecords) (*route.Route, error) { if m.release != nil { m.release <- struct{}{} @@ -694,12 +695,13 @@ type mockPaymentSession struct { var _ PaymentSession = (*mockPaymentSession)(nil) -func (m *mockPaymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi, - activeShards, height uint32, +func (m *mockPaymentSession) RequestRoute(ctx context.Context, maxAmt, + feeLimit lnwire.MilliSatoshi, activeShards, height uint32, firstHopCustomRecords lnwire.CustomRecords) (*route.Route, error) { args := m.Called( - maxAmt, feeLimit, activeShards, height, firstHopCustomRecords, + ctx, maxAmt, feeLimit, activeShards, height, + firstHopCustomRecords, ) // Type assertion on nil will fail, so we check and return here. diff --git a/routing/pathfind.go b/routing/pathfind.go index db474e1e80..160752e223 100644 --- a/routing/pathfind.go +++ b/routing/pathfind.go @@ -3,6 +3,7 @@ package routing import ( "bytes" "container/heap" + "context" "errors" "fmt" "math" @@ -50,7 +51,7 @@ const ( ) // pathFinder defines the interface of a path finding algorithm. -type pathFinder = func(g *graphParams, r *RestrictParams, +type pathFinder = func(ctx context.Context, g *graphParams, r *RestrictParams, cfg *PathFindingConfig, self, source, target route.Vertex, amt lnwire.MilliSatoshi, timePref float64, finalHtlcExpiry int32) ( []*unifiedEdge, float64, error) @@ -491,8 +492,8 @@ type PathFindingConfig struct { // getOutgoingBalance returns the maximum available balance in any of the // channels of the given node. The second return parameters is the total // available balance. -func getOutgoingBalance(node route.Vertex, outgoingChans map[uint64]struct{}, - bandwidthHints bandwidthHints, +func getOutgoingBalance(ctx context.Context, node route.Vertex, + outgoingChans map[uint64]struct{}, bandwidthHints bandwidthHints, g Graph) (lnwire.MilliSatoshi, lnwire.MilliSatoshi, error) { var max, total lnwire.MilliSatoshi @@ -540,7 +541,7 @@ func getOutgoingBalance(node route.Vertex, outgoingChans map[uint64]struct{}, } // Iterate over all channels of the to node. - err := g.ForEachNodeChannel(node, cb) + err := g.ForEachNodeChannel(ctx, node, cb) if err != nil { return 0, 0, err } @@ -558,10 +559,10 @@ func getOutgoingBalance(node route.Vertex, outgoingChans map[uint64]struct{}, // source. This is to properly accumulate fees that need to be paid along the // path and accurately check the amount to forward at every node against the // available bandwidth. -func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, - self, source, target route.Vertex, amt lnwire.MilliSatoshi, - timePref float64, finalHtlcExpiry int32) ([]*unifiedEdge, float64, - error) { +func findPath(ctx context.Context, g *graphParams, r *RestrictParams, + cfg *PathFindingConfig, self, source, target route.Vertex, + amt lnwire.MilliSatoshi, timePref float64, finalHtlcExpiry int32) ( + []*unifiedEdge, float64, error) { // Pathfinding can be a significant portion of the total payment // latency, especially on low-powered devices. Log several metrics to @@ -580,7 +581,7 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, features := r.DestFeatures if features == nil { var err error - features, err = g.graph.FetchNodeFeatures(target) + features, err = g.graph.FetchNodeFeatures(ctx, target) if err != nil { return nil, 0, err } @@ -624,7 +625,7 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, // balance available. if source == self { max, total, err := getOutgoingBalance( - self, outgoingChanMap, g.bandwidthHints, g.graph, + ctx, self, outgoingChanMap, g.bandwidthHints, g.graph, ) if err != nil { return nil, 0, err @@ -968,7 +969,7 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, } // Fetch node features fresh from the graph. - fromFeatures, err := g.graph.FetchNodeFeatures(node) + fromFeatures, err := g.graph.FetchNodeFeatures(ctx, node) if err != nil { return nil, err } @@ -1008,7 +1009,7 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, self, pivot, !isExitHop, outgoingChanMap, ) - err := u.addGraphPolicies(g.graph) + err := u.addGraphPolicies(ctx, g.graph) if err != nil { return nil, 0, err } @@ -1183,7 +1184,7 @@ type blindedHop struct { // _and_ the introduction node for the path has more than one public channel. // Any filtering of paths based on payment value or success probabilities is // left to the caller. -func findBlindedPaths(g Graph, target route.Vertex, +func findBlindedPaths(ctx context.Context, g Graph, target route.Vertex, restrictions *blindedPathRestrictions) ([][]blindedHop, error) { // Sanity check the restrictions. @@ -1202,7 +1203,7 @@ func findBlindedPaths(g Graph, target route.Vertex, return true, nil } - features, err := g.FetchNodeFeatures(node) + features, err := g.FetchNodeFeatures(ctx, node) if err != nil { return false, err } @@ -1216,7 +1217,7 @@ func findBlindedPaths(g Graph, target route.Vertex, // a node that doesn't have any other edges - in that final case, the // whole path should be ignored. paths, _, err := processNodeForBlindedPath( - g, target, supportsRouteBlinding, nil, restrictions, + ctx, g, target, supportsRouteBlinding, nil, restrictions, ) if err != nil { return nil, err @@ -1251,7 +1252,7 @@ func findBlindedPaths(g Graph, target route.Vertex, // processNodeForBlindedPath is a recursive function that traverses the graph // in a depth first manner searching for a set of blinded paths to the given // node. -func processNodeForBlindedPath(g Graph, node route.Vertex, +func processNodeForBlindedPath(ctx context.Context, g Graph, node route.Vertex, supportsRouteBlinding func(vertex route.Vertex) (bool, error), alreadyVisited map[route.Vertex]bool, restrictions *blindedPathRestrictions) ([][]blindedHop, bool, error) { @@ -1298,7 +1299,7 @@ func processNodeForBlindedPath(g Graph, node route.Vertex, // Now, iterate over the node's channels in search for paths to this // node that can be used for blinded paths - err = g.ForEachNodeChannel(node, + err = g.ForEachNodeChannel(ctx, node, func(channel *graphdb.DirectedChannel) error { // Keep track of how many incoming channels this node // has. We only use a node as an introduction node if it @@ -1308,8 +1309,8 @@ func processNodeForBlindedPath(g Graph, node route.Vertex, // Process each channel peer to gather any paths that // lead to the peer. nextPaths, hasMoreChans, err := processNodeForBlindedPath( //nolint:lll - g, channel.OtherNode, supportsRouteBlinding, - visited, restrictions, + ctx, g, channel.OtherNode, + supportsRouteBlinding, visited, restrictions, ) if err != nil { return err diff --git a/routing/pathfind_test.go b/routing/pathfind_test.go index 81708d3930..92b36e63e1 100644 --- a/routing/pathfind_test.go +++ b/routing/pathfind_test.go @@ -2,6 +2,7 @@ package routing import ( "bytes" + "context" "crypto/sha256" "encoding/hex" "encoding/json" @@ -2221,7 +2222,7 @@ func TestPathFindSpecExample(t *testing.T) { ) require.NoError(t, err, "invalid route request") - route, _, err := ctx.router.FindRoute(req) + route, _, err := ctx.router.FindRoute(context.Background(), req) require.NoError(t, err, "unable to find route") // Now we'll examine the route returned for correctness. @@ -2248,7 +2249,7 @@ func TestPathFindSpecExample(t *testing.T) { ) require.NoError(t, err, "invalid route request") - route, _, err = ctx.router.FindRoute(req) + route, _, err = ctx.router.FindRoute(context.Background(), req) require.NoError(t, err, "unable to find routes") // The route should be two hops. @@ -3112,6 +3113,8 @@ func dbFindPath(graph *graphdb.ChannelGraph, source, target route.Vertex, amt lnwire.MilliSatoshi, timePref float64, finalHtlcExpiry int32) ([]*unifiedEdge, error) { + ctx := context.Background() + sourceNode, err := graph.SourceNode() if err != nil { return nil, err @@ -3119,7 +3122,7 @@ func dbFindPath(graph *graphdb.ChannelGraph, graphSessFactory := newMockGraphSessionFactoryFromChanDB(graph) - graphSess, closeGraphSess, err := graphSessFactory.NewGraphSession() + graphSess, closeGraphSess, err := graphSessFactory.NewGraphSession(ctx) if err != nil { return nil, err } @@ -3131,7 +3134,7 @@ func dbFindPath(graph *graphdb.ChannelGraph, }() route, _, err := findPath( - &graphParams{ + ctx, &graphParams{ additionalEdges: additionalEdges, bandwidthHints: bandwidthHints, graph: graphSess, @@ -3154,8 +3157,8 @@ func dbFindBlindedPaths(graph *graphdb.ChannelGraph, } return findBlindedPaths( - newMockGraphSessionChanDB(graph), sourceNode.PubKeyBytes, - restrictions, + context.Background(), newMockGraphSessionChanDB(graph), + sourceNode.PubKeyBytes, restrictions, ) } diff --git a/routing/payment_lifecycle.go b/routing/payment_lifecycle.go index 267ce3965d..aa5de10440 100644 --- a/routing/payment_lifecycle.go +++ b/routing/payment_lifecycle.go @@ -259,7 +259,7 @@ lifecycle: } // Now request a route to be used to create our HTLC attempt. - rt, err := p.requestRoute(ps) + rt, err := p.requestRoute(ctx, ps) if err != nil { return exitWithErr(err) } @@ -366,14 +366,14 @@ func (p *paymentLifecycle) checkContext(ctx context.Context) error { // requestRoute is responsible for finding a route to be used to create an HTLC // attempt. -func (p *paymentLifecycle) requestRoute( +func (p *paymentLifecycle) requestRoute(ctx context.Context, ps *channeldb.MPPaymentState) (*route.Route, error) { remainingFees := p.calcFeeBudget(ps.FeesPaid) // Query our payment session to construct a route. rt, err := p.paySession.RequestRoute( - ps.RemainingAmt, remainingFees, + ctx, ps.RemainingAmt, remainingFees, uint32(ps.NumAttemptsInFlight), uint32(p.currentHeight), p.firstHopCustomRecords, ) diff --git a/routing/payment_lifecycle_test.go b/routing/payment_lifecycle_test.go index 315c1bad58..4b77213441 100644 --- a/routing/payment_lifecycle_test.go +++ b/routing/payment_lifecycle_test.go @@ -380,10 +380,10 @@ func TestRequestRouteSucceed(t *testing.T) { // Mock the paySession's `RequestRoute` method to return no error. paySession.On("RequestRoute", mock.Anything, mock.Anything, mock.Anything, mock.Anything, - mock.Anything, + mock.Anything, mock.Anything, ).Return(dummyRoute, nil) - result, err := p.requestRoute(ps) + result, err := p.requestRoute(context.Background(), ps) require.NoError(t, err, "expect no error") require.Equal(t, dummyRoute, result, "returned route not matched") @@ -417,10 +417,10 @@ func TestRequestRouteHandleCriticalErr(t *testing.T) { // Mock the paySession's `RequestRoute` method to return an error. paySession.On("RequestRoute", mock.Anything, mock.Anything, mock.Anything, mock.Anything, - mock.Anything, + mock.Anything, mock.Anything, ).Return(nil, errDummy) - result, err := p.requestRoute(ps) + result, err := p.requestRoute(context.Background(), ps) // Expect an error is returned since it's critical. require.ErrorIs(t, err, errDummy, "error not matched") @@ -452,7 +452,7 @@ func TestRequestRouteHandleNoRouteErr(t *testing.T) { // type. m.paySession.On("RequestRoute", mock.Anything, mock.Anything, mock.Anything, mock.Anything, - mock.Anything, + mock.Anything, mock.Anything, ).Return(nil, errNoTlvPayload) // The payment should be failed with reason no route. @@ -460,7 +460,7 @@ func TestRequestRouteHandleNoRouteErr(t *testing.T) { p.identifier, channeldb.FailureReasonNoRoute, ).Return(nil).Once() - result, err := p.requestRoute(ps) + result, err := p.requestRoute(context.Background(), ps) // Expect no error is returned since it's not critical. require.NoError(t, err, "expected no error") @@ -500,10 +500,10 @@ func TestRequestRouteFailPaymentError(t *testing.T) { // Mock the paySession's `RequestRoute` method to return an error. paySession.On("RequestRoute", mock.Anything, mock.Anything, mock.Anything, mock.Anything, - mock.Anything, + mock.Anything, mock.Anything, ).Return(nil, errNoTlvPayload) - result, err := p.requestRoute(ps) + result, err := p.requestRoute(context.Background(), ps) // Expect an error is returned. require.ErrorIs(t, err, errDummy, "error not matched") @@ -876,7 +876,8 @@ func TestResumePaymentFailOnRequestRouteErr(t *testing.T) { // 4. mock requestRoute to return an error. m.paySession.On("RequestRoute", - paymentAmt, p.feeLimit, uint32(ps.NumAttemptsInFlight), + mock.Anything, paymentAmt, p.feeLimit, + uint32(ps.NumAttemptsInFlight), uint32(p.currentHeight), mock.Anything, ).Return(nil, errDummy).Once() @@ -922,7 +923,8 @@ func TestResumePaymentFailOnRegisterAttemptErr(t *testing.T) { // 4. mock requestRoute to return an route. m.paySession.On("RequestRoute", - paymentAmt, p.feeLimit, uint32(ps.NumAttemptsInFlight), + mock.Anything, paymentAmt, p.feeLimit, + uint32(ps.NumAttemptsInFlight), uint32(p.currentHeight), mock.Anything, ).Return(rt, nil).Once() @@ -982,7 +984,8 @@ func TestResumePaymentFailOnSendAttemptErr(t *testing.T) { // 4. mock requestRoute to return an route. m.paySession.On("RequestRoute", - paymentAmt, p.feeLimit, uint32(ps.NumAttemptsInFlight), + mock.Anything, paymentAmt, p.feeLimit, + uint32(ps.NumAttemptsInFlight), uint32(p.currentHeight), mock.Anything, ).Return(rt, nil).Once() @@ -1074,7 +1077,8 @@ func TestResumePaymentSuccess(t *testing.T) { // 1.4. mock requestRoute to return an route. m.paySession.On("RequestRoute", - paymentAmt, p.feeLimit, uint32(ps.NumAttemptsInFlight), + mock.Anything, paymentAmt, p.feeLimit, + uint32(ps.NumAttemptsInFlight), uint32(p.currentHeight), mock.Anything, ).Return(rt, nil).Once() @@ -1175,7 +1179,8 @@ func TestResumePaymentSuccessWithTwoAttempts(t *testing.T) { // 1.4. mock requestRoute to return an route. m.paySession.On("RequestRoute", - paymentAmt, p.feeLimit, uint32(ps.NumAttemptsInFlight), + mock.Anything, paymentAmt, p.feeLimit, + uint32(ps.NumAttemptsInFlight), uint32(p.currentHeight), mock.Anything, ).Return(rt, nil).Once() @@ -1237,7 +1242,8 @@ func TestResumePaymentSuccessWithTwoAttempts(t *testing.T) { // 2.4. mock requestRoute to return an route. m.paySession.On("RequestRoute", - paymentAmt/2, p.feeLimit, uint32(ps.NumAttemptsInFlight), + mock.Anything, paymentAmt/2, p.feeLimit, + uint32(ps.NumAttemptsInFlight), uint32(p.currentHeight), mock.Anything, ).Return(rt, nil).Once() diff --git a/routing/payment_session.go b/routing/payment_session.go index 0afdf822fb..9c962a88df 100644 --- a/routing/payment_session.go +++ b/routing/payment_session.go @@ -1,6 +1,7 @@ package routing import ( + "context" "fmt" "github.com/btcsuite/btcd/btcec/v2" @@ -137,7 +138,7 @@ type PaymentSession interface { // // A noRouteError is returned if a non-critical error is encountered // during path finding. - RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi, + RequestRoute(ctx context.Context, maxAmt, feeLimit lnwire.MilliSatoshi, activeShards, height uint32, firstHopCustomRecords lnwire.CustomRecords) (*route.Route, error) @@ -169,7 +170,7 @@ type paymentSession struct { additionalEdges map[route.Vertex][]AdditionalEdge - getBandwidthHints func(Graph) (bandwidthHints, error) + getBandwidthHints func(context.Context, Graph) (bandwidthHints, error) payment *LightningPayment @@ -197,7 +198,7 @@ type paymentSession struct { // newPaymentSession instantiates a new payment session. func newPaymentSession(p *LightningPayment, selfNode route.Vertex, - getBandwidthHints func(Graph) (bandwidthHints, error), + getBandwidthHints func(context.Context, Graph) (bandwidthHints, error), graphSessFactory GraphSessionFactory, missionControl MissionControlQuerier, pathFindingConfig PathFindingConfig) (*paymentSession, error) { @@ -244,8 +245,8 @@ func newPaymentSession(p *LightningPayment, selfNode route.Vertex, // // NOTE: This function is safe for concurrent access. // NOTE: Part of the PaymentSession interface. -func (p *paymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi, - activeShards, height uint32, +func (p *paymentSession) RequestRoute(ctx context.Context, maxAmt, + feeLimit lnwire.MilliSatoshi, activeShards, height uint32, firstHopCustomRecords lnwire.CustomRecords) (*route.Route, error) { if p.empty { @@ -297,7 +298,9 @@ func (p *paymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi, for { // Get a routing graph session. - graph, closeGraph, err := p.graphSessFactory.NewGraphSession() + graph, closeGraph, err := p.graphSessFactory.NewGraphSession( + ctx, + ) if err != nil { return nil, err } @@ -308,7 +311,7 @@ func (p *paymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi, // don't have enough bandwidth to carry the payment. New // bandwidth hints are queried for every new path finding // attempt, because concurrent payments may change balances. - bandwidthHints, err := p.getBandwidthHints(graph) + bandwidthHints, err := p.getBandwidthHints(ctx, graph) if err != nil { // Close routing graph session. if graphErr := closeGraph(); graphErr != nil { @@ -323,7 +326,7 @@ func (p *paymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi, // Find a route for the current amount. path, _, err := p.pathFinder( - &graphParams{ + ctx, &graphParams{ additionalEdges: p.additionalEdges, bandwidthHints: bandwidthHints, graph: graph, diff --git a/routing/payment_session_source.go b/routing/payment_session_source.go index d5f1a6af41..69f4d0bea4 100644 --- a/routing/payment_session_source.go +++ b/routing/payment_session_source.go @@ -1,6 +1,8 @@ package routing import ( + "context" + "github.com/btcsuite/btcd/btcec/v2" "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/graph/db/models" @@ -54,9 +56,11 @@ func (m *SessionSource) NewPaymentSession(p *LightningPayment, firstHopBlob fn.Option[tlv.Blob], trafficShaper fn.Option[TlvTrafficShaper]) (PaymentSession, error) { - getBandwidthHints := func(graph Graph) (bandwidthHints, error) { + getBandwidthHints := func(ctx context.Context, + graph Graph) (bandwidthHints, error) { + return newBandwidthManager( - graph, m.SourceNode.PubKeyBytes, m.GetLink, + ctx, graph, m.SourceNode.PubKeyBytes, m.GetLink, firstHopBlob, trafficShaper, ) } diff --git a/routing/payment_session_test.go b/routing/payment_session_test.go index 278e090440..51cfadb0de 100644 --- a/routing/payment_session_test.go +++ b/routing/payment_session_test.go @@ -1,6 +1,7 @@ package routing import ( + "context" "testing" "time" @@ -115,7 +116,7 @@ func TestUpdateAdditionalEdge(t *testing.T) { // Create the paymentsession. session, err := newPaymentSession( payment, route.Vertex{}, - func(Graph) (bandwidthHints, error) { + func(context.Context, Graph) (bandwidthHints, error) { return &mockBandwidthHints{}, nil }, newMockGraphSessionFactory(&sessionGraph{}), @@ -193,7 +194,7 @@ func TestRequestRoute(t *testing.T) { session, err := newPaymentSession( payment, route.Vertex{}, - func(Graph) (bandwidthHints, error) { + func(context.Context, Graph) (bandwidthHints, error) { return &mockBandwidthHints{}, nil }, newMockGraphSessionFactory(&sessionGraph{}), @@ -205,7 +206,8 @@ func TestRequestRoute(t *testing.T) { } // Override pathfinder with a mock. - session.pathFinder = func(_ *graphParams, r *RestrictParams, + session.pathFinder = func(_ context.Context, _ *graphParams, + r *RestrictParams, _ *PathFindingConfig, _, _, _ route.Vertex, _ lnwire.MilliSatoshi, _ float64, _ int32) ([]*unifiedEdge, float64, error) { @@ -233,8 +235,8 @@ func TestRequestRoute(t *testing.T) { } route, err := session.RequestRoute( - payment.Amount, payment.FeeLimit, 0, height, - lnwire.CustomRecords{ + context.Background(), payment.Amount, payment.FeeLimit, 0, + height, lnwire.CustomRecords{ lnwire.MinCustomRecordsTlvType + 123: []byte{1, 2, 3}, }, ) diff --git a/routing/router.go b/routing/router.go index b92aa15023..8eedfaf16d 100644 --- a/routing/router.go +++ b/routing/router.go @@ -515,8 +515,8 @@ func getTargetNode(target *route.Vertex, // FindRoute attempts to query the ChannelRouter for the optimum path to a // particular target destination to which it is able to send `amt` after // factoring in channel capacities and cumulative fees along the route. -func (r *ChannelRouter) FindRoute(req *RouteRequest) (*route.Route, float64, - error) { +func (r *ChannelRouter) FindRoute(ctx context.Context, req *RouteRequest) ( + *route.Route, float64, error) { log.Debugf("Searching for path to %v, sending %v", req.Target, req.Amount) @@ -524,7 +524,7 @@ func (r *ChannelRouter) FindRoute(req *RouteRequest) (*route.Route, float64, // We'll attempt to obtain a set of bandwidth hints that can help us // eliminate certain routes early on in the path finding process. bandwidthHints, err := newBandwidthManager( - r.cfg.RoutingGraph, r.cfg.SelfNode, r.cfg.GetLink, + ctx, r.cfg.RoutingGraph, r.cfg.SelfNode, r.cfg.GetLink, fn.None[tlv.Blob](), r.cfg.TrafficShaper, ) if err != nil { @@ -549,7 +549,7 @@ func (r *ChannelRouter) FindRoute(req *RouteRequest) (*route.Route, float64, } path, probability, err := findPath( - &graphParams{ + ctx, &graphParams{ additionalEdges: req.RouteHints, bandwidthHints: bandwidthHints, graph: r.cfg.RoutingGraph, @@ -616,14 +616,15 @@ type BlindedPathRestrictions struct { // FindBlindedPaths finds a selection of paths to the destination node that can // be used in blinded payment paths. -func (r *ChannelRouter) FindBlindedPaths(destination route.Vertex, - amt lnwire.MilliSatoshi, probabilitySrc probabilitySource, +func (r *ChannelRouter) FindBlindedPaths(ctx context.Context, + destination route.Vertex, amt lnwire.MilliSatoshi, + probabilitySrc probabilitySource, restrictions *BlindedPathRestrictions) ([]*route.Route, error) { // First, find a set of candidate paths given the destination node and // path length restrictions. paths, err := findBlindedPaths( - r.cfg.RoutingGraph, destination, &blindedPathRestrictions{ + ctx, r.cfg.RoutingGraph, destination, &blindedPathRestrictions{ minNumHops: restrictions.MinDistanceFromIntroNode, maxNumHops: restrictions.NumHops, nodeOmissionSet: restrictions.NodeOmissionSet, @@ -1366,7 +1367,8 @@ func (e ErrNoChannel) Error() string { // BuildRoute returns a fully specified route based on a list of pubkeys. If // amount is nil, the minimum routable amount is used. To force a specific // outgoing channel, use the outgoingChan parameter. -func (r *ChannelRouter) BuildRoute(amt fn.Option[lnwire.MilliSatoshi], +func (r *ChannelRouter) BuildRoute(ctx context.Context, + amt fn.Option[lnwire.MilliSatoshi], hops []route.Vertex, outgoingChan *uint64, finalCltvDelta int32, payAddr fn.Option[[32]byte], firstHopBlob fn.Option[[]byte]) ( *route.Route, error) { @@ -1383,8 +1385,8 @@ func (r *ChannelRouter) BuildRoute(amt fn.Option[lnwire.MilliSatoshi], // We'll attempt to obtain a set of bandwidth hints that helps us select // the best outgoing channel to use in case no outgoing channel is set. bandwidthHints, err := newBandwidthManager( - r.cfg.RoutingGraph, r.cfg.SelfNode, r.cfg.GetLink, firstHopBlob, - r.cfg.TrafficShaper, + ctx, r.cfg.RoutingGraph, r.cfg.SelfNode, + r.cfg.GetLink, firstHopBlob, r.cfg.TrafficShaper, ) if err != nil { return nil, err @@ -1395,7 +1397,7 @@ func (r *ChannelRouter) BuildRoute(amt fn.Option[lnwire.MilliSatoshi], // We check that each node in the route has a connection to others that // can forward in principle. unifiers, err := getEdgeUnifiers( - r.cfg.SelfNode, hops, outgoingChans, r.cfg.RoutingGraph, + ctx, r.cfg.SelfNode, hops, outgoingChans, r.cfg.RoutingGraph, ) if err != nil { return nil, err @@ -1652,8 +1654,8 @@ func (r *ChannelRouter) failStaleAttempt(a channeldb.HTLCAttempt, } // getEdgeUnifiers returns a list of edge unifiers for the given route. -func getEdgeUnifiers(source route.Vertex, hops []route.Vertex, - outgoingChans map[uint64]struct{}, +func getEdgeUnifiers(ctx context.Context, source route.Vertex, + hops []route.Vertex, outgoingChans map[uint64]struct{}, graph Graph) ([]*edgeUnifier, error) { // Allocate a list that will contain the edge unifiers for this route. @@ -1678,7 +1680,7 @@ func getEdgeUnifiers(source route.Vertex, hops []route.Vertex, source, toNode, !isExitHop, outgoingChans, ) - err := u.addGraphPolicies(graph) + err := u.addGraphPolicies(ctx, graph) if err != nil { return nil, err } diff --git a/routing/router_test.go b/routing/router_test.go index db72bf266c..4de54af1b0 100644 --- a/routing/router_test.go +++ b/routing/router_test.go @@ -2,6 +2,7 @@ package routing import ( "bytes" + "context" "fmt" "image/color" "math" @@ -271,7 +272,7 @@ func TestFindRoutesWithFeeLimit(t *testing.T) { ) require.NoError(t, err, "invalid route request") - route, _, err := ctx.router.FindRoute(req) + route, _, err := ctx.router.FindRoute(context.Background(), req) require.NoError(t, err, "unable to find any routes") require.Falsef(t, @@ -1530,6 +1531,8 @@ func TestSendToRouteMaxHops(t *testing.T) { // TestBuildRoute tests whether correct routes are built. func TestBuildRoute(t *testing.T) { + ctx := context.Background() + // Setup a three node network. chanCapSat := btcutil.Amount(100000) paymentAddrFeatures := lnwire.NewFeatureVector( @@ -1638,7 +1641,9 @@ func TestBuildRoute(t *testing.T) { const startingBlockHeight = 101 - ctx := createTestCtxFromGraphInstance(t, startingBlockHeight, testGraph) + tctx := createTestCtxFromGraphInstance( + t, startingBlockHeight, testGraph, + ) checkHops := func(rt *route.Route, expected []uint64, payAddr [32]byte) { @@ -1664,27 +1669,28 @@ func TestBuildRoute(t *testing.T) { // Test that we can't build a route when no hops are given. hops = []route.Vertex{} - _, err = ctx.router.BuildRoute( - noAmt, hops, nil, 40, fn.None[[32]byte](), fn.None[[]byte](), + _, err = tctx.router.BuildRoute( + ctx, noAmt, hops, nil, 40, fn.None[[32]byte](), + fn.None[[]byte](), ) require.Error(t, err) // Create hop list for an unknown destination. - hops := []route.Vertex{ctx.aliases["b"], ctx.aliases["y"]} - _, err = ctx.router.BuildRoute( - noAmt, hops, nil, 40, fn.Some(payAddr), fn.None[[]byte](), + hops := []route.Vertex{tctx.aliases["b"], tctx.aliases["y"]} + _, err = tctx.router.BuildRoute( + ctx, noAmt, hops, nil, 40, fn.Some(payAddr), fn.None[[]byte](), ) noChanErr := ErrNoChannel{} require.ErrorAs(t, err, &noChanErr) require.Equal(t, 1, noChanErr.position) // Create hop list from the route node pubkeys. - hops = []route.Vertex{ctx.aliases["b"], ctx.aliases["c"]} + hops = []route.Vertex{tctx.aliases["b"], tctx.aliases["c"]} amt := lnwire.NewMSatFromSatoshis(100) // Build the route for the given amount. - rt, err := ctx.router.BuildRoute( - fn.Some(amt), hops, nil, 40, fn.Some(payAddr), + rt, err := tctx.router.BuildRoute( + ctx, fn.Some(amt), hops, nil, 40, fn.Some(payAddr), fn.None[[]byte](), ) require.NoError(t, err) @@ -1696,8 +1702,8 @@ func TestBuildRoute(t *testing.T) { require.Equal(t, lnwire.MilliSatoshi(106000), rt.TotalAmount) // Build the route for the minimum amount. - rt, err = ctx.router.BuildRoute( - noAmt, hops, nil, 40, fn.Some(payAddr), fn.None[[]byte](), + rt, err = tctx.router.BuildRoute( + ctx, noAmt, hops, nil, 40, fn.Some(payAddr), fn.None[[]byte](), ) require.NoError(t, err) @@ -1713,9 +1719,10 @@ func TestBuildRoute(t *testing.T) { // Test a route that contains incompatible channel htlc constraints. // There is no amount that can pass through both channel 5 and 4. - hops = []route.Vertex{ctx.aliases["e"], ctx.aliases["c"]} - _, err = ctx.router.BuildRoute( - noAmt, hops, nil, 40, fn.None[[32]byte](), fn.None[[]byte](), + hops = []route.Vertex{tctx.aliases["e"], tctx.aliases["c"]} + _, err = tctx.router.BuildRoute( + ctx, noAmt, hops, nil, 40, fn.None[[32]byte](), + fn.None[[]byte](), ) require.Error(t, err) noChanErr = ErrNoChannel{} @@ -1733,9 +1740,9 @@ func TestBuildRoute(t *testing.T) { // could me more applicable, which is why we don't get back the highest // amount that could be delivered to the receiver of 21819 msat, using // policy of channel 3. - hops = []route.Vertex{ctx.aliases["b"], ctx.aliases["z"]} - rt, err = ctx.router.BuildRoute( - noAmt, hops, nil, 40, fn.Some(payAddr), fn.None[[]byte](), + hops = []route.Vertex{tctx.aliases["b"], tctx.aliases["z"]} + rt, err = tctx.router.BuildRoute( + ctx, noAmt, hops, nil, 40, fn.Some(payAddr), fn.None[[]byte](), ) require.NoError(t, err) checkHops(rt, []uint64{1, 8}, payAddr) @@ -1746,10 +1753,10 @@ func TestBuildRoute(t *testing.T) { // inbound fees. We expect a similar amount as for the above case of // b->c, but reduced by the inbound discount on the channel a->d. // We get 106000 - 1000 (base in) - 0.001 * 106000 (rate in) = 104894. - hops = []route.Vertex{ctx.aliases["d"], ctx.aliases["f"]} + hops = []route.Vertex{tctx.aliases["d"], tctx.aliases["f"]} amt = lnwire.NewMSatFromSatoshis(100) - rt, err = ctx.router.BuildRoute( - fn.Some(amt), hops, nil, 40, fn.Some(payAddr), + rt, err = tctx.router.BuildRoute( + ctx, fn.Some(amt), hops, nil, 40, fn.Some(payAddr), fn.None[[]byte](), ) require.NoError(t, err) @@ -1764,9 +1771,9 @@ func TestBuildRoute(t *testing.T) { // due to rounding. This would not be compatible with the sender amount // of 20179 msat, which results in underpayment of 1 msat in fee. There // is a third pass through newRoute in which this gets corrected to end - hops = []route.Vertex{ctx.aliases["d"], ctx.aliases["f"]} - rt, err = ctx.router.BuildRoute( - noAmt, hops, nil, 40, fn.Some(payAddr), fn.None[[]byte](), + hops = []route.Vertex{tctx.aliases["d"], tctx.aliases["f"]} + rt, err = tctx.router.BuildRoute( + ctx, noAmt, hops, nil, 40, fn.Some(payAddr), fn.None[[]byte](), ) require.NoError(t, err) checkHops(rt, []uint64{9, 10}, payAddr) @@ -2894,7 +2901,7 @@ func TestAddEdgeUnknownVertexes(t *testing.T) { paymentAmt, 0, noRestrictions, nil, nil, nil, MinCLTVDelta, ) require.NoError(t, err, "invalid route request") - _, _, err = ctx.router.FindRoute(req) + _, _, err = ctx.router.FindRoute(context.Background(), req) require.NoError(t, err, "unable to find any routes") // Now check that we can update the node info for the partial node @@ -2933,7 +2940,7 @@ func TestAddEdgeUnknownVertexes(t *testing.T) { ) require.NoError(t, err, "invalid route request") - _, _, err = ctx.router.FindRoute(req) + _, _, err = ctx.router.FindRoute(context.Background(), req) require.NoError(t, err, "unable to find any routes") copy1, err := ctx.graph.FetchLightningNode(pub1) @@ -3072,6 +3079,8 @@ func createChannelEdge(bitcoinKey1, bitcoinKey2 []byte, func TestFindBlindedPathsWithMC(t *testing.T) { t.Parallel() + ctx := context.Background() + rbFeatureBits := []lnwire.FeatureBit{ lnwire.RouteBlindingOptional, } @@ -3128,15 +3137,15 @@ func TestFindBlindedPathsWithMC(t *testing.T) { ) require.NoError(t, err) - ctx := createTestCtxFromGraphInstance(t, 101, testGraph) + tctx := createTestCtxFromGraphInstance(t, 101, testGraph) var ( - alice = ctx.aliases["alice"] - bob = ctx.aliases["bob"] - charlie = ctx.aliases["charlie"] - dave = ctx.aliases["dave"] - eve = ctx.aliases["eve"] - frank = ctx.aliases["frank"] + alice = tctx.aliases["alice"] + bob = tctx.aliases["bob"] + charlie = tctx.aliases["charlie"] + dave = tctx.aliases["dave"] + eve = tctx.aliases["eve"] + frank = tctx.aliases["frank"] ) // Create a mission control store which initially sets the success @@ -3163,8 +3172,8 @@ func TestFindBlindedPathsWithMC(t *testing.T) { // All the probabilities are set to 1. So if we restrict the path length // to 2 and allow a max of 3 routes, then we expect three paths here. - routes, err := ctx.router.FindBlindedPaths( - dave, 1000, probabilitySrc, &BlindedPathRestrictions{ + routes, err := tctx.router.FindBlindedPaths( + ctx, dave, 1000, probabilitySrc, &BlindedPathRestrictions{ MinDistanceFromIntroNode: 2, NumHops: 2, MaxNumPaths: 3, @@ -3181,12 +3190,12 @@ func TestFindBlindedPathsWithMC(t *testing.T) { var actualPaths []string for _, path := range paths { label := getAliasFromPubKey( - path.SourcePubKey, ctx.aliases, + path.SourcePubKey, tctx.aliases, ) + "," for _, hop := range path.Hops { label += getAliasFromPubKey( - hop.PubKeyBytes, ctx.aliases, + hop.PubKeyBytes, tctx.aliases, ) + "," } @@ -3208,8 +3217,8 @@ func TestFindBlindedPathsWithMC(t *testing.T) { // 3) A -> F -> D missionControl[bob][dave] = 0.5 missionControl[frank][dave] = 0.25 - routes, err = ctx.router.FindBlindedPaths( - dave, 1000, probabilitySrc, &BlindedPathRestrictions{ + routes, err = tctx.router.FindBlindedPaths( + ctx, dave, 1000, probabilitySrc, &BlindedPathRestrictions{ MinDistanceFromIntroNode: 2, NumHops: 2, MaxNumPaths: 3, @@ -3225,8 +3234,8 @@ func TestFindBlindedPathsWithMC(t *testing.T) { // Just to show that the above result was not a fluke, let's change // the C->D link to be the weak one. missionControl[charlie][dave] = 0.125 - routes, err = ctx.router.FindBlindedPaths( - dave, 1000, probabilitySrc, &BlindedPathRestrictions{ + routes, err = tctx.router.FindBlindedPaths( + ctx, dave, 1000, probabilitySrc, &BlindedPathRestrictions{ MinDistanceFromIntroNode: 2, NumHops: 2, MaxNumPaths: 3, @@ -3241,8 +3250,8 @@ func TestFindBlindedPathsWithMC(t *testing.T) { // Change the MaxNumPaths to 1 to assert that only the best route is // returned. - routes, err = ctx.router.FindBlindedPaths( - dave, 1000, probabilitySrc, &BlindedPathRestrictions{ + routes, err = tctx.router.FindBlindedPaths( + ctx, dave, 1000, probabilitySrc, &BlindedPathRestrictions{ MinDistanceFromIntroNode: 2, NumHops: 2, MaxNumPaths: 1, @@ -3255,8 +3264,8 @@ func TestFindBlindedPathsWithMC(t *testing.T) { // Test the edge case where Dave, the recipient, is also the // introduction node. - routes, err = ctx.router.FindBlindedPaths( - dave, 1000, probabilitySrc, &BlindedPathRestrictions{ + routes, err = tctx.router.FindBlindedPaths( + ctx, dave, 1000, probabilitySrc, &BlindedPathRestrictions{ MinDistanceFromIntroNode: 0, NumHops: 0, MaxNumPaths: 1, @@ -3270,8 +3279,8 @@ func TestFindBlindedPathsWithMC(t *testing.T) { // Finally, we make one of the routes have a probability less than the // minimum. This means we expect that route not to be chosen. missionControl[charlie][dave] = DefaultMinRouteProbability - routes, err = ctx.router.FindBlindedPaths( - dave, 1000, probabilitySrc, &BlindedPathRestrictions{ + routes, err = tctx.router.FindBlindedPaths( + ctx, dave, 1000, probabilitySrc, &BlindedPathRestrictions{ MinDistanceFromIntroNode: 2, NumHops: 2, MaxNumPaths: 3, @@ -3285,8 +3294,8 @@ func TestFindBlindedPathsWithMC(t *testing.T) { // Test that if the user explicitly indicates that we should ignore // the Frank node during path selection, then this is done. - routes, err = ctx.router.FindBlindedPaths( - dave, 1000, probabilitySrc, &BlindedPathRestrictions{ + routes, err = tctx.router.FindBlindedPaths( + ctx, dave, 1000, probabilitySrc, &BlindedPathRestrictions{ MinDistanceFromIntroNode: 2, NumHops: 2, MaxNumPaths: 3, diff --git a/routing/unified_edges.go b/routing/unified_edges.go index c2e008e473..5ad777072e 100644 --- a/routing/unified_edges.go +++ b/routing/unified_edges.go @@ -1,6 +1,7 @@ package routing import ( + "context" "math" "github.com/btcsuite/btcd/btcutil" @@ -94,7 +95,7 @@ func (u *nodeEdgeUnifier) addPolicy(fromNode route.Vertex, // addGraphPolicies adds all policies that are known for the toNode in the // graph. -func (u *nodeEdgeUnifier) addGraphPolicies(g Graph) error { +func (u *nodeEdgeUnifier) addGraphPolicies(ctx context.Context, g Graph) error { cb := func(channel *graphdb.DirectedChannel) error { // If there is no edge policy for this candidate node, skip. // Note that we are searching backwards so this node would have @@ -120,7 +121,7 @@ func (u *nodeEdgeUnifier) addGraphPolicies(g Graph) error { } // Iterate over all channels of the to node. - return g.ForEachNodeChannel(u.toNode, cb) + return g.ForEachNodeChannel(ctx, u.toNode, cb) } // unifiedEdge is the individual channel data that is kept inside an edgeUnifier diff --git a/rpcserver.go b/rpcserver.go index d7d2e0186c..d448e23c8d 100644 --- a/rpcserver.go +++ b/rpcserver.go @@ -12,7 +12,6 @@ import ( "net/http" "os" "path/filepath" - "runtime" "sort" "strconv" "strings" @@ -51,7 +50,7 @@ import ( "github.com/lightningnetwork/lnd/graph" graphdb "github.com/lightningnetwork/lnd/graph/db" "github.com/lightningnetwork/lnd/graph/db/models" - "github.com/lightningnetwork/lnd/graph/graphsession" + graphsession "github.com/lightningnetwork/lnd/graph/session" "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/htlcswitch/hop" "github.com/lightningnetwork/lnd/input" @@ -695,6 +694,7 @@ func (r *rpcServer) addDeps(s *server, macService *macaroons.Service, return err } graph := s.graphDB + graphSource := s.graphSource routerBackend := &routerrpc.RouterBackend{ SelfNode: selfNode.PubKeyBytes, @@ -707,11 +707,12 @@ func (r *rpcServer) addDeps(s *server, macService *macaroons.Service, } return info.Capacity, nil }, - FetchAmountPairCapacity: func(nodeFrom, nodeTo route.Vertex, + FetchAmountPairCapacity: func(ctx context.Context, nodeFrom, + nodeTo route.Vertex, amount lnwire.MilliSatoshi) (btcutil.Amount, error) { return routing.FetchAmountPairCapacity( - graphsession.NewRoutingGraph(graph), + ctx, graphsession.NewRoutingGraph(graphSource), selfNode.PubKeyBytes, nodeFrom, nodeTo, amount, ) }, @@ -795,7 +796,8 @@ func (r *rpcServer) addDeps(s *server, macService *macaroons.Service, err = subServerCgs.PopulateDependencies( r.cfg, s.cc, r.cfg.networkDir, macService, atpl, invoiceRegistry, s.htlcSwitch, r.cfg.ActiveNetParams.Params, s.chanRouter, - routerBackend, s.nodeSigner, s.graphDB, s.chanStateDB, + routerBackend, s.nodeSigner, s.graphDB, s.graphSource, + s.chanStateDB, s.sweeper, tower, s.towerClientMgr, r.cfg.net.ResolveTCPAddr, genInvoiceFeatures, genAmpInvoiceFeatures, s.getNodeAnnouncement, s.updateAndBroadcastSelfNode, parseAddr, @@ -1756,8 +1758,8 @@ func (r *rpcServer) VerifyMessage(ctx context.Context, // channels signed the message. // // TODO(phlip9): Require valid nodes to have capital in active channels. - graph := r.server.graphDB - _, active, err := graph.HasLightningNode(pub) + graph := r.server.graphSource + _, active, err := graph.HasLightningNode(ctx, pub) if err != nil { return nil, fmt.Errorf("failed to query graph: %w", err) } @@ -1821,9 +1823,8 @@ func (r *rpcServer) ConnectPeer(ctx context.Context, timeout) } - if err := r.server.ConnectToPeer( - peerAddr, in.Perm, timeout, - ); err != nil { + err = r.server.ConnectToPeer(ctx, peerAddr, in.Perm, timeout) + if err != nil { rpcsLog.Errorf("[connectpeer]: error connecting to peer: %v", err) return nil, err @@ -4538,7 +4539,7 @@ func (r *rpcServer) ListChannels(ctx context.Context, // our list depending on the type of channels requested to us. isActive := peerOnline && linkActive channel, err := createRPCOpenChannel( - r, dbChannel, isActive, in.PeerAliasLookup, + ctx, r, dbChannel, isActive, in.PeerAliasLookup, ) if err != nil { return nil, err @@ -4654,8 +4655,9 @@ func encodeCustomChanData(lnChan *channeldb.OpenChannel) ([]byte, error) { } // createRPCOpenChannel creates an *lnrpc.Channel from the *channeldb.Channel. -func createRPCOpenChannel(r *rpcServer, dbChannel *channeldb.OpenChannel, - isActive, peerAliasLookup bool) (*lnrpc.Channel, error) { +func createRPCOpenChannel(ctx context.Context, r *rpcServer, + dbChannel *channeldb.OpenChannel, isActive, peerAliasLookup bool) ( + *lnrpc.Channel, error) { nodePub := dbChannel.IdentityPub nodeID := hex.EncodeToString(nodePub.SerializeCompressed()) @@ -4756,7 +4758,7 @@ func createRPCOpenChannel(r *rpcServer, dbChannel *channeldb.OpenChannel, // Look up our channel peer's node alias if the caller requests it. if peerAliasLookup { - peerAlias, err := r.server.graphDB.LookupAlias(nodePub) + peerAlias, err := r.server.graphSource.LookupAlias(ctx, nodePub) if err != nil { peerAlias = fmt.Sprintf("unable to lookup "+ "peer alias: %v", err) @@ -5172,7 +5174,8 @@ func (r *rpcServer) SubscribeChannelEvents(req *lnrpc.ChannelEventSubscription, } case channelnotifier.OpenChannelEvent: channel, err := createRPCOpenChannel( - r, event.Channel, true, false, + updateStream.Context(), r, + event.Channel, true, false, ) if err != nil { return err @@ -6095,7 +6098,7 @@ func (r *rpcServer) AddInvoice(ctx context.Context, NodeSigner: r.server.nodeSigner, DefaultCLTVExpiry: defaultDelta, ChanDB: r.server.chanStateDB, - Graph: r.server.graphDB, + Graph: r.server.graphSource, GenInvoiceFeatures: func() *lnwire.FeatureVector { v := r.server.featureMgr.Get(feature.SetInvoice) @@ -6123,11 +6126,11 @@ func (r *rpcServer) AddInvoice(ctx context.Context, }, GetAlias: r.server.aliasMgr.GetPeerAlias, BestHeight: r.server.cc.BestBlockTracker.BestHeight, - QueryBlindedRoutes: func(amt lnwire.MilliSatoshi) ( - []*route.Route, error) { + QueryBlindedRoutes: func(ctx context.Context, + amt lnwire.MilliSatoshi) ([]*route.Route, error) { return r.server.chanRouter.FindBlindedPaths( - r.selfNode, amt, + ctx, r.selfNode, amt, r.server.defaultMC.GetProbability, blindingRestrictions, ) @@ -6525,20 +6528,13 @@ func (r *rpcServer) DescribeGraph(ctx context.Context, } } - // Obtain the pointer to the global singleton channel graph, this will - // provide a consistent view of the graph due to bolt db's - // transactional model. - graph := r.server.graphDB + graph := r.server.graphSource // First iterate through all the known nodes (connected or unconnected // within the graph), collating their current state into the RPC // response. - err := graph.ForEachNode(func(_ kvdb.RTx, - node *models.LightningNode) error { - - lnNode := marshalNode(node) - - resp.Nodes = append(resp.Nodes, lnNode) + err := graph.ForEachNode(ctx, func(node *models.LightningNode) error { + resp.Nodes = append(resp.Nodes, marshalNode(node)) return nil }) @@ -6549,19 +6545,18 @@ func (r *rpcServer) DescribeGraph(ctx context.Context, // Next, for each active channel we know of within the graph, create a // similar response which details both the edge information as well as // the routing policies of th nodes connecting the two edges. - err = graph.ForEachChannel(func(edgeInfo *models.ChannelEdgeInfo, + err = graph.ForEachChannel(ctx, func(edgeInfo *models.ChannelEdgeInfo, c1, c2 *models.ChannelEdgePolicy) error { // Do not include unannounced channels unless specifically - // requested. Unannounced channels include both private channels as - // well as public channels whose authentication proof were not - // confirmed yet, hence were not announced. + // requested. Unannounced channels include both private channels + // as well as public channels whose authentication proof were + // not confirmed yet, hence were not announced. if !includeUnannounced && edgeInfo.AuthProof == nil { return nil } - edge := marshalDBEdge(edgeInfo, c1, c2) - resp.Edges = append(resp.Edges, edge) + resp.Edges = append(resp.Edges, marshalDBEdge(edgeInfo, c1, c2)) return nil }) @@ -6706,43 +6701,27 @@ func (r *rpcServer) GetNodeMetrics(ctx context.Context, return nil, nil } - resp := &lnrpc.NodeMetricsResponse{ - BetweennessCentrality: make(map[string]*lnrpc.FloatMetric), - } - - // Obtain the pointer to the global singleton channel graph, this will - // provide a consistent view of the graph due to bolt db's - // transactional model. - graph := r.server.graphDB + graph := r.server.graphSource - // Calculate betweenness centrality if requested. Note that depending on the - // graph size, this may take up to a few minutes. - channelGraph := autopilot.ChannelGraphFromDatabase(graph) - centralityMetric, err := autopilot.NewBetweennessCentralityMetric( - runtime.NumCPU(), - ) + // Calculate betweenness centrality if requested. Note that depending on + // the graph size, this may take up to a few minutes. + centrality, err := graph.BetweennessCentrality(ctx) if err != nil { return nil, err } - if err := centralityMetric.Refresh(channelGraph); err != nil { - return nil, err - } - - // Fill normalized and non normalized centrality. - centrality := centralityMetric.GetMetric(true) - for nodeID, val := range centrality { - resp.BetweennessCentrality[hex.EncodeToString(nodeID[:])] = - &lnrpc.FloatMetric{ - NormalizedValue: val, - } - } - centrality = centralityMetric.GetMetric(false) - for nodeID, val := range centrality { - resp.BetweennessCentrality[hex.EncodeToString(nodeID[:])].Value = val + result := make(map[string]*lnrpc.FloatMetric) + for nodeID, betweenness := range centrality { + id := hex.EncodeToString(nodeID[:]) + result[id] = &lnrpc.FloatMetric{ + Value: betweenness.NonNormalized, + NormalizedValue: betweenness.Normalized, + } } - return resp, nil + return &lnrpc.NodeMetricsResponse{ + BetweennessCentrality: result, + }, nil } // GetChanInfo returns the latest authenticated network announcement for the @@ -6750,10 +6729,10 @@ func (r *rpcServer) GetNodeMetrics(ctx context.Context, // uniquely identify the location of transaction's funding output within the // blockchain. The former is an 8-byte integer, while the latter is a string // formatted as funding_txid:output_index. -func (r *rpcServer) GetChanInfo(_ context.Context, +func (r *rpcServer) GetChanInfo(ctx context.Context, in *lnrpc.ChanInfoRequest) (*lnrpc.ChannelEdge, error) { - graph := r.server.graphDB + graph := r.server.graphSource var ( edgeInfo *models.ChannelEdgeInfo @@ -6764,7 +6743,7 @@ func (r *rpcServer) GetChanInfo(_ context.Context, switch { case in.ChanId != 0: edgeInfo, edge1, edge2, err = graph.FetchChannelEdgesByID( - in.ChanId, + ctx, in.ChanId, ) case in.ChanPoint != "": @@ -6774,7 +6753,7 @@ func (r *rpcServer) GetChanInfo(_ context.Context, return nil, err } edgeInfo, edge1, edge2, err = graph.FetchChannelEdgesByOutpoint( - chanPoint, + ctx, chanPoint, ) default: @@ -6797,7 +6776,7 @@ func (r *rpcServer) GetChanInfo(_ context.Context, func (r *rpcServer) GetNodeInfo(ctx context.Context, in *lnrpc.NodeInfoRequest) (*lnrpc.NodeInfo, error) { - graph := r.server.graphDB + graph := r.server.graphSource // First, parse the hex-encoded public key into a full in-memory public // key object we can work with for querying. @@ -6809,7 +6788,7 @@ func (r *rpcServer) GetNodeInfo(ctx context.Context, // With the public key decoded, attempt to fetch the node corresponding // to this public key. If the node cannot be found, then an error will // be returned. - node, err := graph.FetchLightningNode(pubKey) + node, err := graph.FetchLightningNode(ctx, pubKey) switch { case errors.Is(err, graphdb.ErrGraphNodeNotFound): return nil, status.Error(codes.NotFound, err.Error()) @@ -6825,9 +6804,9 @@ func (r *rpcServer) GetNodeInfo(ctx context.Context, channels []*lnrpc.ChannelEdge ) - err = graph.ForEachNodeChannel(node.PubKeyBytes, - func(_ kvdb.RTx, edge *models.ChannelEdgeInfo, - c1, c2 *models.ChannelEdgePolicy) error { + err = graph.ForEachNodeChannel(ctx, node.PubKeyBytes, + func(edge *models.ChannelEdgeInfo, c1, + c2 *models.ChannelEdgePolicy) error { numChannels++ totalCapacity += edge.Capacity @@ -6908,134 +6887,34 @@ func (r *rpcServer) QueryRoutes(ctx context.Context, func (r *rpcServer) GetNetworkInfo(ctx context.Context, _ *lnrpc.NetworkInfoRequest) (*lnrpc.NetworkInfo, error) { - graph := r.server.graphDB - - var ( - numNodes uint32 - numChannels uint32 - maxChanOut uint32 - totalNetworkCapacity btcutil.Amount - minChannelSize btcutil.Amount = math.MaxInt64 - maxChannelSize btcutil.Amount - medianChanSize btcutil.Amount - ) - - // We'll use this map to de-duplicate channels during our traversal. - // This is needed since channels are directional, so there will be two - // edges for each channel within the graph. - seenChans := make(map[uint64]struct{}) - - // We also keep a list of all encountered capacities, in order to - // calculate the median channel size. - var allChans []btcutil.Amount - - // We'll run through all the known nodes in the within our view of the - // network, tallying up the total number of nodes, and also gathering - // each node so we can measure the graph diameter and degree stats - // below. - err := graph.ForEachNodeCached(func(node route.Vertex, - edges map[uint64]*graphdb.DirectedChannel) error { - - // Increment the total number of nodes with each iteration. - numNodes++ - - // For each channel we'll compute the out degree of each node, - // and also update our running tallies of the min/max channel - // capacity, as well as the total channel capacity. We pass - // through the db transaction from the outer view so we can - // re-use it within this inner view. - var outDegree uint32 - for _, edge := range edges { - // Bump up the out degree for this node for each - // channel encountered. - outDegree++ - - // If we've already seen this channel, then we'll - // return early to ensure that we don't double-count - // stats. - if _, ok := seenChans[edge.ChannelID]; ok { - return nil - } - - // Compare the capacity of this channel against the - // running min/max to see if we should update the - // extrema. - chanCapacity := edge.Capacity - if chanCapacity < minChannelSize { - minChannelSize = chanCapacity - } - if chanCapacity > maxChannelSize { - maxChannelSize = chanCapacity - } + graph := r.server.graphSource - // Accumulate the total capacity of this channel to the - // network wide-capacity. - totalNetworkCapacity += chanCapacity - - numChannels++ - - seenChans[edge.ChannelID] = struct{}{} - allChans = append(allChans, edge.Capacity) - } - - // Finally, if the out degree of this node is greater than what - // we've seen so far, update the maxChanOut variable. - if outDegree > maxChanOut { - maxChanOut = outDegree - } - - return nil - }) - if err != nil { - return nil, err - } - - // Query the graph for the current number of zombie channels. - numZombies, err := graph.NumZombies() + stats, err := graph.NetworkStats(ctx) if err != nil { return nil, err } - // Find the median. - medianChanSize = autopilot.Median(allChans) - - // If we don't have any channels, then reset the minChannelSize to zero - // to avoid outputting NaN in encoded JSON. - if numChannels == 0 { - minChannelSize = 0 - } - - // Graph diameter. - channelGraph := autopilot.ChannelGraphFromCachedDatabase(graph) - simpleGraph, err := autopilot.NewSimpleGraph(channelGraph) - if err != nil { - return nil, err - } - start := time.Now() - diameter := simpleGraph.DiameterRadialCutoff() - rpcsLog.Infof("elapsed time for diameter (%d) calculation: %v", diameter, - time.Since(start)) - // TODO(roasbeef): also add oldest channel? netInfo := &lnrpc.NetworkInfo{ - GraphDiameter: diameter, - MaxOutDegree: maxChanOut, - AvgOutDegree: float64(2*numChannels) / float64(numNodes), - NumNodes: numNodes, - NumChannels: numChannels, - TotalNetworkCapacity: int64(totalNetworkCapacity), - AvgChannelSize: float64(totalNetworkCapacity) / float64(numChannels), - - MinChannelSize: int64(minChannelSize), - MaxChannelSize: int64(maxChannelSize), - MedianChannelSizeSat: int64(medianChanSize), - NumZombieChans: numZombies, + GraphDiameter: stats.Diameter, + MaxOutDegree: stats.MaxChanOut, + AvgOutDegree: float64(2*stats.NumChannels) / + float64(stats.NumNodes), + NumNodes: stats.NumNodes, + NumChannels: stats.NumChannels, + TotalNetworkCapacity: int64(stats.TotalNetworkCapacity), + AvgChannelSize: float64(stats.TotalNetworkCapacity) / + float64(stats.NumChannels), + MinChannelSize: int64(stats.MinChanSize), + MaxChannelSize: int64(stats.MaxChanSize), + MedianChannelSizeSat: int64(stats.MedianChanSize), + NumZombieChans: stats.NumZombies, } // Similarly, if we don't have any channels, then we'll also set the // average channel size to zero in order to avoid weird JSON encoding // outputs. - if numChannels == 0 { + if stats.NumChannels == 0 { netInfo.AvgChannelSize = 0 } @@ -7886,7 +7765,9 @@ func (r *rpcServer) ForwardingHistory(ctx context.Context, return "", err } - peer, err := r.server.graphDB.FetchLightningNode(vertex) + peer, err := r.server.graphSource.FetchLightningNode( + ctx, vertex, + ) if err != nil { return "", err } @@ -7972,7 +7853,7 @@ func (r *rpcServer) ExportChannelBackup(ctx context.Context, // the database. If this channel has been closed, or the outpoint is // unknown, then we'll return an error unpackedBackup, err := chanbackup.FetchBackupForChan( - chanPoint, r.server.chanStateDB, r.server.addrSource, + ctx, chanPoint, r.server.chanStateDB, r.server.addrSource, ) if err != nil { return nil, err @@ -8152,7 +8033,7 @@ func (r *rpcServer) ExportAllChannelBackups(ctx context.Context, // First, we'll attempt to read back ups for ALL currently opened // channels from disk. allUnpackedBackups, err := chanbackup.FetchStaticChanBackups( - r.server.chanStateDB, r.server.addrSource, + ctx, r.server.chanStateDB, r.server.addrSource, ) if err != nil { return nil, fmt.Errorf("unable to fetch all static chan "+ @@ -8210,7 +8091,7 @@ func (r *rpcServer) RestoreChannelBackups(ctx context.Context, // out to any peers that we know of which were our prior // channel peers. numRestored, err = chanbackup.UnpackAndRecoverSingles( - chanbackup.PackedSingles(packedBackups), + ctx, chanbackup.PackedSingles(packedBackups), r.server.cc.KeyRing, chanRestorer, r.server, ) if err != nil { @@ -8227,7 +8108,7 @@ func (r *rpcServer) RestoreChannelBackups(ctx context.Context, // channel peers. packedMulti := chanbackup.PackedMulti(packedMultiBackup) numRestored, err = chanbackup.UnpackAndRecoverMulti( - packedMulti, r.server.cc.KeyRing, chanRestorer, + ctx, packedMulti, r.server.cc.KeyRing, chanRestorer, r.server, ) if err != nil { @@ -8287,7 +8168,8 @@ func (r *rpcServer) SubscribeChannelBackups(req *lnrpc.ChannelBackupSubscription // we'll obtains the current set of single channel // backups from disk. chanBackups, err := chanbackup.FetchStaticChanBackups( - r.server.chanStateDB, r.server.addrSource, + updateStream.Context(), r.server.chanStateDB, + r.server.addrSource, ) if err != nil { return fmt.Errorf("unable to fetch all "+ diff --git a/server.go b/server.go index d9e86db4b5..431a3ee245 100644 --- a/server.go +++ b/server.go @@ -44,7 +44,8 @@ import ( "github.com/lightningnetwork/lnd/graph" graphdb "github.com/lightningnetwork/lnd/graph/db" "github.com/lightningnetwork/lnd/graph/db/models" - "github.com/lightningnetwork/lnd/graph/graphsession" + graphsession "github.com/lightningnetwork/lnd/graph/session" + "github.com/lightningnetwork/lnd/graph/sources" "github.com/lightningnetwork/lnd/healthcheck" "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/htlcswitch/hop" @@ -261,6 +262,10 @@ type server struct { graphDB *graphdb.ChannelGraph + // graphSource can be used for any read only graph queries. This may be + // implemented by this LND node or some other external source. + graphSource sources.GraphSource + chanStateDB *channeldb.ChannelStateDB addrSource channeldb.AddrSource @@ -501,13 +506,13 @@ func noiseDial(idKey keychain.SingleKeyECDH, // newServer creates a new instance of the server which is to listen using the // passed listener address. -func newServer(cfg *Config, listenAddrs []net.Addr, +func newServer(ctx context.Context, cfg *Config, listenAddrs []net.Addr, dbs *DatabaseInstances, cc *chainreg.ChainControl, nodeKeyDesc *keychain.KeyDescriptor, chansToRestore walletunlocker.ChannelsToRecover, chanPredicate chanacceptor.ChannelAcceptor, torController *tor.Controller, tlsManager *TLSManager, - leaderElector cluster.LeaderElector, + leaderElector cluster.LeaderElector, graphSource sources.GraphSource, implCfg *ImplementationCfg) (*server, error) { var ( @@ -607,12 +612,13 @@ func newServer(cfg *Config, listenAddrs []net.Addr, HtlcInterceptor: invoiceHtlcModifier, } - addrSource := channeldb.NewMultiAddrSource(dbs.ChanStateDB, dbs.GraphDB) + addrSource := channeldb.NewMultiAddrSource(dbs.ChanStateDB, graphSource) s := &server{ cfg: cfg, implCfg: implCfg, graphDB: dbs.GraphDB, + graphSource: graphSource, chanStateDB: dbs.ChanStateDB.ChannelStateDB(), addrSource: addrSource, miscDB: dbs.ChanStateDB, @@ -766,7 +772,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr, IsChannelActive: s.htlcSwitch.HasActiveLink, ApplyChannelUpdate: s.applyChannelUpdate, DB: s.chanStateDB, - Graph: dbs.GraphDB, + Graph: graphSource, } chanStatusMgr, err := netann.NewChanStatusManager(chanStatusMgrCfg) @@ -1020,7 +1026,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr, } paymentSessionSource := &routing.SessionSource{ GraphSessionFactory: graphsession.NewGraphSessionFactory( - dbs.GraphDB, + graphSource, ), SourceNode: sourceNode, MissionControl: s.defaultMC, @@ -1054,7 +1060,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr, s.chanRouter, err = routing.New(routing.Config{ SelfNode: selfNode.PubKeyBytes, - RoutingGraph: graphsession.NewRoutingGraph(dbs.GraphDB), + RoutingGraph: graphsession.NewRoutingGraph(graphSource), Chain: cc.ChainIO, Payer: s.htlcSwitch, Control: s.controlTower, @@ -1631,13 +1637,13 @@ func newServer(cfg *Config, listenAddrs []net.Addr, } backupFile := chanbackup.NewMultiFile(cfg.BackupFilePath) startingChans, err := chanbackup.FetchStaticChanBackups( - s.chanStateDB, s.addrSource, + ctx, s.chanStateDB, s.addrSource, ) if err != nil { return nil, err } s.chanSubSwapper, err = chanbackup.NewSubSwapper( - startingChans, chanNotifier, s.cc.KeyRing, backupFile, + ctx, startingChans, chanNotifier, s.cc.KeyRing, backupFile, ) if err != nil { return nil, err @@ -1799,14 +1805,18 @@ func newServer(cfg *Config, listenAddrs []net.Addr, // maintaining persistent outbound connections and also accepting new // incoming connections cmgr, err := connmgr.New(&connmgr.Config{ - Listeners: listeners, - OnAccept: s.InboundPeerConnected, + Listeners: listeners, + OnAccept: func(conn net.Conn) { + s.InboundPeerConnected(ctx, conn) + }, RetryDuration: time.Second * 5, TargetOutbound: 100, Dial: noiseDial( nodeKeyECDH, s.cfg.net, s.cfg.ConnectionTimeout, ), - OnConnection: s.OutboundPeerConnected, + OnConnection: func(req *connmgr.ConnReq, conn net.Conn) { + s.OutboundPeerConnected(ctx, req, conn) + }, }) if err != nil { return nil, err @@ -2072,7 +2082,7 @@ func (c cleaner) run() { // NOTE: This function is safe for concurrent access. // //nolint:funlen -func (s *server) Start() error { +func (s *server) Start(ctx context.Context) error { var startErr error // If one sub system fails to start, the following code ensures that the @@ -2283,7 +2293,7 @@ func (s *server) Start() error { } if len(s.chansToRestore.PackedSingleChanBackups) != 0 { _, err := chanbackup.UnpackAndRecoverSingles( - s.chansToRestore.PackedSingleChanBackups, + ctx, s.chansToRestore.PackedSingleChanBackups, s.cc.KeyRing, chanRestorer, s, ) if err != nil { @@ -2294,7 +2304,7 @@ func (s *server) Start() error { } if len(s.chansToRestore.PackedMultiChanBackup) != 0 { _, err := chanbackup.UnpackAndRecoverMulti( - s.chansToRestore.PackedMultiChanBackup, + ctx, s.chansToRestore.PackedMultiChanBackup, s.cc.KeyRing, chanRestorer, s, ) if err != nil { @@ -2359,8 +2369,7 @@ func (s *server) Start() error { } err = s.ConnectToPeer( - peerAddr, true, - s.cfg.ConnectionTimeout, + ctx, peerAddr, true, s.cfg.ConnectionTimeout, ) if err != nil { startErr = fmt.Errorf("unable to connect to "+ @@ -2447,14 +2456,16 @@ func (s *server) Start() error { // dedicated goroutine to maintain a set of persistent // connections. if shouldPeerBootstrap(s.cfg) { - bootstrappers, err := initNetworkBootstrappers(s) + bootstrappers, err := initNetworkBootstrappers(ctx, s) if err != nil { startErr = err return } s.wg.Add(1) - go s.peerBootstrapper(defaultMinPeers, bootstrappers) + go s.peerBootstrapper( + ctx, defaultMinPeers, bootstrappers, + ) } else { srvrLog.Infof("Auto peer bootstrapping is disabled") } @@ -2476,6 +2487,7 @@ func (s *server) Start() error { // NOTE: This function is safe for concurrent access. func (s *server) Stop() error { s.stop.Do(func() { + ctx := context.Background() atomic.StoreInt32(&s.stopping, 1) close(s.quit) @@ -2545,7 +2557,7 @@ func (s *server) Stop() error { // Update channel.backup file. Make sure to do it before // stopping chanSubSwapper. singles, err := chanbackup.FetchStaticChanBackups( - s.chanStateDB, s.addrSource, + ctx, s.chanStateDB, s.addrSource, ) if err != nil { srvrLog.Warnf("failed to fetch channel states: %v", @@ -2810,7 +2822,9 @@ out: // initNetworkBootstrappers initializes a set of network peer bootstrappers // based on the server, and currently active bootstrap mechanisms as defined // within the current configuration. -func initNetworkBootstrappers(s *server) ([]discovery.NetworkPeerBootstrapper, error) { +func initNetworkBootstrappers(ctx context.Context, + s *server) ([]discovery.NetworkPeerBootstrapper, error) { + srvrLog.Infof("Initializing peer network bootstrappers!") var bootStrappers []discovery.NetworkPeerBootstrapper @@ -2818,8 +2832,7 @@ func initNetworkBootstrappers(s *server) ([]discovery.NetworkPeerBootstrapper, e // First, we'll create an instance of the ChannelGraphBootstrapper as // this can be used by default if we've already partially seeded the // network. - chanGraph := autopilot.ChannelGraphFromDatabase(s.graphDB) - graphBootstrapper, err := discovery.NewGraphBootstrapper(chanGraph) + graphBootstrapper, err := s.graphSource.GraphBootstrapper(ctx) if err != nil { return nil, err } @@ -2884,7 +2897,7 @@ func (s *server) createBootstrapIgnorePeers() map[autopilot.NodeID]struct{} { // invariant, we ensure that our node is connected to a diverse set of peers // and that nodes newly joining the network receive an up to date network view // as soon as possible. -func (s *server) peerBootstrapper(numTargetPeers uint32, +func (s *server) peerBootstrapper(ctx context.Context, numTargetPeers uint32, bootstrappers []discovery.NetworkPeerBootstrapper) { defer s.wg.Done() @@ -2894,7 +2907,7 @@ func (s *server) peerBootstrapper(numTargetPeers uint32, // We'll start off by aggressively attempting connections to peers in // order to be a part of the network as soon as possible. - s.initialPeerBootstrap(ignoreList, numTargetPeers, bootstrappers) + s.initialPeerBootstrap(ctx, ignoreList, numTargetPeers, bootstrappers) // Once done, we'll attempt to maintain our target minimum number of // peers. @@ -2972,7 +2985,7 @@ func (s *server) peerBootstrapper(numTargetPeers uint32, ignoreList = s.createBootstrapIgnorePeers() peerAddrs, err := discovery.MultiSourceBootstrap( - ignoreList, numNeeded*2, bootstrappers..., + ctx, ignoreList, numNeeded*2, bootstrappers..., ) if err != nil { srvrLog.Errorf("Unable to retrieve bootstrap "+ @@ -2990,7 +3003,7 @@ func (s *server) peerBootstrapper(numTargetPeers uint32, // country diversity, etc errChan := make(chan error, 1) s.connectToPeer( - a, errChan, + ctx, a, errChan, s.cfg.ConnectionTimeout, ) select { @@ -3021,8 +3034,8 @@ const bootstrapBackOffCeiling = time.Minute * 5 // initialPeerBootstrap attempts to continuously connect to peers on startup // until the target number of peers has been reached. This ensures that nodes // receive an up to date network view as soon as possible. -func (s *server) initialPeerBootstrap(ignore map[autopilot.NodeID]struct{}, - numTargetPeers uint32, +func (s *server) initialPeerBootstrap(ctx context.Context, + ignore map[autopilot.NodeID]struct{}, numTargetPeers uint32, bootstrappers []discovery.NetworkPeerBootstrapper) { srvrLog.Debugf("Init bootstrap with targetPeers=%v, bootstrappers=%v, "+ @@ -3081,7 +3094,7 @@ func (s *server) initialPeerBootstrap(ignore map[autopilot.NodeID]struct{}, // in order to reach our target. peersNeeded := numTargetPeers - numActivePeers bootstrapAddrs, err := discovery.MultiSourceBootstrap( - ignore, peersNeeded, bootstrappers..., + ctx, ignore, peersNeeded, bootstrappers..., ) if err != nil { srvrLog.Errorf("Unable to retrieve initial bootstrap "+ @@ -3099,7 +3112,8 @@ func (s *server) initialPeerBootstrap(ignore map[autopilot.NodeID]struct{}, errChan := make(chan error, 1) go s.connectToPeer( - addr, errChan, s.cfg.ConnectionTimeout, + ctx, addr, errChan, + s.cfg.ConnectionTimeout, ) // We'll only allow this connection attempt to @@ -3777,7 +3791,7 @@ func shouldDropLocalConnection(local, remote *btcec.PublicKey) bool { // connection. // // NOTE: This function is safe for concurrent access. -func (s *server) InboundPeerConnected(conn net.Conn) { +func (s *server) InboundPeerConnected(ctx context.Context, conn net.Conn) { // Exit early if we have already been instructed to shutdown, this // prevents any delayed callbacks from accidentally registering peers. if s.Stopped() { @@ -3847,7 +3861,7 @@ func (s *server) InboundPeerConnected(conn net.Conn) { // We were unable to locate an existing connection with the // target peer, proceed to connect. s.cancelConnReqs(pubStr, nil) - s.peerConnected(conn, nil, true) + s.peerConnected(ctx, conn, nil, true) case nil: // We already have a connection with the incoming peer. If the @@ -3879,7 +3893,7 @@ func (s *server) InboundPeerConnected(conn net.Conn) { s.removePeer(connectedPeer) s.ignorePeerTermination[connectedPeer] = struct{}{} s.scheduledPeerConnection[pubStr] = func() { - s.peerConnected(conn, nil, true) + s.peerConnected(ctx, conn, nil, true) } } } @@ -3887,7 +3901,9 @@ func (s *server) InboundPeerConnected(conn net.Conn) { // OutboundPeerConnected initializes a new peer in response to a new outbound // connection. // NOTE: This function is safe for concurrent access. -func (s *server) OutboundPeerConnected(connReq *connmgr.ConnReq, conn net.Conn) { +func (s *server) OutboundPeerConnected(ctx context.Context, + connReq *connmgr.ConnReq, conn net.Conn) { + // Exit early if we have already been instructed to shutdown, this // prevents any delayed callbacks from accidentally registering peers. if s.Stopped() { @@ -3985,7 +4001,7 @@ func (s *server) OutboundPeerConnected(connReq *connmgr.ConnReq, conn net.Conn) case ErrPeerNotConnected: // We were unable to locate an existing connection with the // target peer, proceed to connect. - s.peerConnected(conn, connReq, false) + s.peerConnected(ctx, conn, connReq, false) case nil: // We already have a connection with the incoming peer. If the @@ -4019,7 +4035,7 @@ func (s *server) OutboundPeerConnected(connReq *connmgr.ConnReq, conn net.Conn) s.removePeer(connectedPeer) s.ignorePeerTermination[connectedPeer] = struct{}{} s.scheduledPeerConnection[pubStr] = func() { - s.peerConnected(conn, connReq, false) + s.peerConnected(ctx, conn, connReq, false) } } } @@ -4097,8 +4113,8 @@ func (s *server) SubscribeCustomMessages() (*subscribe.Client, error) { // peer by adding it to the server's global list of all active peers, and // starting all the goroutines the peer needs to function properly. The inbound // boolean should be true if the peer initiated the connection to us. -func (s *server) peerConnected(conn net.Conn, connReq *connmgr.ConnReq, - inbound bool) { +func (s *server) peerConnected(ctx context.Context, conn net.Conn, + connReq *connmgr.ConnReq, inbound bool) { brontideConn := conn.(*brontide.Conn) addr := conn.RemoteAddr() @@ -4252,7 +4268,7 @@ func (s *server) peerConnected(conn net.Conn, connReq *connmgr.ConnReq, // includes sending and receiving Init messages, which would be a DOS // vector if we held the server's mutex throughout the procedure. s.wg.Add(1) - go s.peerInitializer(p) + go s.peerInitializer(ctx, p) } // addPeer adds the passed peer to the server's global state of all active @@ -4307,7 +4323,7 @@ func (s *server) addPeer(p *peer.Brontide) { // be signaled of the new peer once the method returns. // // NOTE: This MUST be launched as a goroutine. -func (s *server) peerInitializer(p *peer.Brontide) { +func (s *server) peerInitializer(ctx context.Context, p *peer.Brontide) { defer s.wg.Done() pubBytes := p.IdentityKey().SerializeCompressed() @@ -4331,7 +4347,7 @@ func (s *server) peerInitializer(p *peer.Brontide) { // the peer is ever added to the ignorePeerTermination map, indicating // that the server has already handled the removal of this peer. s.wg.Add(1) - go s.peerTerminationWatcher(p, ready) + go s.peerTerminationWatcher(ctx, p, ready) // Start the peer! If an error occurs, we Disconnect the peer, which // will unblock the peerTerminationWatcher. @@ -4376,7 +4392,9 @@ func (s *server) peerInitializer(p *peer.Brontide) { // successfully, otherwise the peer should be disconnected instead. // // NOTE: This MUST be launched as a goroutine. -func (s *server) peerTerminationWatcher(p *peer.Brontide, ready chan struct{}) { +func (s *server) peerTerminationWatcher(ctx context.Context, p *peer.Brontide, + ready chan struct{}) { + defer s.wg.Done() p.WaitForDisconnect(ready) @@ -4465,7 +4483,7 @@ func (s *server) peerTerminationWatcher(p *peer.Brontide, ready chan struct{}) { // We'll ensure that we locate all the peers advertised addresses for // reconnection purposes. - advertisedAddrs, err := s.fetchNodeAdvertisedAddrs(pubKey) + advertisedAddrs, err := s.fetchNodeAdvertisedAddrs(ctx, pubKey) switch { // We found advertised addresses, so use them. case err == nil: @@ -4714,7 +4732,7 @@ func (s *server) removePeer(p *peer.Brontide) { // connection is established, or the initial handshake process fails. // // NOTE: This function is safe for concurrent access. -func (s *server) ConnectToPeer(addr *lnwire.NetAddress, +func (s *server) ConnectToPeer(ctx context.Context, addr *lnwire.NetAddress, perm bool, timeout time.Duration) error { targetPub := string(addr.IdentityKey.SerializeCompressed()) @@ -4776,7 +4794,7 @@ func (s *server) ConnectToPeer(addr *lnwire.NetAddress, // the crypto negotiation breaks down, then return an error to the // caller. errChan := make(chan error, 1) - s.connectToPeer(addr, errChan, timeout) + s.connectToPeer(ctx, addr, errChan, timeout) select { case err := <-errChan: @@ -4789,7 +4807,7 @@ func (s *server) ConnectToPeer(addr *lnwire.NetAddress, // connectToPeer establishes a connection to a remote peer. errChan is used to // notify the caller if the connection attempt has failed. Otherwise, it will be // closed. -func (s *server) connectToPeer(addr *lnwire.NetAddress, +func (s *server) connectToPeer(ctx context.Context, addr *lnwire.NetAddress, errChan chan<- error, timeout time.Duration) { conn, err := brontide.Dial( @@ -4809,7 +4827,7 @@ func (s *server) connectToPeer(addr *lnwire.NetAddress, srvrLog.Tracef("Brontide dialer made local=%v, remote=%v", conn.LocalAddr(), conn.RemoteAddr()) - s.OutboundPeerConnected(nil, conn) + s.OutboundPeerConnected(ctx, nil, conn) } // DisconnectPeer sends the request to server to close the connection with peer @@ -4955,13 +4973,15 @@ func computeNextBackoff(currBackoff, maxBackoff time.Duration) time.Duration { var errNoAdvertisedAddr = errors.New("no advertised address found") // fetchNodeAdvertisedAddrs attempts to fetch the advertised addresses of a node. -func (s *server) fetchNodeAdvertisedAddrs(pub *btcec.PublicKey) ([]net.Addr, error) { +func (s *server) fetchNodeAdvertisedAddrs(ctx context.Context, + pub *btcec.PublicKey) ([]net.Addr, error) { + vertex, err := route.NewVertexFromBytes(pub.SerializeCompressed()) if err != nil { return nil, err } - node, err := s.graphDB.FetchLightningNode(vertex) + node, err := s.graphSource.FetchLightningNode(ctx, vertex) if err != nil { return nil, err } diff --git a/subrpcserver_config.go b/subrpcserver_config.go index 30755c05e4..435583373a 100644 --- a/subrpcserver_config.go +++ b/subrpcserver_config.go @@ -13,6 +13,7 @@ import ( "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/fn" graphdb "github.com/lightningnetwork/lnd/graph/db" + graphsources "github.com/lightningnetwork/lnd/graph/sources" "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/invoices" "github.com/lightningnetwork/lnd/lncfg" @@ -114,6 +115,7 @@ func (s *subRPCServerConfigs) PopulateDependencies(cfg *Config, routerBackend *routerrpc.RouterBackend, nodeSigner *netann.NodeSigner, graphDB *graphdb.ChannelGraph, + graphSource graphsources.GraphSource, chanStateDB *channeldb.ChannelStateDB, sweeper *sweep.UtxoSweeper, tower *watchtower.Standalone, @@ -262,8 +264,8 @@ func (s *subRPCServerConfigs) PopulateDependencies(cfg *Config, subCfgValue.FieldByName("DefaultCLTVExpiry").Set( reflect.ValueOf(defaultDelta), ) - subCfgValue.FieldByName("GraphDB").Set( - reflect.ValueOf(graphDB), + subCfgValue.FieldByName("Graph").Set( + reflect.ValueOf(graphSource), ) subCfgValue.FieldByName("ChanStateDB").Set( reflect.ValueOf(chanStateDB),