diff --git a/lntypes/channel_party.go b/lntypes/channel_party.go index 5848becee6..82cbd1045e 100644 --- a/lntypes/channel_party.go +++ b/lntypes/channel_party.go @@ -117,3 +117,5 @@ func MapDual[A, B any](d Dual[A], f func(A) B) Dual[B] { Remote: f(d.Remote), } } + +var BothParties []ChannelParty = []ChannelParty{Local, Remote} diff --git a/lnwallet/channel.go b/lnwallet/channel.go index 6b6f444313..8af684c794 100644 --- a/lnwallet/channel.go +++ b/lnwallet/channel.go @@ -2890,16 +2890,14 @@ func fundingTxIn(chanState *channeldb.OpenChannel) wire.TxIn { // returned reflects the current state of HTLCs within the remote or local // commitment chain, and the current commitment fee rate. // -// If mutateState is set to true, then the add height of all added HTLCs -// will be set to nextHeight, and the remove height of all removed HTLCs -// will be set to nextHeight. This should therefore only be set to true -// once for each height, and only in concert with signing a new commitment. -// TODO(halseth): return htlcs to mutate instead of mutating inside -// method. +// The return values of this function are as follows: +// 1. The new htlcView reflecting the current channel state. +// 2. A Dual of the updates which have not yet been committed in +// 'whoseCommitChain's commitment chain. func (lc *LightningChannel) evaluateHTLCView(view *HtlcView, ourBalance, theirBalance *lnwire.MilliSatoshi, nextHeight uint64, - whoseCommitChain lntypes.ChannelParty, mutateState bool) (*HtlcView, - error) { + whoseCommitChain lntypes.ChannelParty) (*HtlcView, + lntypes.Dual[[]*paymentDescriptor], error) { // We initialize the view's fee rate to the fee rate of the unfiltered // view. If any fee updates are found when evaluating the view, it will @@ -2917,8 +2915,7 @@ func (lc *LightningChannel) evaluateHTLCView(view *HtlcView, ourBalance, skipThem := make(map[uint64]struct{}) // First we run through non-add entries in both logs, populating the - // skip sets and mutating the current chain state (crediting balances, - // etc) to reflect the settle/timeout entry encountered. + // skip sets. for _, entry := range view.OurUpdates { switch entry.EntryType { // Skip adds for now. They will be processed below. @@ -2938,53 +2935,31 @@ func (lc *LightningChannel) evaluateHTLCView(view *HtlcView, ourBalance, newView.FeePerKw = chainfee.SatPerKWeight( entry.Amount.ToSatoshis(), ) - - if mutateState { - entry.addCommitHeights.SetForParty( - whoseCommitChain, nextHeight, - ) - - entry.removeCommitHeights.SetForParty( - whoseCommitChain, nextHeight, - ) - } } - continue - } - - // If we're settling an inbound HTLC, and it hasn't been - // processed yet, then increment our state tracking the total - // number of satoshis we've received within the channel. - if mutateState && entry.EntryType == Settle && - whoseCommitChain.IsLocal() && - entry.removeCommitHeights.Local == 0 { - lc.channelState.TotalMSatReceived += entry.Amount + continue } addEntry, err := lc.fetchParent( entry, whoseCommitChain, lntypes.Remote, ) if err != nil { - return nil, err + return nil, lntypes.Dual[[]*paymentDescriptor]{}, err } skipThem[addEntry.HtlcIndex] = struct{}{} - rmvHeights := &entry.removeCommitHeights - rmvHeight := rmvHeights.GetForParty(whoseCommitChain) + rmvHeight := entry.removeCommitHeights.GetForParty( + whoseCommitChain, + ) if rmvHeight == 0 { processRemoveEntry( entry, ourBalance, theirBalance, true, ) - - if mutateState { - rmvHeights.SetForParty( - whoseCommitChain, nextHeight, - ) - } } } + + // Do the same for our peer's updates. for _, entry := range view.TheirUpdates { switch entry.EntryType { // Skip adds for now. They will be processed below. @@ -3004,53 +2979,27 @@ func (lc *LightningChannel) evaluateHTLCView(view *HtlcView, ourBalance, newView.FeePerKw = chainfee.SatPerKWeight( entry.Amount.ToSatoshis(), ) - - - if mutateState { - entry.addCommitHeights.SetForParty( - whoseCommitChain, nextHeight, - ) - - entry.removeCommitHeights.SetForParty( - whoseCommitChain, nextHeight, - ) - } } - continue - } - // If the remote party is settling one of our outbound HTLC's, - // and it hasn't been processed, yet, the increment our state - // tracking the total number of satoshis we've sent within the - // channel. - if mutateState && entry.EntryType == Settle && - whoseCommitChain.IsLocal() && - entry.removeCommitHeights.Local == 0 { - - lc.channelState.TotalMSatSent += entry.Amount + continue } addEntry, err := lc.fetchParent( entry, whoseCommitChain, lntypes.Local, ) if err != nil { - return nil, err + return nil, lntypes.Dual[[]*paymentDescriptor]{}, err } skipUs[addEntry.HtlcIndex] = struct{}{} - rmvHeights := &entry.removeCommitHeights - rmvHeight := rmvHeights.GetForParty(whoseCommitChain) + rmvHeight := entry.removeCommitHeights.GetForParty( + whoseCommitChain, + ) if rmvHeight == 0 { processRemoveEntry( entry, ourBalance, theirBalance, false, ) - - if mutateState { - rmvHeights.SetForParty( - whoseCommitChain, nextHeight, - ) - } } } @@ -3065,25 +3014,19 @@ func (lc *LightningChannel) evaluateHTLCView(view *HtlcView, ourBalance, // Skip the entries that have already had their add commit // height set for this commit chain. - addHeights := &entry.addCommitHeights - addHeight := addHeights.GetForParty(whoseCommitChain) + addHeight := entry.addCommitHeights.GetForParty( + whoseCommitChain, + ) if addHeight == 0 { processAddEntry( entry, ourBalance, theirBalance, false, ) - - // If we are mutating the state, then set the add - // height for the appropriate commitment chain to the - // next height. - if mutateState { - addHeights.SetForParty( - whoseCommitChain, nextHeight, - ) - } } newView.OurUpdates = append(newView.OurUpdates, entry) } + + // Again, we do the same for our peer's updates. for _, entry := range view.TheirUpdates { isAdd := entry.EntryType == Add if _, ok := skipThem[entry.HtlcIndex]; !isAdd || ok { @@ -3092,27 +3035,51 @@ func (lc *LightningChannel) evaluateHTLCView(view *HtlcView, ourBalance, // Skip the entries that have already had their add commit // height set for this commit chain. - addHeights := &entry.addCommitHeights - addHeight := addHeights.GetForParty(whoseCommitChain) + addHeight := entry.addCommitHeights.GetForParty( + whoseCommitChain, + ) if addHeight == 0 { processAddEntry( entry, ourBalance, theirBalance, true, ) - - // If we are mutating the state, then set the add - // height for the appropriate commitment chain to the - // next height. - if mutateState { - addHeights.SetForParty( - whoseCommitChain, nextHeight, - ) - } } newView.TheirUpdates = append(newView.TheirUpdates, entry) } - return newView, nil + // Create a function that is capable of identifying whether or not the + // paymentDescriptor has been committed in the commitment chain + // corresponding to whoseCommitmentChain. + isUncommitted := func(update *paymentDescriptor) bool { + switch update.EntryType { + case Add: + return update.addCommitHeights.GetForParty( + whoseCommitChain, + ) == 0 + + case FeeUpdate: + return update.addCommitHeights.GetForParty( + whoseCommitChain, + ) == 0 + + case Settle, Fail, MalformedFail: + return update.removeCommitHeights.GetForParty( + whoseCommitChain, + ) == 0 + + default: + panic("invalid paymentDescriptor EntryType") + } + } + + // Collect all of the updates that haven't had their commit heights set + // for the commitment chain corresponding to whoseCommitmentChain. + uncommittedUpdates := lntypes.Dual[[]*paymentDescriptor]{ + Local: fn.Filter(isUncommitted, view.OurUpdates), + Remote: fn.Filter(isUncommitted, view.TheirUpdates), + } + + return newView, uncommittedUpdates, nil } // fetchParent is a helper that looks up update log parent entries in the @@ -4683,13 +4650,27 @@ func (lc *LightningChannel) computeView(view *HtlcView, // channel constraints to the final commitment state. If any fee // updates are found in the logs, the commitment fee rate should be // changed, so we'll also set the feePerKw to this new value. - filteredHTLCView, err := lc.evaluateHTLCView( + filteredHTLCView, uncommitted, err := lc.evaluateHTLCView( view, &ourBalance, &theirBalance, nextHeight, whoseCommitChain, - updateState, ) if err != nil { return 0, 0, 0, nil, err } + + if updateState { + for _, party := range lntypes.BothParties { + for _, u := range uncommitted.GetForParty(party) { + u.setCommitHeight(whoseCommitChain, nextHeight) + + if whoseCommitChain == lntypes.Local && + u.EntryType == Settle { + + lc.recordSettlement(party, u.Amount) + } + } + } + } + feePerKw := filteredHTLCView.FeePerKw // Here we override the view's fee-rate if a dry-run fee-rate was @@ -4742,6 +4723,18 @@ func (lc *LightningChannel) computeView(view *HtlcView, return ourBalance, theirBalance, totalCommitWeight, filteredHTLCView, nil } +// recordSettlement updates the lifetime payment flow values in persistent state +// of the LightningChannel, adding amt to the total received by the redeemer. +func (lc *LightningChannel) recordSettlement( + redeemer lntypes.ChannelParty, amt lnwire.MilliSatoshi) { + + if redeemer == lntypes.Local { + lc.channelState.TotalMSatReceived += amt + } else { + lc.channelState.TotalMSatSent += amt + } +} + // genHtlcSigValidationJobs generates a series of signatures verification jobs // meant to verify all the signatures for HTLC's attached to a newly created // commitment state. The jobs generated are fully populated, and can be sent diff --git a/lnwallet/channel_test.go b/lnwallet/channel_test.go index 50eb24663e..0372e4d166 100644 --- a/lnwallet/channel_test.go +++ b/lnwallet/channel_test.go @@ -8956,14 +8956,40 @@ func TestEvaluateView(t *testing.T) { ) // Evaluate the htlc view, mutate as test expects. - result, err := lc.evaluateHTLCView( + result, uncommitted, err := lc.evaluateHTLCView( view, &ourBalance, &theirBalance, nextHeight, - test.whoseCommitChain, test.mutateState, + test.whoseCommitChain, ) + if err != nil { t.Fatalf("unexpected error: %v", err) } + // TODO(proofofkeags): This block is here because we + // extracted this code from a previous implementation + // of evaluateHTLCView, due to a reduced scope of + // responsibility of that function. Consider removing + // it from the test altogether. + if test.mutateState { + for _, party := range lntypes.BothParties { + us := uncommitted.GetForParty(party) + for _, u := range us { + u.setCommitHeight( + test.whoseCommitChain, + nextHeight, + ) + if test.whoseCommitChain == + lntypes.Local && + u.EntryType == Settle { + + lc.recordSettlement( + party, u.Amount, + ) + } + } + } + } + if result.FeePerKw != test.expectedFee { t.Fatalf("expected fee: %v, got: %v", test.expectedFee, result.FeePerKw) diff --git a/lnwallet/payment_descriptor.go b/lnwallet/payment_descriptor.go index ffa4cc8ce1..a8edb1e7e6 100644 --- a/lnwallet/payment_descriptor.go +++ b/lnwallet/payment_descriptor.go @@ -283,3 +283,31 @@ func (pd *paymentDescriptor) toLogUpdate() channeldb.LogUpdate { UpdateMsg: msg, } } + +// setCommitHeight updates the appropriate addCommitHeight and/or +// removeCommitHeight for whoseCommitChain and locks it in at nextHeight. +func (pd *paymentDescriptor) setCommitHeight( + whoseCommitChain lntypes.ChannelParty, nextHeight uint64) { + + switch pd.EntryType { + case Add: + pd.addCommitHeights.SetForParty( + whoseCommitChain, nextHeight, + ) + case Settle, Fail, MalformedFail: + pd.removeCommitHeights.SetForParty( + whoseCommitChain, nextHeight, + ) + case FeeUpdate: + // Fee updates are applied for all commitments + // after they are sent/received, so we consider + // them being added and removed at the same + // height. + pd.addCommitHeights.SetForParty( + whoseCommitChain, nextHeight, + ) + pd.removeCommitHeights.SetForParty( + whoseCommitChain, nextHeight, + ) + } +}