diff --git a/cost_migration.go b/cost_migration.go new file mode 100644 index 000000000..247f63fad --- /dev/null +++ b/cost_migration.go @@ -0,0 +1,177 @@ +package loop + +import ( + "context" + "fmt" + "time" + + "github.com/btcsuite/btcd/chaincfg" + "github.com/lightninglabs/lndclient" + "github.com/lightninglabs/loop/loopdb" + "github.com/lightninglabs/loop/swap" + "github.com/lightningnetwork/lnd/lntypes" + "github.com/lightningnetwork/lnd/lnwire" +) + +const ( + costMigrationID = "cost_migration" +) + +// CalculateLoopOutCost calculates the total cost of a loop out swap. It will +// correctly account for the on-chain and off-chain fees that were paid and +// make sure that all costs are positive. +func CalculateLoopOutCost(params *chaincfg.Params, loopOutSwap *loopdb.LoopOut, + paymentFees map[lntypes.Hash]lnwire.MilliSatoshi) (loopdb.SwapCost, + error) { + + // First make sure that this swap is actually finished. + if loopOutSwap.State().State.IsPending() { + return loopdb.SwapCost{}, fmt.Errorf("swap is not yet finished") + } + + // We first need to decode the prepay invoice to get the prepay hash and + // the prepay amount. + _, _, hash, prepayAmount, err := swap.DecodeInvoice( + params, loopOutSwap.Contract.PrepayInvoice, + ) + if err != nil { + return loopdb.SwapCost{}, fmt.Errorf("unable to decode the "+ + "prepay invoice: %v", err) + } + + // The swap hash is given and we don't need to get it from the + // swap invoice, however we'll decode it anyway to get the invoice amount + // that was paid in case we don't have the payment anymore. + _, _, swapHash, swapPaymentAmount, err := swap.DecodeInvoice( + params, loopOutSwap.Contract.SwapInvoice, + ) + if err != nil { + return loopdb.SwapCost{}, fmt.Errorf("unable to decode the "+ + "swap invoice: %v", err) + } + + var ( + cost loopdb.SwapCost + swapPaid, prepayPaid bool + ) + + // Now that we have the prepay and swap amount, we can calculate the + // total cost of the swap. Note that we only need to account for the + // server cost in case the swap was successful or if the sweep timed + // out. Otherwise the server didn't pull the off-chain htlc nor the + // prepay. + switch loopOutSwap.State().State { + case loopdb.StateSuccess: + cost.Server = swapPaymentAmount + prepayAmount - + loopOutSwap.Contract.AmountRequested + + swapPaid = true + prepayPaid = true + + case loopdb.StateFailSweepTimeout: + cost.Server = prepayAmount + + prepayPaid = true + + default: + cost.Server = 0 + } + + // Now attempt to look up the actual payments so we can calculate the + // total routing costs. + prepayPaymentFee, ok := paymentFees[hash] + if prepayPaid && ok { + cost.Offchain += prepayPaymentFee.ToSatoshis() + } else { + log.Debugf("Prepay payment %s is missing, won't account for "+ + "routing fees", hash) + } + + swapPaymentFee, ok := paymentFees[swapHash] + if swapPaid && ok { + cost.Offchain += swapPaymentFee.ToSatoshis() + } else { + log.Debugf("Swap payment %s is missing, won't account for "+ + "routing fees", swapHash) + } + + // For the on-chain cost, just make sure that the cost is positive. + cost.Onchain = loopOutSwap.State().Cost.Onchain + if cost.Onchain < 0 { + cost.Onchain *= -1 + } + + return cost, nil +} + +// MigrateLoopOutCosts will calculate the correct cost for all loop out swaps +// and override the cost values of the last update in the database. +func MigrateLoopOutCosts(ctx context.Context, lnd lndclient.LndServices, + db loopdb.SwapStore) error { + + migrationDone, err := db.HasMigration(ctx, costMigrationID) + if err != nil { + return err + } + if migrationDone { + log.Infof("Cost cleanup migration already done, skipping") + + return nil + } + + log.Infof("Starting cost cleanup migration") + startTs := time.Now() + defer func() { + log.Infof("Finished cost cleanup migration in %v", + time.Since(startTs)) + }() + + // First we'll fetch all loop out swaps from the database. + loopOutSwaps, err := db.FetchLoopOutSwaps(ctx) + if err != nil { + return err + } + + // Next we fetch all payments from LND. + payments, err := lnd.Client.ListPayments( + ctx, lndclient.ListPaymentsRequest{}, + ) + if err != nil { + return err + } + + // Gather payment fees to a map for easier lookup. + paymentFees := make(map[lntypes.Hash]lnwire.MilliSatoshi) + for _, payment := range payments.Payments { + paymentFees[payment.Hash] = payment.Fee + } + + // Now we'll calculate the cost for each swap and finally update the + // costs in the database. + updatedCosts := make(map[lntypes.Hash]loopdb.SwapCost) + for _, loopOutSwap := range loopOutSwaps { + cost, err := CalculateLoopOutCost( + lnd.ChainParams, loopOutSwap, paymentFees, + ) + if err != nil { + return err + } + + _, ok := updatedCosts[loopOutSwap.Hash] + if ok { + return fmt.Errorf("found a duplicate swap %v while "+ + "updating costs", loopOutSwap.Hash) + } + + updatedCosts[loopOutSwap.Hash] = cost + } + + log.Infof("Updating costs for %d loop out swaps", len(updatedCosts)) + err = db.BatchUpdateLoopOutSwapCosts(ctx, updatedCosts) + if err != nil { + return err + } + + // Finally mark the migration as done. + return db.SetMigration(ctx, costMigrationID) +} diff --git a/cost_migration_test.go b/cost_migration_test.go new file mode 100644 index 000000000..ee5525da5 --- /dev/null +++ b/cost_migration_test.go @@ -0,0 +1,184 @@ +package loop + +import ( + "context" + "testing" + "time" + + "github.com/btcsuite/btcd/btcutil" + "github.com/lightninglabs/lndclient" + "github.com/lightninglabs/loop/loopdb" + "github.com/lightninglabs/loop/test" + "github.com/lightningnetwork/lnd/lntypes" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/stretchr/testify/require" +) + +// TestCalculateLoopOutCost tests the CalculateLoopOutCost function. +func TestCalculateLoopOutCost(t *testing.T) { + // Set up test context objects. + lnd := test.NewMockLnd() + server := newServerMock(lnd) + store := loopdb.NewStoreMock(t) + + cfg := &swapConfig{ + lnd: &lnd.LndServices, + store: store, + server: server, + } + + height := int32(600) + req := *testRequest + initResult, err := newLoopOutSwap( + context.Background(), cfg, height, &req, + ) + require.NoError(t, err) + swap, err := store.FetchLoopOutSwap( + context.Background(), initResult.swap.hash, + ) + require.NoError(t, err) + + // Override the chain cost so it's negative. + const expectedChainCost = btcutil.Amount(1000) + + // Now we have the swap and prepay invoices so let's calculate the + // costs without providing the payments first, so we don't account for + // any routing fees. + paymentFees := make(map[lntypes.Hash]lnwire.MilliSatoshi) + _, err = CalculateLoopOutCost(lnd.ChainParams, swap, paymentFees) + + // We expect that the call fails as the swap isn't finished yet. + require.Error(t, err) + + // Override the swap state to make it look like the swap is finished + // and make the chain cost negative too, so we can test that it'll be + // corrected to be positive in the cost calculation. + swap.Events = append( + swap.Events, &loopdb.LoopEvent{ + SwapStateData: loopdb.SwapStateData{ + State: loopdb.StateSuccess, + Cost: loopdb.SwapCost{ + Onchain: -expectedChainCost, + }, + }, + }, + ) + costs, err := CalculateLoopOutCost(lnd.ChainParams, swap, paymentFees) + require.NoError(t, err) + + expectedServerCost := server.swapInvoiceAmt + server.prepayInvoiceAmt - + swap.Contract.AmountRequested + require.Equal(t, expectedServerCost, costs.Server) + require.Equal(t, btcutil.Amount(0), costs.Offchain) + require.Equal(t, expectedChainCost, costs.Onchain) + + // Now add the two payments to the payments map and calculate the costs + // again. We expect that the routng fees are now accounted for. + paymentFees[server.swapHash] = lnwire.NewMSatFromSatoshis(44) + paymentFees[server.prepayHash] = lnwire.NewMSatFromSatoshis(11) + + costs, err = CalculateLoopOutCost(lnd.ChainParams, swap, paymentFees) + require.NoError(t, err) + + expectedOffchainCost := btcutil.Amount(44 + 11) + require.Equal(t, expectedServerCost, costs.Server) + require.Equal(t, expectedOffchainCost, costs.Offchain) + require.Equal(t, expectedChainCost, costs.Onchain) + + // Now override the last update to make the swap timed out at the HTLC + // sweep. We expect that the chain cost won't change, and only the + // prepay will be accounted for. + swap.Events[0] = &loopdb.LoopEvent{ + SwapStateData: loopdb.SwapStateData{ + State: loopdb.StateFailSweepTimeout, + Cost: loopdb.SwapCost{ + Onchain: 0, + }, + }, + } + + costs, err = CalculateLoopOutCost(lnd.ChainParams, swap, paymentFees) + require.NoError(t, err) + + expectedServerCost = server.prepayInvoiceAmt + expectedOffchainCost = btcutil.Amount(11) + require.Equal(t, expectedServerCost, costs.Server) + require.Equal(t, expectedOffchainCost, costs.Offchain) + require.Equal(t, btcutil.Amount(0), costs.Onchain) +} + +// TestCostMigration tests the cost migration for loop out swaps. +func TestCostMigration(t *testing.T) { + // Set up test context objects. + lnd := test.NewMockLnd() + server := newServerMock(lnd) + store := loopdb.NewStoreMock(t) + + cfg := &swapConfig{ + lnd: &lnd.LndServices, + store: store, + server: server, + } + + height := int32(600) + req := *testRequest + initResult, err := newLoopOutSwap( + context.Background(), cfg, height, &req, + ) + require.NoError(t, err) + + // Override the chain cost so it's negative. + const expectedChainCost = btcutil.Amount(1000) + + // Override the swap state to make it look like the swap is finished + // and make the chain cost negative too, so we can test that it'll be + // corrected to be positive in the cost calculation. + err = store.UpdateLoopOut( + context.Background(), initResult.swap.hash, time.Now(), + loopdb.SwapStateData{ + State: loopdb.StateSuccess, + Cost: loopdb.SwapCost{ + Onchain: -expectedChainCost, + }, + }, + ) + require.NoError(t, err) + + // Add the two mocked payment to LND. Note that we only care about the + // fees here, so we don't need to provide the full payment details. + lnd.Payments = []lndclient.Payment{ + { + Hash: server.swapHash, + Fee: lnwire.NewMSatFromSatoshis(44), + }, + { + Hash: server.prepayHash, + Fee: lnwire.NewMSatFromSatoshis(11), + }, + } + + // Now we can run the migration. + err = MigrateLoopOutCosts(context.Background(), lnd.LndServices, store) + require.NoError(t, err) + + // Finally check that the swap cost has been updated correctly. + swap, err := store.FetchLoopOutSwap( + context.Background(), initResult.swap.hash, + ) + require.NoError(t, err) + + expectedServerCost := server.swapInvoiceAmt + server.prepayInvoiceAmt - + swap.Contract.AmountRequested + + costs := swap.Events[0].Cost + expectedOffchainCost := btcutil.Amount(44 + 11) + require.Equal(t, expectedServerCost, costs.Server) + require.Equal(t, expectedOffchainCost, costs.Offchain) + require.Equal(t, expectedChainCost, costs.Onchain) + + // Now run the migration again to make sure it doesn't fail. This also + // indicates that the migration did not run the second time as + // otherwise the store mocks SetMigration function would fail. + err = MigrateLoopOutCosts(context.Background(), lnd.LndServices, store) + require.NoError(t, err) +} diff --git a/loopdb/store_mock.go b/loopdb/store_mock.go index 268bba6b9..955ae5c4d 100644 --- a/loopdb/store_mock.go +++ b/loopdb/store_mock.go @@ -24,6 +24,8 @@ type StoreMock struct { loopInStoreChan chan LoopInContract loopInUpdateChan chan SwapStateData + migrations map[string]struct{} + t *testing.T } @@ -39,6 +41,7 @@ func NewStoreMock(t *testing.T) *StoreMock { loopInUpdateChan: make(chan SwapStateData, 1), LoopInSwaps: make(map[lntypes.Hash]*LoopInContract), LoopInUpdates: make(map[lntypes.Hash][]SwapStateData), + migrations: make(map[string]struct{}), t: t, } } @@ -364,12 +367,20 @@ func (s *StoreMock) BatchUpdateLoopOutSwapCosts(ctx context.Context, func (s *StoreMock) HasMigration(ctx context.Context, migrationID string) ( bool, error) { - return false, errUnimplemented + _, ok := s.migrations[migrationID] + + return ok, nil } // SetMigration marks the migration with the given ID as done. func (s *StoreMock) SetMigration(ctx context.Context, migrationID string) error { - return errUnimplemented + if _, ok := s.migrations[migrationID]; ok { + return errors.New("migration already done") + } + + s.migrations[migrationID] = struct{}{} + + return nil } diff --git a/server_mock_test.go b/server_mock_test.go index 9f52a2c07..f46a49774 100644 --- a/server_mock_test.go +++ b/server_mock_test.go @@ -42,6 +42,7 @@ type serverMock struct { swapInvoice string swapHash lntypes.Hash + prepayHash lntypes.Hash // preimagePush is a channel that preimage pushes are sent into. preimagePush chan lntypes.Preimage @@ -81,13 +82,17 @@ func (s *serverMock) NewLoopOutSwap(_ context.Context, swapHash lntypes.Hash, return nil, errors.New("unexpected test swap amount") } + s.swapHash = swapHash swapPayReqString, err := getInvoice(swapHash, s.swapInvoiceAmt, swapInvoiceDesc) if err != nil { return nil, err } - prePayReqString, err := getInvoice(swapHash, s.prepayInvoiceAmt, + // Set the prepay hash to be different from the swap hash. + s.prepayHash = swapHash + s.prepayHash[0] ^= 1 + prePayReqString, err := getInvoice(s.prepayHash, s.prepayInvoiceAmt, prepayInvoiceDesc) if err != nil { return nil, err