From f40ef193e906360a6a5e1eb89b4f78107cf5632f Mon Sep 17 00:00:00 2001 From: Andras Banki-Horvath Date: Thu, 30 May 2024 22:38:56 +0200 Subject: [PATCH] loop: add migration for loopout swaps to fix negative stored costs Previously we may have stored negative costs for some loop out swaps which this commit attempts to correct by fetching all completed swap, calculating the correct costs and overriding them in the database. --- cost_migration.go | 177 +++++++++++++++++++++++++++++++++++++++ cost_migration_test.go | 184 +++++++++++++++++++++++++++++++++++++++++ loopdb/store_mock.go | 15 +++- server_mock_test.go | 7 +- 4 files changed, 380 insertions(+), 3 deletions(-) create mode 100644 cost_migration.go create mode 100644 cost_migration_test.go diff --git a/cost_migration.go b/cost_migration.go new file mode 100644 index 000000000..247f63fad --- /dev/null +++ b/cost_migration.go @@ -0,0 +1,177 @@ +package loop + +import ( + "context" + "fmt" + "time" + + "github.com/btcsuite/btcd/chaincfg" + "github.com/lightninglabs/lndclient" + "github.com/lightninglabs/loop/loopdb" + "github.com/lightninglabs/loop/swap" + "github.com/lightningnetwork/lnd/lntypes" + "github.com/lightningnetwork/lnd/lnwire" +) + +const ( + costMigrationID = "cost_migration" +) + +// CalculateLoopOutCost calculates the total cost of a loop out swap. It will +// correctly account for the on-chain and off-chain fees that were paid and +// make sure that all costs are positive. +func CalculateLoopOutCost(params *chaincfg.Params, loopOutSwap *loopdb.LoopOut, + paymentFees map[lntypes.Hash]lnwire.MilliSatoshi) (loopdb.SwapCost, + error) { + + // First make sure that this swap is actually finished. + if loopOutSwap.State().State.IsPending() { + return loopdb.SwapCost{}, fmt.Errorf("swap is not yet finished") + } + + // We first need to decode the prepay invoice to get the prepay hash and + // the prepay amount. + _, _, hash, prepayAmount, err := swap.DecodeInvoice( + params, loopOutSwap.Contract.PrepayInvoice, + ) + if err != nil { + return loopdb.SwapCost{}, fmt.Errorf("unable to decode the "+ + "prepay invoice: %v", err) + } + + // The swap hash is given and we don't need to get it from the + // swap invoice, however we'll decode it anyway to get the invoice amount + // that was paid in case we don't have the payment anymore. + _, _, swapHash, swapPaymentAmount, err := swap.DecodeInvoice( + params, loopOutSwap.Contract.SwapInvoice, + ) + if err != nil { + return loopdb.SwapCost{}, fmt.Errorf("unable to decode the "+ + "swap invoice: %v", err) + } + + var ( + cost loopdb.SwapCost + swapPaid, prepayPaid bool + ) + + // Now that we have the prepay and swap amount, we can calculate the + // total cost of the swap. Note that we only need to account for the + // server cost in case the swap was successful or if the sweep timed + // out. Otherwise the server didn't pull the off-chain htlc nor the + // prepay. + switch loopOutSwap.State().State { + case loopdb.StateSuccess: + cost.Server = swapPaymentAmount + prepayAmount - + loopOutSwap.Contract.AmountRequested + + swapPaid = true + prepayPaid = true + + case loopdb.StateFailSweepTimeout: + cost.Server = prepayAmount + + prepayPaid = true + + default: + cost.Server = 0 + } + + // Now attempt to look up the actual payments so we can calculate the + // total routing costs. + prepayPaymentFee, ok := paymentFees[hash] + if prepayPaid && ok { + cost.Offchain += prepayPaymentFee.ToSatoshis() + } else { + log.Debugf("Prepay payment %s is missing, won't account for "+ + "routing fees", hash) + } + + swapPaymentFee, ok := paymentFees[swapHash] + if swapPaid && ok { + cost.Offchain += swapPaymentFee.ToSatoshis() + } else { + log.Debugf("Swap payment %s is missing, won't account for "+ + "routing fees", swapHash) + } + + // For the on-chain cost, just make sure that the cost is positive. + cost.Onchain = loopOutSwap.State().Cost.Onchain + if cost.Onchain < 0 { + cost.Onchain *= -1 + } + + return cost, nil +} + +// MigrateLoopOutCosts will calculate the correct cost for all loop out swaps +// and override the cost values of the last update in the database. +func MigrateLoopOutCosts(ctx context.Context, lnd lndclient.LndServices, + db loopdb.SwapStore) error { + + migrationDone, err := db.HasMigration(ctx, costMigrationID) + if err != nil { + return err + } + if migrationDone { + log.Infof("Cost cleanup migration already done, skipping") + + return nil + } + + log.Infof("Starting cost cleanup migration") + startTs := time.Now() + defer func() { + log.Infof("Finished cost cleanup migration in %v", + time.Since(startTs)) + }() + + // First we'll fetch all loop out swaps from the database. + loopOutSwaps, err := db.FetchLoopOutSwaps(ctx) + if err != nil { + return err + } + + // Next we fetch all payments from LND. + payments, err := lnd.Client.ListPayments( + ctx, lndclient.ListPaymentsRequest{}, + ) + if err != nil { + return err + } + + // Gather payment fees to a map for easier lookup. + paymentFees := make(map[lntypes.Hash]lnwire.MilliSatoshi) + for _, payment := range payments.Payments { + paymentFees[payment.Hash] = payment.Fee + } + + // Now we'll calculate the cost for each swap and finally update the + // costs in the database. + updatedCosts := make(map[lntypes.Hash]loopdb.SwapCost) + for _, loopOutSwap := range loopOutSwaps { + cost, err := CalculateLoopOutCost( + lnd.ChainParams, loopOutSwap, paymentFees, + ) + if err != nil { + return err + } + + _, ok := updatedCosts[loopOutSwap.Hash] + if ok { + return fmt.Errorf("found a duplicate swap %v while "+ + "updating costs", loopOutSwap.Hash) + } + + updatedCosts[loopOutSwap.Hash] = cost + } + + log.Infof("Updating costs for %d loop out swaps", len(updatedCosts)) + err = db.BatchUpdateLoopOutSwapCosts(ctx, updatedCosts) + if err != nil { + return err + } + + // Finally mark the migration as done. + return db.SetMigration(ctx, costMigrationID) +} diff --git a/cost_migration_test.go b/cost_migration_test.go new file mode 100644 index 000000000..ee5525da5 --- /dev/null +++ b/cost_migration_test.go @@ -0,0 +1,184 @@ +package loop + +import ( + "context" + "testing" + "time" + + "github.com/btcsuite/btcd/btcutil" + "github.com/lightninglabs/lndclient" + "github.com/lightninglabs/loop/loopdb" + "github.com/lightninglabs/loop/test" + "github.com/lightningnetwork/lnd/lntypes" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/stretchr/testify/require" +) + +// TestCalculateLoopOutCost tests the CalculateLoopOutCost function. +func TestCalculateLoopOutCost(t *testing.T) { + // Set up test context objects. + lnd := test.NewMockLnd() + server := newServerMock(lnd) + store := loopdb.NewStoreMock(t) + + cfg := &swapConfig{ + lnd: &lnd.LndServices, + store: store, + server: server, + } + + height := int32(600) + req := *testRequest + initResult, err := newLoopOutSwap( + context.Background(), cfg, height, &req, + ) + require.NoError(t, err) + swap, err := store.FetchLoopOutSwap( + context.Background(), initResult.swap.hash, + ) + require.NoError(t, err) + + // Override the chain cost so it's negative. + const expectedChainCost = btcutil.Amount(1000) + + // Now we have the swap and prepay invoices so let's calculate the + // costs without providing the payments first, so we don't account for + // any routing fees. + paymentFees := make(map[lntypes.Hash]lnwire.MilliSatoshi) + _, err = CalculateLoopOutCost(lnd.ChainParams, swap, paymentFees) + + // We expect that the call fails as the swap isn't finished yet. + require.Error(t, err) + + // Override the swap state to make it look like the swap is finished + // and make the chain cost negative too, so we can test that it'll be + // corrected to be positive in the cost calculation. + swap.Events = append( + swap.Events, &loopdb.LoopEvent{ + SwapStateData: loopdb.SwapStateData{ + State: loopdb.StateSuccess, + Cost: loopdb.SwapCost{ + Onchain: -expectedChainCost, + }, + }, + }, + ) + costs, err := CalculateLoopOutCost(lnd.ChainParams, swap, paymentFees) + require.NoError(t, err) + + expectedServerCost := server.swapInvoiceAmt + server.prepayInvoiceAmt - + swap.Contract.AmountRequested + require.Equal(t, expectedServerCost, costs.Server) + require.Equal(t, btcutil.Amount(0), costs.Offchain) + require.Equal(t, expectedChainCost, costs.Onchain) + + // Now add the two payments to the payments map and calculate the costs + // again. We expect that the routng fees are now accounted for. + paymentFees[server.swapHash] = lnwire.NewMSatFromSatoshis(44) + paymentFees[server.prepayHash] = lnwire.NewMSatFromSatoshis(11) + + costs, err = CalculateLoopOutCost(lnd.ChainParams, swap, paymentFees) + require.NoError(t, err) + + expectedOffchainCost := btcutil.Amount(44 + 11) + require.Equal(t, expectedServerCost, costs.Server) + require.Equal(t, expectedOffchainCost, costs.Offchain) + require.Equal(t, expectedChainCost, costs.Onchain) + + // Now override the last update to make the swap timed out at the HTLC + // sweep. We expect that the chain cost won't change, and only the + // prepay will be accounted for. + swap.Events[0] = &loopdb.LoopEvent{ + SwapStateData: loopdb.SwapStateData{ + State: loopdb.StateFailSweepTimeout, + Cost: loopdb.SwapCost{ + Onchain: 0, + }, + }, + } + + costs, err = CalculateLoopOutCost(lnd.ChainParams, swap, paymentFees) + require.NoError(t, err) + + expectedServerCost = server.prepayInvoiceAmt + expectedOffchainCost = btcutil.Amount(11) + require.Equal(t, expectedServerCost, costs.Server) + require.Equal(t, expectedOffchainCost, costs.Offchain) + require.Equal(t, btcutil.Amount(0), costs.Onchain) +} + +// TestCostMigration tests the cost migration for loop out swaps. +func TestCostMigration(t *testing.T) { + // Set up test context objects. + lnd := test.NewMockLnd() + server := newServerMock(lnd) + store := loopdb.NewStoreMock(t) + + cfg := &swapConfig{ + lnd: &lnd.LndServices, + store: store, + server: server, + } + + height := int32(600) + req := *testRequest + initResult, err := newLoopOutSwap( + context.Background(), cfg, height, &req, + ) + require.NoError(t, err) + + // Override the chain cost so it's negative. + const expectedChainCost = btcutil.Amount(1000) + + // Override the swap state to make it look like the swap is finished + // and make the chain cost negative too, so we can test that it'll be + // corrected to be positive in the cost calculation. + err = store.UpdateLoopOut( + context.Background(), initResult.swap.hash, time.Now(), + loopdb.SwapStateData{ + State: loopdb.StateSuccess, + Cost: loopdb.SwapCost{ + Onchain: -expectedChainCost, + }, + }, + ) + require.NoError(t, err) + + // Add the two mocked payment to LND. Note that we only care about the + // fees here, so we don't need to provide the full payment details. + lnd.Payments = []lndclient.Payment{ + { + Hash: server.swapHash, + Fee: lnwire.NewMSatFromSatoshis(44), + }, + { + Hash: server.prepayHash, + Fee: lnwire.NewMSatFromSatoshis(11), + }, + } + + // Now we can run the migration. + err = MigrateLoopOutCosts(context.Background(), lnd.LndServices, store) + require.NoError(t, err) + + // Finally check that the swap cost has been updated correctly. + swap, err := store.FetchLoopOutSwap( + context.Background(), initResult.swap.hash, + ) + require.NoError(t, err) + + expectedServerCost := server.swapInvoiceAmt + server.prepayInvoiceAmt - + swap.Contract.AmountRequested + + costs := swap.Events[0].Cost + expectedOffchainCost := btcutil.Amount(44 + 11) + require.Equal(t, expectedServerCost, costs.Server) + require.Equal(t, expectedOffchainCost, costs.Offchain) + require.Equal(t, expectedChainCost, costs.Onchain) + + // Now run the migration again to make sure it doesn't fail. This also + // indicates that the migration did not run the second time as + // otherwise the store mocks SetMigration function would fail. + err = MigrateLoopOutCosts(context.Background(), lnd.LndServices, store) + require.NoError(t, err) +} diff --git a/loopdb/store_mock.go b/loopdb/store_mock.go index 268bba6b9..955ae5c4d 100644 --- a/loopdb/store_mock.go +++ b/loopdb/store_mock.go @@ -24,6 +24,8 @@ type StoreMock struct { loopInStoreChan chan LoopInContract loopInUpdateChan chan SwapStateData + migrations map[string]struct{} + t *testing.T } @@ -39,6 +41,7 @@ func NewStoreMock(t *testing.T) *StoreMock { loopInUpdateChan: make(chan SwapStateData, 1), LoopInSwaps: make(map[lntypes.Hash]*LoopInContract), LoopInUpdates: make(map[lntypes.Hash][]SwapStateData), + migrations: make(map[string]struct{}), t: t, } } @@ -364,12 +367,20 @@ func (s *StoreMock) BatchUpdateLoopOutSwapCosts(ctx context.Context, func (s *StoreMock) HasMigration(ctx context.Context, migrationID string) ( bool, error) { - return false, errUnimplemented + _, ok := s.migrations[migrationID] + + return ok, nil } // SetMigration marks the migration with the given ID as done. func (s *StoreMock) SetMigration(ctx context.Context, migrationID string) error { - return errUnimplemented + if _, ok := s.migrations[migrationID]; ok { + return errors.New("migration already done") + } + + s.migrations[migrationID] = struct{}{} + + return nil } diff --git a/server_mock_test.go b/server_mock_test.go index 9f52a2c07..f46a49774 100644 --- a/server_mock_test.go +++ b/server_mock_test.go @@ -42,6 +42,7 @@ type serverMock struct { swapInvoice string swapHash lntypes.Hash + prepayHash lntypes.Hash // preimagePush is a channel that preimage pushes are sent into. preimagePush chan lntypes.Preimage @@ -81,13 +82,17 @@ func (s *serverMock) NewLoopOutSwap(_ context.Context, swapHash lntypes.Hash, return nil, errors.New("unexpected test swap amount") } + s.swapHash = swapHash swapPayReqString, err := getInvoice(swapHash, s.swapInvoiceAmt, swapInvoiceDesc) if err != nil { return nil, err } - prePayReqString, err := getInvoice(swapHash, s.prepayInvoiceAmt, + // Set the prepay hash to be different from the swap hash. + s.prepayHash = swapHash + s.prepayHash[0] ^= 1 + prePayReqString, err := getInvoice(s.prepayHash, s.prepayInvoiceAmt, prepayInvoiceDesc) if err != nil { return nil, err