diff --git a/chanbackup/backup.go b/chanbackup/backup.go index a5ba497488..183d77e7d6 100644 --- a/chanbackup/backup.go +++ b/chanbackup/backup.go @@ -34,10 +34,13 @@ func assembleChanBackup(addrSource channeldb.AddrSource, // First, we'll query the channel source to obtain all the addresses // that are associated with the peer for this channel. - nodeAddrs, err := addrSource.AddrsForNode(openChan.IdentityPub) + known, nodeAddrs, err := addrSource.AddrsForNode(openChan.IdentityPub) if err != nil { return nil, err } + if !known { + return nil, fmt.Errorf("node unknown by address source") + } single := NewSingle(openChan, nodeAddrs) diff --git a/chanbackup/backup_test.go b/chanbackup/backup_test.go index 511b1081dc..0593b97599 100644 --- a/chanbackup/backup_test.go +++ b/chanbackup/backup_test.go @@ -62,20 +62,19 @@ func (m *mockChannelSource) addAddrsForNode(nodePub *btcec.PublicKey, addrs []ne m.addrs[nodeKey] = addrs } -func (m *mockChannelSource) AddrsForNode(nodePub *btcec.PublicKey) ([]net.Addr, error) { +func (m *mockChannelSource) AddrsForNode(nodePub *btcec.PublicKey) (bool, + []net.Addr, error) { + if m.failQuery { - return nil, fmt.Errorf("fail") + return false, nil, fmt.Errorf("fail") } var nodeKey [33]byte copy(nodeKey[:], nodePub.SerializeCompressed()) addrs, ok := m.addrs[nodeKey] - if !ok { - return nil, fmt.Errorf("can't find addr") - } - return addrs, nil + return ok, addrs, nil } // TestFetchBackupForChan tests that we're able to construct a single channel diff --git a/channel_notifier.go b/channel_notifier.go index 6d08620ef6..88a05ac4ce 100644 --- a/channel_notifier.go +++ b/channel_notifier.go @@ -45,7 +45,7 @@ func (c *channelNotifier) SubscribeChans(startingChans map[wire.OutPoint]struct{ // chanUpdates channel to inform subscribers about new pending or // confirmed channels. sendChanOpenUpdate := func(newOrPendingChan *channeldb.OpenChannel) { - nodeAddrs, err := c.addrs.AddrsForNode( + _, nodeAddrs, err := c.addrs.AddrsForNode( newOrPendingChan.IdentityPub, ) if err != nil { diff --git a/channeldb/addr_source.go b/channeldb/addr_source.go index ef06fbe1f1..de933ed496 100644 --- a/channeldb/addr_source.go +++ b/channeldb/addr_source.go @@ -1,6 +1,7 @@ package channeldb import ( + "errors" "net" "github.com/btcsuite/btcd/btcec/v2" @@ -10,8 +11,9 @@ import ( // node. It may combine the results of multiple address sources. type AddrSource interface { // AddrsForNode returns all known addresses for the target node public - // key. - AddrsForNode(nodePub *btcec.PublicKey) ([]net.Addr, error) + // key. The returned boolean must indicate if the given node is unknown + // to the backing source. + AddrsForNode(nodePub *btcec.PublicKey) (bool, []net.Addr, error) } // multiAddrSource is an implementation of AddrSource which gathers all the @@ -31,22 +33,35 @@ func NewMultiAddrSource(sources ...AddrSource) AddrSource { } // AddrsForNode returns all known addresses for the target node public key. It -// queries all the address sources provided and de-duplicates the results. +// queries all the address sources provided and de-duplicates the results. The +// returned boolean is false only if none of the backing sources know of the +// node. // // NOTE: this implements the AddrSource interface. -func (c *multiAddrSource) AddrsForNode(nodePub *btcec.PublicKey) ([]net.Addr, - error) { +func (c *multiAddrSource) AddrsForNode(nodePub *btcec.PublicKey) (bool, + []net.Addr, error) { + + if len(c.sources) == 0 { + return false, nil, errors.New("no address sources") + } // The multiple address sources will likely contain duplicate addresses, // so we use a map here to de-dup them. dedupedAddrs := make(map[string]net.Addr) + // known will be set to true if any backing source is aware of the node. + var known bool + // Iterate over all the address sources and query each one for the // addresses it has for the node in question. for _, src := range c.sources { - addrs, err := src.AddrsForNode(nodePub) + isKnown, addrs, err := src.AddrsForNode(nodePub) if err != nil { - return nil, err + return false, nil, err + } + + if isKnown { + known = true } for _, addr := range addrs { @@ -60,7 +75,7 @@ func (c *multiAddrSource) AddrsForNode(nodePub *btcec.PublicKey) ([]net.Addr, addrs = append(addrs, addr) } - return addrs, nil + return known, addrs, nil } // A compile-time check to ensure that multiAddrSource implements the AddrSource diff --git a/channeldb/addr_source_test.go b/channeldb/addr_source_test.go index 3d7d71ac92..85ee30bf53 100644 --- a/channeldb/addr_source_test.go +++ b/channeldb/addr_source_test.go @@ -36,11 +36,12 @@ func TestMultiAddrSource(t *testing.T) { // Let source 1 know of 2 addresses (addr 1 and 2) for node 1. src1.On("AddrsForNode", pk1).Return( - []net.Addr{addr1, addr2}, nil, + true, []net.Addr{addr1, addr2}, nil, ).Once() // Let source 2 know of 2 addresses (addr 2 and 3) for node 1. src2.On("AddrsForNode", pk1).Return( + true, []net.Addr{addr2, addr3}, nil, []net.Addr{addr2, addr3}, nil, ).Once() @@ -50,8 +51,9 @@ func TestMultiAddrSource(t *testing.T) { // Query it for the addresses known for node 1. The results // should contain addr 1, 2 and 3. - addrs, err := multiSrc.AddrsForNode(pk1) + known, addrs, err := multiSrc.AddrsForNode(pk1) require.NoError(t, err) + require.True(t, known) require.ElementsMatch(t, addrs, []net.Addr{addr1, addr2, addr3}) }) @@ -69,9 +71,9 @@ func TestMultiAddrSource(t *testing.T) { // Let source 1 know of address 1 for node 1. src1.On("AddrsForNode", pk1).Return( - []net.Addr{addr1}, nil, + true, []net.Addr{addr1}, nil, ).Once() - src2.On("AddrsForNode", pk1).Return(nil, nil).Once() + src2.On("AddrsForNode", pk1).Return(false, nil, nil).Once() // Create a multi-addr source that consists of both source 1 // and 2. @@ -79,11 +81,39 @@ func TestMultiAddrSource(t *testing.T) { // Query it for the addresses known for node 1. The results // should contain addr 1. - addrs, err := multiSrc.AddrsForNode(pk1) + known, addrs, err := multiSrc.AddrsForNode(pk1) require.NoError(t, err) + require.True(t, known) require.ElementsMatch(t, addrs, []net.Addr{addr1}) }) + t.Run("unknown address", func(t *testing.T) { + t.Parallel() + + var ( + src1 = newMockAddrSource(t) + src2 = newMockAddrSource(t) + ) + t.Cleanup(func() { + src1.AssertExpectations(t) + src2.AssertExpectations(t) + }) + + // Create a multi-addr source that consists of both source 1 + // and 2. Neither source known of node 1. + multiSrc := NewMultiAddrSource(src1, src2) + + src1.On("AddrsForNode", pk1).Return(false, nil, nil).Once() + src2.On("AddrsForNode", pk1).Return(false, nil, nil).Once() + + // Query it for the addresses known for node 1. It should return + // false to indicate that the node is unknown to all backing + // sources. + known, addrs, err := multiSrc.AddrsForNode(pk1) + require.NoError(t, err) + require.False(t, known) + require.Empty(t, addrs) + }) } type mockAddrSource struct { @@ -97,18 +127,18 @@ func newMockAddrSource(t *testing.T) *mockAddrSource { return &mockAddrSource{t: t} } -func (m *mockAddrSource) AddrsForNode(pub *btcec.PublicKey) ([]net.Addr, +func (m *mockAddrSource) AddrsForNode(pub *btcec.PublicKey) (bool, []net.Addr, error) { args := m.Called(pub) - if args.Get(0) == nil { - return nil, args.Error(1) + if args.Get(1) == nil { + return args.Bool(0), nil, args.Error(2) } - addrs, ok := args.Get(0).([]net.Addr) + addrs, ok := args.Get(1).([]net.Addr) require.True(m.t, ok) - return addrs, args.Error(1) + return args.Bool(0), addrs, args.Error(2) } func newTestPubKey(t *testing.T) *btcec.PublicKey { diff --git a/channeldb/db.go b/channeldb/db.go index 7eea5cbf4d..ef6e899a44 100644 --- a/channeldb/db.go +++ b/channeldb/db.go @@ -1345,12 +1345,32 @@ func (c *ChannelStateDB) RestoreChannelShells(channelShells ...*ChannelShell) er // AddrsForNode consults the graph and channel database for all addresses known // to the passed node public key. -func (d *DB) AddrsForNode(nodePub *btcec.PublicKey) ([]net.Addr, - error) { +func (d *DB) AddrsForNode(nodePub *btcec.PublicKey) (bool, []net.Addr, error) { + var ( + // addrs holds the collection of deduplicated addresses we know + // of for the node. + addrs = make(map[string]net.Addr) + + // known keeps track of if any of the backing sources know of + // this node. + known bool + ) + // First, query the channel DB for its known addresses. linkNode, err := d.channelStateDB.linkNodeDB.FetchLinkNode(nodePub) - if err != nil { - return nil, err + switch { + // If we get back a ErrNodeNotFound error, then this just means that the + // channel DB does not know of the error, but we don't error out here + // because we still want to check the graph db. + case err != nil && !errors.Is(err, ErrNodeNotFound): + return false, nil, err + + // A nil error means the node is known. + case err == nil: + known = true + for _, addr := range linkNode.Addresses { + addrs[addr.String()] = addr + } } // We'll also query the graph for this peer to see if they have any @@ -1358,33 +1378,29 @@ func (d *DB) AddrsForNode(nodePub *btcec.PublicKey) ([]net.Addr, // database. pubKey, err := route.NewVertexFromBytes(nodePub.SerializeCompressed()) if err != nil { - return nil, err + return false, nil, err } graphNode, err := d.graph.FetchLightningNode(pubKey) - if err != nil && err != graphdb.ErrGraphNodeNotFound { - return nil, err - } else if err == graphdb.ErrGraphNodeNotFound { - // If the node isn't found, then that's OK, as we still have the - // link node data. But any other error needs to be returned. - graphNode = &graphdb.LightningNode{} + switch { + // We don't consider it an error if the graph is unaware of the node. + case err != nil && !errors.Is(err, graphdb.ErrGraphNodeNotFound): + return false, nil, err + + // If we do find the node, we add its addresses to our deduplicated set. + case err == nil: + known = true + for _, addr := range graphNode.Addresses { + addrs[addr.String()] = addr + } } - // Now that we have both sources of addrs for this node, we'll use a - // map to de-duplicate any addresses between the two sources, and - // produce a final list of the combined addrs. - addrs := make(map[string]net.Addr) - for _, addr := range linkNode.Addresses { - addrs[addr.String()] = addr - } - for _, addr := range graphNode.Addresses { - addrs[addr.String()] = addr - } + // Convert the deduplicated set into a list. dedupedAddrs := make([]net.Addr, 0, len(addrs)) for _, addr := range addrs { dedupedAddrs = append(dedupedAddrs, addr) } - return dedupedAddrs, nil + return known, dedupedAddrs, nil } // AbandonChannel attempts to remove the target channel from the open channel diff --git a/channeldb/db_test.go b/channeldb/db_test.go index d7e990d627..6717058ce3 100644 --- a/channeldb/db_test.go +++ b/channeldb/db_test.go @@ -219,7 +219,7 @@ func TestAddrsForNode(t *testing.T) { // Now that we've created a link node, as well as a vertex for the // node, we'll query for all its addresses. - nodeAddrs, err := fullDB.AddrsForNode(nodePub) + _, nodeAddrs, err := fullDB.AddrsForNode(nodePub) require.NoError(t, err, "unable to obtain node addrs") expectedAddrs := make(map[string]struct{})