From b668f908aefd74033feaf1b4f5c62241a0e41705 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Thu, 23 Nov 2023 15:33:05 +0200 Subject: [PATCH 1/3] watchtower/wtdb: start populating channel max commitment In this commit, a new key, cChanMaxCommitmentHeight, is added to the channel details bucket. This key will hold the highest commitment number that the tower has been handed for this channel. In this commit, we start writing to it in the two places where a backup is first persisted in the tower client db: 1) CommitUpdate and 2) in the Queue's `addItem` method. A follow up commit will do a migration to back-fill the new field. --- watchtower/wtdb/client_db.go | 68 ++++++++++++++++++++++++++++++++++++ watchtower/wtdb/queue.go | 12 ++++++- 2 files changed, 79 insertions(+), 1 deletion(-) diff --git a/watchtower/wtdb/client_db.go b/watchtower/wtdb/client_db.go index 084f2dcfe0..2c111eb1c3 100644 --- a/watchtower/wtdb/client_db.go +++ b/watchtower/wtdb/client_db.go @@ -25,6 +25,7 @@ var ( // => cChanDBID -> db-assigned-id // => cChanSessions => db-session-id -> 1 // => cChanClosedHeight -> block-height + // => cChanMaxCommitmentHeight -> commitment-height cChanDetailsBkt = []byte("client-channel-detail-bucket") // cChanSessions is a sub-bucket of cChanDetailsBkt which stores: @@ -45,6 +46,13 @@ var ( // body of ClientChanSummary. cChannelSummary = []byte("client-channel-summary") + // cChanMaxCommitmentHeight is a key used in the cChanDetailsBkt used + // to store the highest commitment height for this channel that the + // tower has been handed. + cChanMaxCommitmentHeight = []byte( + "client-channel-max-commitment-height", + ) + // cSessionBkt is a top-level bucket storing: // session-id => cSessionBody -> encoded ClientSessionBody // => cSessionDBID -> db-assigned-id @@ -1963,6 +1971,12 @@ func (c *ClientDB) CommitUpdate(id *SessionID, return err } + // Update the channel's max commitment height if needed. + err = maybeUpdateMaxCommitHeight(tx, update.BackupID) + if err != nil { + return err + } + // Finally, capture the session's last applied value so it can // be sent in the next state update to the tower. lastApplied = session.TowerLastApplied @@ -2181,6 +2195,8 @@ func (c *ClientDB) GetDBQueue(namespace []byte) Queue[*BackupID] { return NewQueueDB[*BackupID]( c.db, namespace, func() *BackupID { return &BackupID{} + }, func(tx kvdb.RwTx, item *BackupID) error { + return maybeUpdateMaxCommitHeight(tx, *item) }, ) } @@ -2720,6 +2736,58 @@ func getDBSessionID(sessionsBkt kvdb.RBucket, sessionID SessionID) (uint64, return id, idBytes, nil } +// maybeUpdateMaxCommitHeight updates the given channel details bucket with the +// given height if it is larger than the current max height stored for the +// channel. +func maybeUpdateMaxCommitHeight(tx kvdb.RwTx, backupID BackupID) error { + chanDetailsBkt := tx.ReadWriteBucket(cChanDetailsBkt) + if chanDetailsBkt == nil { + return ErrUninitializedDB + } + + // If an entry for this channel does not exist in the channel details + // bucket then we exit here as this means that the channel has been + // closed. + chanDetails := chanDetailsBkt.NestedReadWriteBucket(backupID.ChanID[:]) + if chanDetails == nil { + return nil + } + + putHeight := func() error { + b, err := writeBigSize(backupID.CommitHeight) + if err != nil { + return err + } + + return chanDetails.Put( + cChanMaxCommitmentHeight, b, + ) + } + + // Get current height. + heightBytes := chanDetails.Get(cChanMaxCommitmentHeight) + + // The height might have not been set yet, in which case + // we can just write the new height. + if len(heightBytes) == 0 { + return putHeight() + } + + // Otherwise, read in the current max commitment height for the channel. + currentHeight, err := readBigSize(heightBytes) + if err != nil { + return err + } + + // If the new height is not larger than the current persisted height, + // then there is nothing left for us to do. + if backupID.CommitHeight <= currentHeight { + return nil + } + + return putHeight() +} + func getRealSessionID(sessIDIndexBkt kvdb.RBucket, dbID uint64) (*SessionID, error) { diff --git a/watchtower/wtdb/queue.go b/watchtower/wtdb/queue.go index 372765e7cd..674743a1e7 100644 --- a/watchtower/wtdb/queue.go +++ b/watchtower/wtdb/queue.go @@ -80,6 +80,7 @@ type DiskQueueDB[T Serializable] struct { db kvdb.Backend topLevelBkt []byte constructor func() T + onItemWrite func(tx kvdb.RwTx, item T) error } // A compile-time check to ensure that DiskQueueDB implements the Queue @@ -89,12 +90,14 @@ var _ Queue[Serializable] = (*DiskQueueDB[Serializable])(nil) // NewQueueDB constructs a new DiskQueueDB. A queueBktName must be provided so // that the DiskQueueDB can create its own namespace in the bolt db. func NewQueueDB[T Serializable](db kvdb.Backend, queueBktName []byte, - constructor func() T) Queue[T] { + constructor func() T, + onItemWrite func(tx kvdb.RwTx, item T) error) Queue[T] { return &DiskQueueDB[T]{ db: db, topLevelBkt: queueBktName, constructor: constructor, + onItemWrite: onItemWrite, } } @@ -279,6 +282,13 @@ func (d *DiskQueueDB[T]) addItem(tx kvdb.RwTx, queueName []byte, item T) error { return err } + if d.onItemWrite != nil { + err = d.onItemWrite(tx, item) + if err != nil { + return err + } + } + // Find the index to use for placing this new item at the back of the // queue. var nextIndex uint64 From 073e1863f43fa394faa8dfb363fce9d3d031fcce Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Thu, 23 Nov 2023 15:45:42 +0200 Subject: [PATCH 2/3] watchtower: start using the new channel max heights --- watchtower/wtclient/client.go | 87 ++++++++------------------ watchtower/wtclient/interface.go | 6 +- watchtower/wtdb/client_chan_summary.go | 15 +++++ watchtower/wtdb/client_db.go | 37 +++++++---- watchtower/wtdb/client_db_test.go | 10 +-- 5 files changed, 74 insertions(+), 81 deletions(-) diff --git a/watchtower/wtclient/client.go b/watchtower/wtclient/client.go index 412412c1e3..f9036e87fc 100644 --- a/watchtower/wtclient/client.go +++ b/watchtower/wtclient/client.go @@ -17,6 +17,7 @@ import ( "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channelnotifier" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lnwallet" @@ -295,9 +296,8 @@ type TowerClient struct { closableSessionQueue *sessionCloseMinHeap - backupMu sync.Mutex - summaries wtdb.ChannelSummaries - chanCommitHeights map[lnwire.ChannelID]uint64 + backupMu sync.Mutex + chanInfos wtdb.ChannelInfos statTicker *time.Ticker stats *ClientStats @@ -339,9 +339,7 @@ func New(config *Config) (*TowerClient, error) { plog := build.NewPrefixLog(prefix, log) - // Load the sweep pkscripts that have been generated for all previously - // registered channels. - chanSummaries, err := cfg.DB.FetchChanSummaries() + chanInfos, err := cfg.DB.FetchChanInfos() if err != nil { return nil, err } @@ -358,9 +356,8 @@ func New(config *Config) (*TowerClient, error) { cfg: cfg, log: plog, pipeline: queue, - chanCommitHeights: make(map[lnwire.ChannelID]uint64), activeSessions: newSessionQueueSet(), - summaries: chanSummaries, + chanInfos: chanInfos, closableSessionQueue: newSessionCloseMinHeap(), statTicker: time.NewTicker(DefaultStatInterval), stats: new(ClientStats), @@ -369,44 +366,6 @@ func New(config *Config) (*TowerClient, error) { quit: make(chan struct{}), } - // perUpdate is a callback function that will be used to inspect the - // full set of candidate client sessions loaded from disk, and to - // determine the highest known commit height for each channel. This - // allows the client to reject backups that it has already processed for - // its active policy. - perUpdate := func(policy wtpolicy.Policy, chanID lnwire.ChannelID, - commitHeight uint64) { - - // We only want to consider accepted updates that have been - // accepted under an identical policy to the client's current - // policy. - if policy != c.cfg.Policy { - return - } - - c.backupMu.Lock() - defer c.backupMu.Unlock() - - // Take the highest commit height found in the session's acked - // updates. - height, ok := c.chanCommitHeights[chanID] - if !ok || commitHeight > height { - c.chanCommitHeights[chanID] = commitHeight - } - } - - perMaxHeight := func(s *wtdb.ClientSession, chanID lnwire.ChannelID, - height uint64) { - - perUpdate(s.Policy, chanID, height) - } - - perCommittedUpdate := func(s *wtdb.ClientSession, - u *wtdb.CommittedUpdate) { - - perUpdate(s.Policy, u.BackupID.ChanID, u.BackupID.CommitHeight) - } - candidateTowers := newTowerListIterator() perActiveTower := func(tower *Tower) { // If the tower has already been marked as active, then there is @@ -429,8 +388,6 @@ func New(config *Config) (*TowerClient, error) { candidateSessions, err := getTowerAndSessionCandidates( cfg.DB, cfg.SecretKeyRing, perActiveTower, wtdb.WithPreEvalFilterFn(c.genSessionFilter(true)), - wtdb.WithPerMaxHeight(perMaxHeight), - wtdb.WithPerCommittedUpdate(perCommittedUpdate), wtdb.WithPostEvalFilterFn(ExhaustedSessionFilter()), ) if err != nil { @@ -594,7 +551,7 @@ func (c *TowerClient) Start() error { // Iterate over the list of registered channels and check if // any of them can be marked as closed. - for id := range c.summaries { + for id := range c.chanInfos { isClosed, closedHeight, err := c.isChannelClosed(id) if err != nil { returnErr = err @@ -615,7 +572,7 @@ func (c *TowerClient) Start() error { // Since the channel has been marked as closed, we can // also remove it from the channel summaries map. - delete(c.summaries, id) + delete(c.chanInfos, id) } // Load all closable sessions. @@ -732,7 +689,7 @@ func (c *TowerClient) RegisterChannel(chanID lnwire.ChannelID) error { // If a pkscript for this channel already exists, the channel has been // previously registered. - if _, ok := c.summaries[chanID]; ok { + if _, ok := c.chanInfos[chanID]; ok { return nil } @@ -752,8 +709,10 @@ func (c *TowerClient) RegisterChannel(chanID lnwire.ChannelID) error { // Finally, cache the pkscript in our in-memory cache to avoid db // lookups for the remainder of the daemon's execution. - c.summaries[chanID] = wtdb.ClientChanSummary{ - SweepPkScript: pkScript, + c.chanInfos[chanID] = &wtdb.ChannelInfo{ + ClientChanSummary: wtdb.ClientChanSummary{ + SweepPkScript: pkScript, + }, } return nil @@ -770,16 +729,23 @@ func (c *TowerClient) BackupState(chanID *lnwire.ChannelID, // Make sure that this channel is registered with the tower client. c.backupMu.Lock() - if _, ok := c.summaries[*chanID]; !ok { + info, ok := c.chanInfos[*chanID] + if !ok { c.backupMu.Unlock() return ErrUnregisteredChannel } // Ignore backups that have already been presented to the client. - height, ok := c.chanCommitHeights[*chanID] - if ok && stateNum <= height { + var duplicate bool + info.MaxHeight.WhenSome(func(maxHeight uint64) { + if stateNum <= maxHeight { + duplicate = true + } + }) + if duplicate { c.backupMu.Unlock() + c.log.Debugf("Ignoring duplicate backup for chanid=%v at "+ "height=%d", chanID, stateNum) @@ -789,7 +755,7 @@ func (c *TowerClient) BackupState(chanID *lnwire.ChannelID, // This backup has a higher commit height than any known backup for this // channel. We'll update our tip so that we won't accept it again if the // link flaps. - c.chanCommitHeights[*chanID] = stateNum + c.chanInfos[*chanID].MaxHeight = fn.Some(stateNum) c.backupMu.Unlock() id := &wtdb.BackupID{ @@ -899,7 +865,7 @@ func (c *TowerClient) handleClosedChannel(chanID lnwire.ChannelID, defer c.backupMu.Unlock() // We only care about channels registered with the tower client. - if _, ok := c.summaries[chanID]; !ok { + if _, ok := c.chanInfos[chanID]; !ok { return nil } @@ -924,8 +890,7 @@ func (c *TowerClient) handleClosedChannel(chanID lnwire.ChannelID, return fmt.Errorf("could not track closable sessions: %w", err) } - delete(c.summaries, chanID) - delete(c.chanCommitHeights, chanID) + delete(c.chanInfos, chanID) return nil } @@ -1332,7 +1297,7 @@ func (c *TowerClient) backupDispatcher() { // the prevTask, and should be reprocessed after obtaining a new sessionQueue. func (c *TowerClient) processTask(task *wtdb.BackupID) { c.backupMu.Lock() - summary, ok := c.summaries[task.ChanID] + summary, ok := c.chanInfos[task.ChanID] if !ok { c.backupMu.Unlock() diff --git a/watchtower/wtclient/interface.go b/watchtower/wtclient/interface.go index 0f7f1b5391..e691552881 100644 --- a/watchtower/wtclient/interface.go +++ b/watchtower/wtclient/interface.go @@ -81,10 +81,10 @@ type DB interface { // successfully backed up using the given session. NumAckedUpdates(id *wtdb.SessionID) (uint64, error) - // FetchChanSummaries loads a mapping from all registered channels to - // their channel summaries. Only the channels that have not yet been + // FetchChanInfos loads a mapping from all registered channels to + // their wtdb.ChannelInfo. Only the channels that have not yet been // marked as closed will be loaded. - FetchChanSummaries() (wtdb.ChannelSummaries, error) + FetchChanInfos() (wtdb.ChannelInfos, error) // MarkChannelClosed will mark a registered channel as closed by setting // its closed-height as the given block height. It returns a list of diff --git a/watchtower/wtdb/client_chan_summary.go b/watchtower/wtdb/client_chan_summary.go index d4b3c3c388..061249a51b 100644 --- a/watchtower/wtdb/client_chan_summary.go +++ b/watchtower/wtdb/client_chan_summary.go @@ -3,9 +3,24 @@ package wtdb import ( "io" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/lnwire" ) +// ChannelInfos is a map for a given channel id to it's ChannelInfo. +type ChannelInfos map[lnwire.ChannelID]*ChannelInfo + +// ChannelInfo contains various useful things about a registered channel. +type ChannelInfo struct { + ClientChanSummary + + // MaxHeight is the highest commitment height that the tower has been + // handed for this channel. An Option type is used to store this since + // a commitment height of zero is valid, and we need a way of knowing if + // we have seen a new height yet or not. + MaxHeight fn.Option[uint64] +} + // ChannelSummaries is a map for a given channel id to it's ClientChanSummary. type ChannelSummaries map[lnwire.ChannelID]ClientChanSummary diff --git a/watchtower/wtdb/client_db.go b/watchtower/wtdb/client_db.go index 2c111eb1c3..981140e2a0 100644 --- a/watchtower/wtdb/client_db.go +++ b/watchtower/wtdb/client_db.go @@ -9,6 +9,7 @@ import ( "sync" "github.com/btcsuite/btcd/btcec/v2" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/tlv" @@ -1308,51 +1309,63 @@ func (c *ClientDB) NumAckedUpdates(id *SessionID) (uint64, error) { return numAcked, nil } -// FetchChanSummaries loads a mapping from all registered channels to their -// channel summaries. Only the channels that have not yet been marked as closed -// will be loaded. -func (c *ClientDB) FetchChanSummaries() (ChannelSummaries, error) { - var summaries map[lnwire.ChannelID]ClientChanSummary +// FetchChanInfos loads a mapping from all registered channels to their +// ChannelInfo. Only the channels that have not yet been marked as closed will +// be loaded. +func (c *ClientDB) FetchChanInfos() (ChannelInfos, error) { + var infos ChannelInfos err := kvdb.View(c.db, func(tx kvdb.RTx) error { chanDetailsBkt := tx.ReadBucket(cChanDetailsBkt) if chanDetailsBkt == nil { return ErrUninitializedDB } - return chanDetailsBkt.ForEach(func(k, _ []byte) error { chanDetails := chanDetailsBkt.NestedReadBucket(k) if chanDetails == nil { return ErrCorruptChanDetails } - // If this channel has already been marked as closed, // then its summary does not need to be loaded. closedHeight := chanDetails.Get(cChanClosedHeight) if len(closedHeight) > 0 { return nil } - var chanID lnwire.ChannelID copy(chanID[:], k) - summary, err := getChanSummary(chanDetails) if err != nil { return err } - summaries[chanID] = *summary + info := &ChannelInfo{ + ClientChanSummary: *summary, + } + + maxHeightBytes := chanDetails.Get( + cChanMaxCommitmentHeight, + ) + if len(maxHeightBytes) != 0 { + height, err := readBigSize(maxHeightBytes) + if err != nil { + return err + } + + info.MaxHeight = fn.Some(height) + } + + infos[chanID] = info return nil }) }, func() { - summaries = make(map[lnwire.ChannelID]ClientChanSummary) + infos = make(ChannelInfos) }) if err != nil { return nil, err } - return summaries, nil + return infos, nil } // RegisterChannel registers a channel for use within the client database. For diff --git a/watchtower/wtdb/client_db_test.go b/watchtower/wtdb/client_db_test.go index 6d11a69728..5c12d9a56b 100644 --- a/watchtower/wtdb/client_db_test.go +++ b/watchtower/wtdb/client_db_test.go @@ -156,13 +156,13 @@ func (h *clientDBHarness) loadTowerByID(id wtdb.TowerID, return tower } -func (h *clientDBHarness) fetchChanSummaries() map[lnwire.ChannelID]wtdb.ClientChanSummary { +func (h *clientDBHarness) fetchChanInfos() wtdb.ChannelInfos { h.t.Helper() - summaries, err := h.db.FetchChanSummaries() + infos, err := h.db.FetchChanInfos() require.NoError(h.t, err) - return summaries + return infos } func (h *clientDBHarness) registerChan(chanID lnwire.ChannelID, @@ -552,7 +552,7 @@ func testRemoveTower(h *clientDBHarness) { func testChanSummaries(h *clientDBHarness) { // First, assert that this channel is not already registered. var chanID lnwire.ChannelID - _, ok := h.fetchChanSummaries()[chanID] + _, ok := h.fetchChanInfos()[chanID] require.Falsef(h.t, ok, "pkscript for channel %x should not exist yet", chanID) @@ -565,7 +565,7 @@ func testChanSummaries(h *clientDBHarness) { // Assert that the channel exists and that its sweep pkscript matches // the one we registered. - summary, ok := h.fetchChanSummaries()[chanID] + summary, ok := h.fetchChanInfos()[chanID] require.Truef(h.t, ok, "pkscript for channel %x should not exist yet", chanID) require.Equal(h.t, expPkScript, summary.SweepPkScript) From 087ab2ed82a76c05fab3c71290ac82bdc6ed233e Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Thu, 23 Nov 2023 15:49:16 +0200 Subject: [PATCH 3/3] wtdb/migration8: migrate channel max heights --- watchtower/wtdb/log.go | 2 + watchtower/wtdb/migration8/codec.go | 234 +++++++ watchtower/wtdb/migration8/log.go | 14 + watchtower/wtdb/migration8/migration.go | 224 +++++++ watchtower/wtdb/migration8/migration_test.go | 209 +++++++ watchtower/wtdb/migration8/range_index.go | 619 +++++++++++++++++++ watchtower/wtdb/version.go | 4 + 7 files changed, 1306 insertions(+) create mode 100644 watchtower/wtdb/migration8/codec.go create mode 100644 watchtower/wtdb/migration8/log.go create mode 100644 watchtower/wtdb/migration8/migration.go create mode 100644 watchtower/wtdb/migration8/migration_test.go create mode 100644 watchtower/wtdb/migration8/range_index.go diff --git a/watchtower/wtdb/log.go b/watchtower/wtdb/log.go index 639030631f..ed453383f9 100644 --- a/watchtower/wtdb/log.go +++ b/watchtower/wtdb/log.go @@ -10,6 +10,7 @@ import ( "github.com/lightningnetwork/lnd/watchtower/wtdb/migration5" "github.com/lightningnetwork/lnd/watchtower/wtdb/migration6" "github.com/lightningnetwork/lnd/watchtower/wtdb/migration7" + "github.com/lightningnetwork/lnd/watchtower/wtdb/migration8" ) // log is a logger that is initialized with no output filters. This @@ -40,6 +41,7 @@ func UseLogger(logger btclog.Logger) { migration5.UseLogger(logger) migration6.UseLogger(logger) migration7.UseLogger(logger) + migration8.UseLogger(logger) } // logClosure is used to provide a closure over expensive logging operations so diff --git a/watchtower/wtdb/migration8/codec.go b/watchtower/wtdb/migration8/codec.go new file mode 100644 index 0000000000..9c8dca1a36 --- /dev/null +++ b/watchtower/wtdb/migration8/codec.go @@ -0,0 +1,234 @@ +package migration8 + +import ( + "bytes" + "encoding/binary" + "encoding/hex" + "fmt" + "io" + + "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/tlv" +) + +// BreachHintSize is the length of the identifier used to detect remote +// commitment broadcasts. +const BreachHintSize = 16 + +// BreachHint is the first 16-bytes of SHA256(txid), which is used to identify +// the breach transaction. +type BreachHint [BreachHintSize]byte + +// ChannelID is a series of 32-bytes that uniquely identifies all channels +// within the network. The ChannelID is computed using the outpoint of the +// funding transaction (the txid, and output index). Given a funding output the +// ChannelID can be calculated by XOR'ing the big-endian serialization of the +// txid and the big-endian serialization of the output index, truncated to +// 2 bytes. +type ChannelID [32]byte + +// writeBigSize will encode the given uint64 as a BigSize byte slice. +func writeBigSize(i uint64) ([]byte, error) { + var b bytes.Buffer + err := tlv.WriteVarInt(&b, i, &[8]byte{}) + if err != nil { + return nil, err + } + + return b.Bytes(), nil +} + +// readBigSize converts the given byte slice into a uint64 and assumes that the +// bytes slice is using BigSize encoding. +func readBigSize(b []byte) (uint64, error) { + r := bytes.NewReader(b) + i, err := tlv.ReadVarInt(r, &[8]byte{}) + if err != nil { + return 0, err + } + + return i, nil +} + +// CommittedUpdate holds a state update sent by a client along with its +// allocated sequence number and the exact remote commitment the encrypted +// justice transaction can rectify. +type CommittedUpdate struct { + // SeqNum is the unique sequence number allocated by the session to this + // update. + SeqNum uint16 + + CommittedUpdateBody +} + +// BackupID identifies a particular revoked, remote commitment by channel id and +// commitment height. +type BackupID struct { + // ChanID is the channel id of the revoked commitment. + ChanID ChannelID + + // CommitHeight is the commitment height of the revoked commitment. + CommitHeight uint64 +} + +// Encode writes the BackupID from the passed io.Writer. +func (b *BackupID) Encode(w io.Writer) error { + return WriteElements(w, + b.ChanID, + b.CommitHeight, + ) +} + +// Decode reads a BackupID from the passed io.Reader. +func (b *BackupID) Decode(r io.Reader) error { + return ReadElements(r, + &b.ChanID, + &b.CommitHeight, + ) +} + +// String returns a human-readable encoding of a BackupID. +func (b BackupID) String() string { + return fmt.Sprintf("backup(%v, %d)", b.ChanID, b.CommitHeight) +} + +// WriteElements serializes a variadic list of elements into the given +// io.Writer. +func WriteElements(w io.Writer, elements ...interface{}) error { + for _, element := range elements { + if err := WriteElement(w, element); err != nil { + return err + } + } + + return nil +} + +// ReadElements deserializes the provided io.Reader into a variadic list of +// target elements. +func ReadElements(r io.Reader, elements ...interface{}) error { + for _, element := range elements { + if err := ReadElement(r, element); err != nil { + return err + } + } + + return nil +} + +// WriteElement serializes a single element into the provided io.Writer. +func WriteElement(w io.Writer, element interface{}) error { + switch e := element.(type) { + case ChannelID: + if _, err := w.Write(e[:]); err != nil { + return err + } + + case uint64: + if err := binary.Write(w, byteOrder, e); err != nil { + return err + } + + case BreachHint: + if _, err := w.Write(e[:]); err != nil { + return err + } + + case []byte: + if err := wire.WriteVarBytes(w, 0, e); err != nil { + return err + } + + default: + return fmt.Errorf("unexpected type") + } + + return nil +} + +// ReadElement deserializes a single element from the provided io.Reader. +func ReadElement(r io.Reader, element interface{}) error { + switch e := element.(type) { + case *ChannelID: + if _, err := io.ReadFull(r, e[:]); err != nil { + return err + } + + case *uint64: + if err := binary.Read(r, byteOrder, e); err != nil { + return err + } + + case *BreachHint: + if _, err := io.ReadFull(r, e[:]); err != nil { + return err + } + + case *[]byte: + bytes, err := wire.ReadVarBytes(r, 0, 66000, "[]byte") + if err != nil { + return err + } + + *e = bytes + + default: + return fmt.Errorf("unexpected type") + } + + return nil +} + +// CommittedUpdateBody represents the primary components of a CommittedUpdate. +// On disk, this is stored under the sequence number, which acts as its key. +type CommittedUpdateBody struct { + // BackupID identifies the breached commitment that the encrypted blob + // can spend from. + BackupID BackupID + + // Hint is the 16-byte prefix of the revoked commitment transaction ID. + Hint BreachHint + + // EncryptedBlob is a ciphertext containing the sweep information for + // exacting justice if the commitment transaction matching the breach + // hint is broadcast. + EncryptedBlob []byte +} + +// Encode writes the CommittedUpdateBody to the passed io.Writer. +func (u *CommittedUpdateBody) Encode(w io.Writer) error { + err := u.BackupID.Encode(w) + if err != nil { + return err + } + + return WriteElements(w, + u.Hint, + u.EncryptedBlob, + ) +} + +// Decode reads a CommittedUpdateBody from the passed io.Reader. +func (u *CommittedUpdateBody) Decode(r io.Reader) error { + err := u.BackupID.Decode(r) + if err != nil { + return err + } + + return ReadElements(r, + &u.Hint, + &u.EncryptedBlob, + ) +} + +// SessionIDSize is 33-bytes; it is a serialized, compressed public key. +const SessionIDSize = 33 + +// SessionID is created from the remote public key of a client, and serves as a +// unique identifier and authentication for sending state updates. +type SessionID [SessionIDSize]byte + +// String returns a hex encoding of the session id. +func (s SessionID) String() string { + return hex.EncodeToString(s[:]) +} diff --git a/watchtower/wtdb/migration8/log.go b/watchtower/wtdb/migration8/log.go new file mode 100644 index 0000000000..ab35682c5a --- /dev/null +++ b/watchtower/wtdb/migration8/log.go @@ -0,0 +1,14 @@ +package migration8 + +import ( + "github.com/btcsuite/btclog" +) + +// log is a logger that is initialized as disabled. This means the package will +// not perform any logging by default until a logger is set. +var log = btclog.Disabled + +// UseLogger uses a specified Logger to output package logging info. +func UseLogger(logger btclog.Logger) { + log = logger +} diff --git a/watchtower/wtdb/migration8/migration.go b/watchtower/wtdb/migration8/migration.go new file mode 100644 index 0000000000..956f718574 --- /dev/null +++ b/watchtower/wtdb/migration8/migration.go @@ -0,0 +1,224 @@ +package migration8 + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + + "github.com/lightningnetwork/lnd/kvdb" +) + +var ( + // cSessionBkt is a top-level bucket storing: + // session-id => cSessionBody -> encoded ClientSessionBody + // => cSessionDBID -> db-assigned-id + // => cSessionCommits => seqnum -> encoded CommittedUpdate + // => cSessionAckRangeIndex => db-chan-id => start -> end + // => cSessionRogueUpdateCount -> count + cSessionBkt = []byte("client-session-bucket") + + // cChanIDIndexBkt is a top-level bucket storing: + // db-assigned-id -> channel-ID + cChanIDIndexBkt = []byte("client-channel-id-index") + + // cSessionAckRangeIndex is a sub-bucket of cSessionBkt storing + // chan-id => start -> end + cSessionAckRangeIndex = []byte("client-session-ack-range-index") + + // cSessionBody is a sub-bucket of cSessionBkt storing: + // seqnum -> encoded CommittedUpdate. + cSessionCommits = []byte("client-session-commits") + + // cChanDetailsBkt is a top-level bucket storing: + // channel-id => cChannelSummary -> encoded ClientChanSummary. + // => cChanDBID -> db-assigned-id + // => cChanSessions => db-session-id -> 1 + // => cChanClosedHeight -> block-height + // => cChanMaxCommitmentHeight -> commitment-height + cChanDetailsBkt = []byte("client-channel-detail-bucket") + + cChanMaxCommitmentHeight = []byte( + "client-channel-max-commitment-height", + ) + + // ErrUninitializedDB signals that top-level buckets for the database + // have not been initialized. + ErrUninitializedDB = errors.New("db not initialized") + + byteOrder = binary.BigEndian +) + +// MigrateChannelMaxHeights migrates the tower client db by collecting all the +// max commitment heights that have been backed up for each channel and then +// storing those heights alongside the channel info. +func MigrateChannelMaxHeights(tx kvdb.RwTx) error { + log.Infof("Migrating the tower client DB for quick channel max " + + "commitment height lookup") + + heights, err := collectChanMaxHeights(tx) + if err != nil { + return err + } + + return writeChanMaxHeights(tx, heights) +} + +// writeChanMaxHeights iterates over the given channel ID to height map and +// writes an entry under the cChanMaxCommitmentHeight key for each channel. +func writeChanMaxHeights(tx kvdb.RwTx, heights map[ChannelID]uint64) error { + chanDetailsBkt := tx.ReadWriteBucket(cChanDetailsBkt) + if chanDetailsBkt == nil { + return ErrUninitializedDB + } + + for chanID, maxHeight := range heights { + chanDetails := chanDetailsBkt.NestedReadWriteBucket(chanID[:]) + + // If the details bucket for this channel ID does not exist, + // it is probably a channel that has been closed and deleted + // already. So we can skip this height. + if chanDetails == nil { + continue + } + + b, err := writeBigSize(maxHeight) + if err != nil { + return err + } + + err = chanDetails.Put(cChanMaxCommitmentHeight, b) + if err != nil { + return err + } + } + + return nil +} + +// collectChanMaxHeights iterates over all the sessions in the DB. For each +// session, it iterates over all the Acked updates and the committed updates +// to collect the maximum commitment height for each channel. +func collectChanMaxHeights(tx kvdb.RwTx) (map[ChannelID]uint64, error) { + sessionsBkt := tx.ReadBucket(cSessionBkt) + if sessionsBkt == nil { + return nil, ErrUninitializedDB + } + + chanIDIndexBkt := tx.ReadBucket(cChanIDIndexBkt) + if chanIDIndexBkt == nil { + return nil, ErrUninitializedDB + } + + heights := make(map[ChannelID]uint64) + + // For each update we consider, we will only update the heights map if + // the commitment height for the channel is larger than the current + // max height stored for the channel. + cb := func(chanID ChannelID, commitHeight uint64) { + height, ok := heights[chanID] + if !ok || commitHeight > height { + heights[chanID] = commitHeight + } + } + + err := sessionsBkt.ForEach(func(sessIDBytes, _ []byte) error { + sessBkt := sessionsBkt.NestedReadBucket(sessIDBytes) + if sessBkt == nil { + return fmt.Errorf("bucket not found for session %x", + sessIDBytes) + } + + err := forEachCommittedUpdate(sessBkt, cb) + if err != nil { + return err + } + + return forEachAckedUpdate(sessBkt, chanIDIndexBkt, cb) + }) + if err != nil { + return nil, err + } + + return heights, nil +} + +// forEachCommittedUpdate iterates over all the given session's committed +// updates and calls the call-back for each. +func forEachCommittedUpdate(sessBkt kvdb.RBucket, + cb func(chanID ChannelID, commitHeight uint64)) error { + + sessionCommits := sessBkt.NestedReadBucket(cSessionCommits) + if sessionCommits == nil { + return nil + } + + return sessionCommits.ForEach(func(k, v []byte) error { + var update CommittedUpdate + err := update.Decode(bytes.NewReader(v)) + if err != nil { + return err + } + + cb(update.BackupID.ChanID, update.BackupID.CommitHeight) + + return nil + }) +} + +// forEachAckedUpdate iterates over all the given session's acked update range +// indices and calls the call-back for each. +func forEachAckedUpdate(sessBkt, chanIDIndexBkt kvdb.RBucket, + cb func(chanID ChannelID, commitHeight uint64)) error { + + sessionAcksRanges := sessBkt.NestedReadBucket(cSessionAckRangeIndex) + if sessionAcksRanges == nil { + return nil + } + + return sessionAcksRanges.ForEach(func(dbChanID, _ []byte) error { + rangeBkt := sessionAcksRanges.NestedReadBucket(dbChanID) + if rangeBkt == nil { + return nil + } + + index, err := readRangeIndex(rangeBkt) + if err != nil { + return err + } + + chanIDBytes := chanIDIndexBkt.Get(dbChanID) + var chanID ChannelID + copy(chanID[:], chanIDBytes) + + cb(chanID, index.MaxHeight()) + + return nil + }) +} + +// readRangeIndex reads a persisted RangeIndex from the passed bucket and into +// a new in-memory RangeIndex. +func readRangeIndex(rangesBkt kvdb.RBucket) (*RangeIndex, error) { + ranges := make(map[uint64]uint64) + err := rangesBkt.ForEach(func(k, v []byte) error { + start, err := readBigSize(k) + if err != nil { + return err + } + + end, err := readBigSize(v) + if err != nil { + return err + } + + ranges[start] = end + + return nil + }) + if err != nil { + return nil, err + } + + return NewRangeIndex(ranges, WithSerializeUint64Fn(writeBigSize)) +} diff --git a/watchtower/wtdb/migration8/migration_test.go b/watchtower/wtdb/migration8/migration_test.go new file mode 100644 index 0000000000..9a9a589147 --- /dev/null +++ b/watchtower/wtdb/migration8/migration_test.go @@ -0,0 +1,209 @@ +package migration8 + +import ( + "bytes" + "testing" + + "github.com/lightningnetwork/lnd/channeldb/migtest" + "github.com/lightningnetwork/lnd/kvdb" + "github.com/stretchr/testify/require" +) + +const ( + chan1ID = 10 + chan2ID = 20 + chan3ID = 30 + chan4ID = 40 + + chan1DBID = 111 + chan2DBID = 222 + chan3DBID = 333 +) + +var ( + // preDetails is the expected data of the channel details bucket before + // the migration. + preDetails = map[string]interface{}{ + channelIDString(chan1ID): map[string]interface{}{}, + channelIDString(chan2ID): map[string]interface{}{}, + channelIDString(chan3ID): map[string]interface{}{}, + } + + // channelIDIndex is the data in the channelID index that is used to + // find the mapping between the db-assigned channel ID and the real + // channel ID. + channelIDIndex = map[string]interface{}{ + uint64ToStr(chan1DBID): channelIDString(chan1ID), + uint64ToStr(chan2DBID): channelIDString(chan2ID), + uint64ToStr(chan3DBID): channelIDString(chan3ID), + } + + // postDetails is the expected data in the channel details bucket after + // the migration. + postDetails = map[string]interface{}{ + channelIDString(chan1ID): map[string]interface{}{ + string(cChanMaxCommitmentHeight): uint64ToStr(105), + }, + channelIDString(chan2ID): map[string]interface{}{ + string(cChanMaxCommitmentHeight): uint64ToStr(205), + }, + channelIDString(chan3ID): map[string]interface{}{ + string(cChanMaxCommitmentHeight): uint64ToStr(304), + }, + } +) + +// TestMigrateChannelToSessionIndex tests that the MigrateChannelToSessionIndex +// function correctly builds the new channel-to-sessionID index to the tower +// client DB. +func TestMigrateChannelToSessionIndex(t *testing.T) { + t.Parallel() + + update1 := &CommittedUpdate{ + SeqNum: 1, + CommittedUpdateBody: CommittedUpdateBody{ + BackupID: BackupID{ + ChanID: intToChannelID(chan1ID), + CommitHeight: 105, + }, + }, + } + var update1B bytes.Buffer + require.NoError(t, update1.Encode(&update1B)) + + update3 := &CommittedUpdate{ + SeqNum: 1, + CommittedUpdateBody: CommittedUpdateBody{ + BackupID: BackupID{ + ChanID: intToChannelID(chan3ID), + CommitHeight: 304, + }, + }, + } + var update3B bytes.Buffer + require.NoError(t, update3.Encode(&update3B)) + + update4 := &CommittedUpdate{ + SeqNum: 1, + CommittedUpdateBody: CommittedUpdateBody{ + BackupID: BackupID{ + ChanID: intToChannelID(chan4ID), + CommitHeight: 400, + }, + }, + } + var update4B bytes.Buffer + require.NoError(t, update4.Encode(&update4B)) + + // sessions is the expected data in the sessions bucket before and + // after the migration. + sessions := map[string]interface{}{ + sessionIDString("1"): map[string]interface{}{ + string(cSessionAckRangeIndex): map[string]interface{}{ + // This range index gives channel 1 a max height + // of 104. + uint64ToStr(chan1DBID): map[string]interface{}{ + uint64ToStr(100): uint64ToStr(101), + uint64ToStr(104): uint64ToStr(104), + }, + // This range index gives channel 2 a max height + // of 200. + uint64ToStr(chan2DBID): map[string]interface{}{ + uint64ToStr(200): uint64ToStr(200), + }, + }, + string(cSessionCommits): map[string]interface{}{ + // This committed update gives channel 1 a max + // height of 105 and so it overrides the heights + // from the range index. + uint64ToStr(1): update1B.String(), + }, + }, + sessionIDString("2"): map[string]interface{}{ + string(cSessionAckRangeIndex): map[string]interface{}{ + // This range index gives channel 2 a max height + // of 205. + uint64ToStr(chan2DBID): map[string]interface{}{ + uint64ToStr(201): uint64ToStr(205), + }, + }, + }, + sessionIDString("3"): map[string]interface{}{ + string(cSessionCommits): map[string]interface{}{ + // This committed update gives channel 3 a max + // height of 304. + uint64ToStr(1): update3B.String(), + }, + }, + // This session only contains heights for channel 4 which has + // been closed and so this should have no effect. + sessionIDString("4"): map[string]interface{}{ + string(cSessionAckRangeIndex): map[string]interface{}{ + uint64ToStr(444): map[string]interface{}{ + uint64ToStr(400): uint64ToStr(402), + uint64ToStr(403): uint64ToStr(405), + }, + }, + string(cSessionCommits): map[string]interface{}{ + uint64ToStr(1): update4B.String(), + }, + }, + } + + // Before the migration we have a channel details + // bucket, a sessions bucket, a session ID index bucket + // and a channel ID index bucket. + before := func(tx kvdb.RwTx) error { + err := migtest.RestoreDB(tx, cChanDetailsBkt, preDetails) + if err != nil { + return err + } + + err = migtest.RestoreDB(tx, cSessionBkt, sessions) + if err != nil { + return err + } + + return migtest.RestoreDB(tx, cChanIDIndexBkt, channelIDIndex) + } + + after := func(tx kvdb.RwTx) error { + err := migtest.VerifyDB(tx, cSessionBkt, sessions) + if err != nil { + return err + } + + return migtest.VerifyDB(tx, cChanDetailsBkt, postDetails) + } + + migtest.ApplyMigration( + t, before, after, MigrateChannelMaxHeights, false, + ) +} + +func sessionIDString(id string) string { + var sessID SessionID + copy(sessID[:], id) + return sessID.String() +} + +func channelIDString(id uint64) string { + var chanID ChannelID + byteOrder.PutUint64(chanID[:], id) + return string(chanID[:]) +} + +func uint64ToStr(id uint64) string { + b, err := writeBigSize(id) + if err != nil { + panic(err) + } + + return string(b) +} + +func intToChannelID(id uint64) ChannelID { + var chanID ChannelID + byteOrder.PutUint64(chanID[:], id) + return chanID +} diff --git a/watchtower/wtdb/migration8/range_index.go b/watchtower/wtdb/migration8/range_index.go new file mode 100644 index 0000000000..94f0e20300 --- /dev/null +++ b/watchtower/wtdb/migration8/range_index.go @@ -0,0 +1,619 @@ +package migration8 + +import ( + "fmt" + "sync" +) + +// rangeItem represents the start and end values of a range. +type rangeItem struct { + start uint64 + end uint64 +} + +// RangeIndexOption describes the signature of a functional option that can be +// used to modify the behaviour of a RangeIndex. +type RangeIndexOption func(*RangeIndex) + +// WithSerializeUint64Fn is a functional option that can be used to set the +// function to be used to do the serialization of a uint64 into a byte slice. +func WithSerializeUint64Fn(fn func(uint64) ([]byte, error)) RangeIndexOption { + return func(index *RangeIndex) { + index.serializeUint64 = fn + } +} + +// RangeIndex can be used to keep track of which numbers have been added to a +// set. It does so by keeping track of a sorted list of rangeItems. Each +// rangeItem has a start and end value of a range where all values in-between +// have been added to the set. It works well in situations where it is expected +// numbers in the set are not sparse. +type RangeIndex struct { + // set is a sorted list of rangeItem. + set []rangeItem + + // mu is used to ensure safe access to set. + mu sync.Mutex + + // serializeUint64 is the function that can be used to convert a uint64 + // to a byte slice. + serializeUint64 func(uint64) ([]byte, error) +} + +// NewRangeIndex constructs a new RangeIndex. An initial set of ranges may be +// passed to the function in the form of a map. +func NewRangeIndex(ranges map[uint64]uint64, + opts ...RangeIndexOption) (*RangeIndex, error) { + + index := &RangeIndex{ + serializeUint64: defaultSerializeUint64, + set: make([]rangeItem, 0), + } + + // Apply any functional options. + for _, o := range opts { + o(index) + } + + for s, e := range ranges { + if err := index.addRange(s, e); err != nil { + return nil, err + } + } + + return index, nil +} + +// addRange can be used to add an entire new range to the set. This method +// should only ever be called by NewRangeIndex to initialise the in-memory +// structure and so the RangeIndex mutex is not held during this method. +func (a *RangeIndex) addRange(start, end uint64) error { + // Check that the given range is valid. + if start > end { + return fmt.Errorf("invalid range. Start height %d is larger "+ + "than end height %d", start, end) + } + + // min is a helper closure that will return the minimum of two uint64s. + min := func(a, b uint64) uint64 { + if a < b { + return a + } + + return b + } + + // max is a helper closure that will return the maximum of two uint64s. + max := func(a, b uint64) uint64 { + if a > b { + return a + } + + return b + } + + // Collect the ranges that fall before and after the new range along + // with the start and end values of the new range. + var before, after []rangeItem + for _, x := range a.set { + // If the new start value can't extend the current ranges end + // value, then the two cannot be merged. The range is added to + // the group of ranges that fall before the new range. + if x.end+1 < start { + before = append(before, x) + continue + } + + // If the current ranges start value does not follow on directly + // from the new end value, then the two cannot be merged. The + // range is added to the group of ranges that fall after the new + // range. + if end+1 < x.start { + after = append(after, x) + continue + } + + // Otherwise, there is an overlap and so the two can be merged. + start = min(start, x.start) + end = max(end, x.end) + } + + // Re-construct the range index set. + a.set = append(append(before, rangeItem{ + start: start, + end: end, + }), after...) + + return nil +} + +// IsInIndex returns true if the given number is in the range set. +func (a *RangeIndex) IsInIndex(n uint64) bool { + a.mu.Lock() + defer a.mu.Unlock() + + _, isCovered := a.lowerBoundIndex(n) + + return isCovered +} + +// NumInSet returns the number of items covered by the range set. +func (a *RangeIndex) NumInSet() uint64 { + a.mu.Lock() + defer a.mu.Unlock() + + var numItems uint64 + for _, r := range a.set { + numItems += r.end - r.start + 1 + } + + return numItems +} + +// MaxHeight returns the highest number covered in the range. +func (a *RangeIndex) MaxHeight() uint64 { + a.mu.Lock() + defer a.mu.Unlock() + + if len(a.set) == 0 { + return 0 + } + + return a.set[len(a.set)-1].end +} + +// GetAllRanges returns a copy of the range set in the form of a map. +func (a *RangeIndex) GetAllRanges() map[uint64]uint64 { + a.mu.Lock() + defer a.mu.Unlock() + + cp := make(map[uint64]uint64, len(a.set)) + for _, item := range a.set { + cp[item.start] = item.end + } + + return cp +} + +// lowerBoundIndex returns the index of the RangeIndex that is most appropriate +// for the new value, n. In other words, it returns the index of the rangeItem +// set of the range where the start value is the highest start value in the set +// that is still lower than or equal to the given number, n. The returned +// boolean is true if the given number is already covered in the RangeIndex. +// A returned index of -1 indicates that no lower bound range exists in the set. +// Since the most likely case is that the new number will just extend the +// highest range, a check is first done to see if this is the case which will +// make the methods' computational complexity O(1). Otherwise, a binary search +// is done which brings the computational complexity to O(log N). +func (a *RangeIndex) lowerBoundIndex(n uint64) (int, bool) { + // If the set is empty, then there is no such index and the value + // definitely is not in the set. + if len(a.set) == 0 { + return -1, false + } + + // In most cases, the last index item will be the one we want. So just + // do a quick check on that index first to avoid doing the binary + // search. + lastIndex := len(a.set) - 1 + lastRange := a.set[lastIndex] + if lastRange.start <= n { + return lastIndex, lastRange.end >= n + } + + // Otherwise, do a binary search to find the index of interest. + var ( + low = 0 + high = len(a.set) - 1 + rangeIndex = -1 + ) + for { + mid := (low + high) / 2 + currentRange := a.set[mid] + + switch { + case currentRange.start > n: + // If the start of the range is greater than n, we can + // completely cut out that entire part of the array. + high = mid + + case currentRange.start < n: + // If the range already includes the given height, we + // can stop searching now. + if currentRange.end >= n { + return mid, true + } + + // If the start of the range is smaller than n, we can + // store this as the new best index to return. + rangeIndex = mid + + // If low and mid are already equal, then increment low + // by 1. Exit if this means that low is now greater than + // high. + if low == mid { + low = mid + 1 + if low > high { + return rangeIndex, false + } + } else { + low = mid + } + + continue + + default: + // If the height is equal to the start value of the + // current range that mid is pointing to, then the + // height is already covered. + return mid, true + } + + // Exit if we have checked all the ranges. + if low == high { + break + } + } + + return rangeIndex, false +} + +// KVStore is an interface representing a key-value store. +type KVStore interface { + // Put saves the specified key/value pair to the store. Keys that do not + // already exist are added and keys that already exist are overwritten. + Put(key, value []byte) error + + // Delete removes the specified key from the bucket. Deleting a key that + // does not exist does not return an error. + Delete(key []byte) error +} + +// Add adds a single number to the range set. It first attempts to apply the +// necessary changes to the passed KV store and then only if this succeeds, will +// the changes be applied to the in-memory structure. +func (a *RangeIndex) Add(newHeight uint64, kv KVStore) error { + a.mu.Lock() + defer a.mu.Unlock() + + // Compute the changes that will need to be applied to both the sorted + // rangeItem array representation and the key-value store representation + // of the range index. + arrayChanges, kvStoreChanges := a.getChanges(newHeight) + + // First attempt to apply the KV store changes. Only if this succeeds + // will we apply the changes to our in-memory range index structure. + err := a.applyKVChanges(kv, kvStoreChanges) + if err != nil { + return err + } + + // Since the DB changes were successful, we can now commit the + // changes to our in-memory representation of the range set. + a.applyArrayChanges(arrayChanges) + + return nil +} + +// applyKVChanges applies the given set of kvChanges to a KV store. It is +// assumed that a transaction is being held on the kv store so that if any +// of the actions of the function fails, the changes will be reverted. +func (a *RangeIndex) applyKVChanges(kv KVStore, changes *kvChanges) error { + // Exit early if there are no changes to apply. + if kv == nil || changes == nil { + return nil + } + + // Check if any range pair needs to be deleted. + if changes.deleteKVKey != nil { + del, err := a.serializeUint64(*changes.deleteKVKey) + if err != nil { + return err + } + + if err := kv.Delete(del); err != nil { + return err + } + } + + start, err := a.serializeUint64(changes.key) + if err != nil { + return err + } + + end, err := a.serializeUint64(changes.value) + if err != nil { + return err + } + + return kv.Put(start, end) +} + +// applyArrayChanges applies the given arrayChanges to the in-memory RangeIndex +// itself. This should only be done once the persisted kv store changes have +// already been applied. +func (a *RangeIndex) applyArrayChanges(changes *arrayChanges) { + if changes == nil { + return + } + + if changes.indexToDelete != nil { + a.set = append( + a.set[:*changes.indexToDelete], + a.set[*changes.indexToDelete+1:]..., + ) + } + + if changes.newIndex != nil { + switch { + case *changes.newIndex == 0: + a.set = append([]rangeItem{{ + start: changes.start, + end: changes.end, + }}, a.set...) + + case *changes.newIndex == len(a.set): + a.set = append(a.set, rangeItem{ + start: changes.start, + end: changes.end, + }) + + default: + a.set = append( + a.set[:*changes.newIndex+1], + a.set[*changes.newIndex:]..., + ) + a.set[*changes.newIndex] = rangeItem{ + start: changes.start, + end: changes.end, + } + } + + return + } + + if changes.indexToEdit != nil { + a.set[*changes.indexToEdit] = rangeItem{ + start: changes.start, + end: changes.end, + } + } +} + +// arrayChanges encompasses the diff to apply to the sorted rangeItem array +// representation of a range index. Such a diff will either include adding a +// new range or editing an existing range. If an existing range is edited, then +// the diff might also include deleting an index (this will be the case if the +// editing of the one range results in the merge of another range). +type arrayChanges struct { + start uint64 + end uint64 + + // newIndex, if set, is the index of the in-memory range array where a + // new range, [start:end], should be added. newIndex should never be + // set at the same time as indexToEdit or indexToDelete. + newIndex *int + + // indexToDelete, if set, is the index of the sorted rangeItem array + // that should be deleted. This should be applied before reading the + // index value of indexToEdit. This should not be set at the same time + // as newIndex. + indexToDelete *int + + // indexToEdit is the index of the in-memory range array that should be + // edited. The range at this index will be changed to [start:end]. This + // should only be read after indexToDelete index has been deleted. + indexToEdit *int +} + +// kvChanges encompasses the diff to apply to a KV-store representation of a +// range index. A kv-store diff for the addition of a single number to the range +// index will include either a brand new key-value pair or the altering of the +// value of an existing key. Optionally, the diff may also include the deletion +// of an existing key. A deletion will be required if the addition of the new +// number results in the merge of two ranges. +type kvChanges struct { + key uint64 + value uint64 + + // deleteKVKey, if set, is the key of the kv store representation that + // should be deleted. + deleteKVKey *uint64 +} + +// getChanges will calculate and return the changes that need to be applied to +// both the sorted-rangeItem-array representation and the key-value store +// representation of the range index. +func (a *RangeIndex) getChanges(n uint64) (*arrayChanges, *kvChanges) { + // If the set is empty then a new range item is added. + if len(a.set) == 0 { + // For the array representation, a new range [n:n] is added to + // the first index of the array. + firstIndex := 0 + ac := &arrayChanges{ + newIndex: &firstIndex, + start: n, + end: n, + } + + // For the KV representation, a new [n:n] pair is added. + kvc := &kvChanges{ + key: n, + value: n, + } + + return ac, kvc + } + + // Find the index of the lower bound range to the new number. + indexOfRangeBelow, alreadyCovered := a.lowerBoundIndex(n) + + switch { + // The new number is already covered by the range index. No changes are + // required. + case alreadyCovered: + return nil, nil + + // No lower bound index exists. + case indexOfRangeBelow < 0: + // Check if the very first range can be merged into this new + // one. + if n+1 == a.set[0].start { + // If so, the two ranges can be merged and so the start + // value of the range is n and the end value is the end + // of the existing first range. + start := n + end := a.set[0].end + + // For the array representation, we can just edit the + // first entry of the array + editIndex := 0 + ac := &arrayChanges{ + indexToEdit: &editIndex, + start: start, + end: end, + } + + // For the KV store representation, we add a new kv pair + // and delete the range with the key equal to the start + // value of the range we are merging. + kvKeyToDelete := a.set[0].start + kvc := &kvChanges{ + key: start, + value: end, + deleteKVKey: &kvKeyToDelete, + } + + return ac, kvc + } + + // Otherwise, we add a new index. + + // For the array representation, a new range [n:n] is added to + // the first index of the array. + newIndex := 0 + ac := &arrayChanges{ + newIndex: &newIndex, + start: n, + end: n, + } + + // For the KV representation, a new [n:n] pair is added. + kvc := &kvChanges{ + key: n, + value: n, + } + + return ac, kvc + + // A lower range does exist, and it can be extended to include this new + // number. + case a.set[indexOfRangeBelow].end+1 == n: + start := a.set[indexOfRangeBelow].start + end := n + indexToChange := indexOfRangeBelow + + // If there are no intervals above this one or if there are, but + // they can't be merged into this one then we just need to edit + // this interval. + if indexOfRangeBelow == len(a.set)-1 || + a.set[indexOfRangeBelow+1].start != n+1 { + + // For the array representation, we just edit the index. + ac := &arrayChanges{ + indexToEdit: &indexToChange, + start: start, + end: end, + } + + // For the key-value representation, we just overwrite + // the end value at the existing start key. + kvc := &kvChanges{ + key: start, + value: end, + } + + return ac, kvc + } + + // There is a range above this one that we need to merge into + // this one. + delIndex := indexOfRangeBelow + 1 + end = a.set[delIndex].end + + // For the array representation, we delete the range above this + // one and edit this range to include the end value of the range + // above. + ac := &arrayChanges{ + indexToDelete: &delIndex, + indexToEdit: &indexToChange, + start: start, + end: end, + } + + // For the kv representation, we tweak the end value of an + // existing key and delete the key of the range we are deleting. + deleteKey := a.set[delIndex].start + kvc := &kvChanges{ + key: start, + value: end, + deleteKVKey: &deleteKey, + } + + return ac, kvc + + // A lower range does exist, but it can't be extended to include this + // new number, and so we need to add a new range after the lower bound + // range. + default: + newIndex := indexOfRangeBelow + 1 + + // If there are no ranges above this new one or if there are, + // but they can't be merged into this new one, then we can just + // add the new one as is. + if newIndex == len(a.set) || a.set[newIndex].start != n+1 { + ac := &arrayChanges{ + newIndex: &newIndex, + start: n, + end: n, + } + + kvc := &kvChanges{ + key: n, + value: n, + } + + return ac, kvc + } + + // Else, we merge the above index. + start := n + end := a.set[newIndex].end + toEdit := newIndex + + // For the array representation, we edit the range above to + // include the new start value. + ac := &arrayChanges{ + indexToEdit: &toEdit, + start: start, + end: end, + } + + // For the kv representation, we insert the new start-end key + // value pair and delete the key using the old start value. + delKey := a.set[newIndex].start + kvc := &kvChanges{ + key: start, + value: end, + deleteKVKey: &delKey, + } + + return ac, kvc + } +} + +func defaultSerializeUint64(i uint64) ([]byte, error) { + var b [8]byte + byteOrder.PutUint64(b[:], i) + return b[:], nil +} diff --git a/watchtower/wtdb/version.go b/watchtower/wtdb/version.go index b44ed80eb3..dd9c554723 100644 --- a/watchtower/wtdb/version.go +++ b/watchtower/wtdb/version.go @@ -12,6 +12,7 @@ import ( "github.com/lightningnetwork/lnd/watchtower/wtdb/migration5" "github.com/lightningnetwork/lnd/watchtower/wtdb/migration6" "github.com/lightningnetwork/lnd/watchtower/wtdb/migration7" + "github.com/lightningnetwork/lnd/watchtower/wtdb/migration8" ) // txMigration is a function which takes a prior outdated version of the @@ -67,6 +68,9 @@ var clientDBVersions = []version{ { txMigration: migration7.MigrateChannelToSessionIndex, }, + { + txMigration: migration8.MigrateChannelMaxHeights, + }, } // getLatestDBVersion returns the last known database version.