Skip to content

Commit

Permalink
Merge pull request #7623 from ellemouton/queueBackupIDOnly
Browse files Browse the repository at this point in the history
wtclient: queue backup id only
  • Loading branch information
guggero authored Apr 24, 2023
2 parents 588a7eb + 9a1ed8c commit 4355ce6
Show file tree
Hide file tree
Showing 12 changed files with 460 additions and 258 deletions.
121 changes: 104 additions & 17 deletions channeldb/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (

"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btcwallet/walletdb"
"github.com/go-errors/errors"
mig "github.com/lightningnetwork/lnd/channeldb/migration"
"github.com/lightningnetwork/lnd/channeldb/migration12"
Expand Down Expand Up @@ -654,20 +655,94 @@ func (c *ChannelStateDB) fetchNodeChannels(chainBucket kvdb.RBucket) (

// FetchChannel attempts to locate a channel specified by the passed channel
// point. If the channel cannot be found, then an error will be returned.
// Optionally an existing db tx can be supplied. Optionally an existing db tx
// can be supplied.
// Optionally an existing db tx can be supplied.
func (c *ChannelStateDB) FetchChannel(tx kvdb.RTx, chanPoint wire.OutPoint) (
*OpenChannel, error) {

var (
targetChan *OpenChannel
targetChanPoint bytes.Buffer
)

var targetChanPoint bytes.Buffer
if err := writeOutpoint(&targetChanPoint, &chanPoint); err != nil {
return nil, err
}

targetChanPointBytes := targetChanPoint.Bytes()
selector := func(chainBkt walletdb.ReadBucket) ([]byte, *wire.OutPoint,
error) {

return targetChanPointBytes, &chanPoint, nil
}

return c.channelScanner(tx, selector)
}

// FetchChannelByID attempts to locate a channel specified by the passed channel
// ID. If the channel cannot be found, then an error will be returned.
// Optionally an existing db tx can be supplied.
func (c *ChannelStateDB) FetchChannelByID(tx kvdb.RTx, id lnwire.ChannelID) (
*OpenChannel, error) {

selector := func(chainBkt walletdb.ReadBucket) ([]byte, *wire.OutPoint,
error) {

var (
targetChanPointBytes []byte
targetChanPoint *wire.OutPoint

// errChanFound is used to signal that the channel has
// been found so that iteration through the DB buckets
// can stop.
errChanFound = errors.New("channel found")
)
err := chainBkt.ForEach(func(k, _ []byte) error {
var outPoint wire.OutPoint
err := readOutpoint(bytes.NewReader(k), &outPoint)
if err != nil {
return err
}

chanID := lnwire.NewChanIDFromOutPoint(&outPoint)
if chanID != id {
return nil
}

targetChanPoint = &outPoint
targetChanPointBytes = k

return errChanFound
})
if err != nil && !errors.Is(err, errChanFound) {
return nil, nil, err
}
if targetChanPoint == nil {
return nil, nil, ErrChannelNotFound
}

return targetChanPointBytes, targetChanPoint, nil
}

return c.channelScanner(tx, selector)
}

// channelSelector describes a function that takes a chain-hash bucket from
// within the open-channel DB and returns the wanted channel point bytes, and
// channel point. It must return the ErrChannelNotFound error if the wanted
// channel is not in the given bucket.
type channelSelector func(chainBkt walletdb.ReadBucket) ([]byte, *wire.OutPoint,
error)

// channelScanner will traverse the DB to each chain-hash bucket of each node
// pub-key bucket in the open-channel-bucket. The chanSelector will then be used
// to fetch the wanted channel outpoint from the chain bucket.
func (c *ChannelStateDB) channelScanner(tx kvdb.RTx,
chanSelect channelSelector) (*OpenChannel, error) {

var (
targetChan *OpenChannel

// errChanFound is used to signal that the channel has been
// found so that iteration through the DB buckets can stop.
errChanFound = errors.New("channel found")
)

// chanScan will traverse the following bucket structure:
// * nodePub => chainHash => chanPoint
//
Expand All @@ -685,16 +760,18 @@ func (c *ChannelStateDB) FetchChannel(tx kvdb.RTx, chanPoint wire.OutPoint) (
}

// Within the node channel bucket, are the set of node pubkeys
// we have channels with, we don't know the entire set, so
// we'll check them all.
// we have channels with, we don't know the entire set, so we'll
// check them all.
return openChanBucket.ForEach(func(nodePub, v []byte) error {
// Ensure that this is a key the same size as a pubkey,
// and also that it leads directly to a bucket.
if len(nodePub) != 33 || v != nil {
return nil
}

nodeChanBucket := openChanBucket.NestedReadBucket(nodePub)
nodeChanBucket := openChanBucket.NestedReadBucket(
nodePub,
)
if nodeChanBucket == nil {
return nil
}
Expand All @@ -715,20 +792,30 @@ func (c *ChannelStateDB) FetchChannel(tx kvdb.RTx, chanPoint wire.OutPoint) (
)
if chainBucket == nil {
return fmt.Errorf("unable to read "+
"bucket for chain=%x", chainHash[:])
"bucket for chain=%x",
chainHash)
}

// Finally we reach the leaf bucket that stores
// Finally, we reach the leaf bucket that stores
// all the chanPoints for this node.
targetChanBytes, chanPoint, err := chanSelect(
chainBucket,
)
if errors.Is(err, ErrChannelNotFound) {
return nil
} else if err != nil {
return err
}

chanBucket := chainBucket.NestedReadBucket(
targetChanPoint.Bytes(),
targetChanBytes,
)
if chanBucket == nil {
return nil
}

channel, err := fetchOpenChannel(
chanBucket, &chanPoint,
chanBucket, chanPoint,
)
if err != nil {
return err
Expand All @@ -737,7 +824,7 @@ func (c *ChannelStateDB) FetchChannel(tx kvdb.RTx, chanPoint wire.OutPoint) (
targetChan = channel
targetChan.Db = c

return nil
return errChanFound
})
})
}
Expand All @@ -748,7 +835,7 @@ func (c *ChannelStateDB) FetchChannel(tx kvdb.RTx, chanPoint wire.OutPoint) (
} else {
err = chanScan(tx)
}
if err != nil {
if err != nil && !errors.Is(err, errChanFound) {
return nil, err
}

Expand All @@ -757,7 +844,7 @@ func (c *ChannelStateDB) FetchChannel(tx kvdb.RTx, chanPoint wire.OutPoint) (
}

// If we can't find the channel, then we return with an error, as we
// have nothing to backup.
// have nothing to back up.
return nil, ErrChannelNotFound
}

Expand Down
23 changes: 15 additions & 8 deletions channeldb/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
"github.com/davecgh/go-spew/spew"
"github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lnwire"
Expand Down Expand Up @@ -238,10 +237,16 @@ func TestFetchChannel(t *testing.T) {

// The decoded channel state should be identical to what we stored
// above.
if !reflect.DeepEqual(channelState, dbChannel) {
t.Fatalf("channel state doesn't match:: %v vs %v",
spew.Sdump(channelState), spew.Sdump(dbChannel))
}
require.Equal(t, channelState, dbChannel)

// Next, attempt to fetch the channel by its channel ID.
chanID := lnwire.NewChanIDFromOutPoint(&channelState.FundingOutpoint)
dbChannel, err = cdb.FetchChannelByID(nil, chanID)
require.NoError(t, err, "unable to fetch channel")

// The decoded channel state should be identical to what we stored
// above.
require.Equal(t, channelState, dbChannel)

// If we attempt to query for a non-existent channel, then we should
// get an error.
Expand All @@ -252,9 +257,11 @@ func TestFetchChannel(t *testing.T) {
channelState2.FundingOutpoint.Index = uniqueOutputIndex.Load()

_, err = cdb.FetchChannel(nil, channelState2.FundingOutpoint)
if err == nil {
t.Fatalf("expected query to fail")
}
require.ErrorIs(t, err, ErrChannelNotFound)

chanID2 := lnwire.NewChanIDFromOutPoint(&channelState2.FundingOutpoint)
_, err = cdb.FetchChannelByID(nil, chanID2)
require.ErrorIs(t, err, ErrChannelNotFound)
}

func genRandomChannelShell() (*ChannelShell, error) {
Expand Down
8 changes: 8 additions & 0 deletions docs/release-notes/release-notes-0.17.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,15 @@
implementation](https://github.com/lightningnetwork/lnd/pull/7377) logic in
different update types.

## Watchtowers

* Let the task pipeline [only carry
wtdb.BackupIDs](https://github.com/lightningnetwork/lnd/pull/7623) instead of
the entire retribution struct. This reduces the amount of data that needs to
be held in memory.

# Contributors (Alphabetical Order)

* Elle Mouton
* Jordi Montes

8 changes: 2 additions & 6 deletions htlcswitch/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"github.com/lightningnetwork/lnd/invoices"
"github.com/lightningnetwork/lnd/lnpeer"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/lnwallet/chainfee"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/record"
Expand Down Expand Up @@ -255,11 +254,8 @@ type TowerClient interface {
// state. If the method returns nil, the backup is guaranteed to be
// successful unless the tower is unavailable and client is force quit,
// or the justice transaction would create dust outputs when trying to
// abide by the negotiated policy. If the channel we're trying to back
// up doesn't have a tweak for the remote party's output, then
// isTweakless should be true.
BackupState(*lnwire.ChannelID, *lnwallet.BreachRetribution,
channeldb.ChannelType) error
// abide by the negotiated policy.
BackupState(chanID *lnwire.ChannelID, stateNum uint64) error
}

// InterceptableHtlcForwarder is the interface to set the interceptor
Expand Down
22 changes: 4 additions & 18 deletions htlcswitch/link.go
Original file line number Diff line number Diff line change
Expand Up @@ -2022,11 +2022,6 @@ func (l *channelLink) handleUpstreamMsg(msg lnwire.Message) {
// We've received a revocation from the remote chain, if valid,
// this moves the remote chain forward, and expands our
// revocation window.
//
// Before advancing our remote chain, we will record the
// current commit tx, which is used by the TowerClient to
// create backups.
oldCommitTx := l.channel.State().RemoteCommitment.CommitTx

// We now process the message and advance our remote commit
// chain.
Expand Down Expand Up @@ -2063,24 +2058,15 @@ func (l *channelLink) handleUpstreamMsg(msg lnwire.Message) {
// create a backup for the current state.
if l.cfg.TowerClient != nil {
state := l.channel.State()
breachInfo, err := lnwallet.NewBreachRetribution(
state, state.RemoteCommitment.CommitHeight-1, 0,
// OldCommitTx is the breaching tx at height-1.
oldCommitTx,
)
if err != nil {
l.fail(LinkFailureError{code: ErrInternalError},
"failed to load breach info: %v", err)
return
}

chanID := l.ChanID()

err = l.cfg.TowerClient.BackupState(
&chanID, breachInfo, state.ChanType,
&chanID, state.RemoteCommitment.CommitHeight-1,
)
if err != nil {
l.fail(LinkFailureError{code: ErrInternalError},
"unable to queue breach backup: %v", err)
"unable to queue breach backup: %v",
err)
return
}
}
Expand Down
38 changes: 32 additions & 6 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -1523,12 +1523,37 @@ func newServer(cfg *Config, listenAddrs []net.Addr,
)
}

// buildBreachRetribution is a call-back that can be used to
// query the BreachRetribution info and channel type given a
// channel ID and commitment height.
buildBreachRetribution := func(chanID lnwire.ChannelID,
commitHeight uint64) (*lnwallet.BreachRetribution,
channeldb.ChannelType, error) {

channel, err := s.chanStateDB.FetchChannelByID(
nil, chanID,
)
if err != nil {
return nil, 0, err
}

br, err := lnwallet.NewBreachRetribution(
channel, commitHeight, 0, nil,
)
if err != nil {
return nil, 0, err
}

return br, channel.ChanType, nil
}

fetchClosedChannel := s.chanStateDB.FetchClosedChannelForID

s.towerClient, err = wtclient.New(&wtclient.Config{
FetchClosedChannel: fetchClosedChannel,
SessionCloseRange: sessionCloseRange,
ChainNotifier: s.cc.ChainNotifier,
FetchClosedChannel: fetchClosedChannel,
BuildBreachRetribution: buildBreachRetribution,
SessionCloseRange: sessionCloseRange,
ChainNotifier: s.cc.ChainNotifier,
SubscribeChannelEvents: func() (subscribe.Subscription,
error) {

Expand Down Expand Up @@ -1558,9 +1583,10 @@ func newServer(cfg *Config, listenAddrs []net.Addr,
blob.Type(blob.FlagAnchorChannel)

s.anchorTowerClient, err = wtclient.New(&wtclient.Config{
FetchClosedChannel: fetchClosedChannel,
SessionCloseRange: sessionCloseRange,
ChainNotifier: s.cc.ChainNotifier,
FetchClosedChannel: fetchClosedChannel,
BuildBreachRetribution: buildBreachRetribution,
SessionCloseRange: sessionCloseRange,
ChainNotifier: s.cc.ChainNotifier,
SubscribeChannelEvents: func() (subscribe.Subscription,
error) {

Expand Down
Loading

0 comments on commit 4355ce6

Please sign in to comment.