-
Notifications
You must be signed in to change notification settings - Fork 118
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #764 from bhandras/costs-cleanup-migration
loop: add migration to fix stored loop out costs
- Loading branch information
Showing
18 changed files
with
805 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,177 @@ | ||
package loop | ||
|
||
import ( | ||
"context" | ||
"fmt" | ||
"time" | ||
|
||
"github.com/btcsuite/btcd/chaincfg" | ||
"github.com/lightninglabs/lndclient" | ||
"github.com/lightninglabs/loop/loopdb" | ||
"github.com/lightninglabs/loop/swap" | ||
"github.com/lightningnetwork/lnd/lntypes" | ||
"github.com/lightningnetwork/lnd/lnwire" | ||
) | ||
|
||
const ( | ||
costMigrationID = "cost_migration" | ||
) | ||
|
||
// CalculateLoopOutCost calculates the total cost of a loop out swap. It will | ||
// correctly account for the on-chain and off-chain fees that were paid and | ||
// make sure that all costs are positive. | ||
func CalculateLoopOutCost(params *chaincfg.Params, loopOutSwap *loopdb.LoopOut, | ||
paymentFees map[lntypes.Hash]lnwire.MilliSatoshi) (loopdb.SwapCost, | ||
error) { | ||
|
||
// First make sure that this swap is actually finished. | ||
if loopOutSwap.State().State.IsPending() { | ||
return loopdb.SwapCost{}, fmt.Errorf("swap is not yet finished") | ||
} | ||
|
||
// We first need to decode the prepay invoice to get the prepay hash and | ||
// the prepay amount. | ||
_, _, hash, prepayAmount, err := swap.DecodeInvoice( | ||
params, loopOutSwap.Contract.PrepayInvoice, | ||
) | ||
if err != nil { | ||
return loopdb.SwapCost{}, fmt.Errorf("unable to decode the "+ | ||
"prepay invoice: %v", err) | ||
} | ||
|
||
// The swap hash is given and we don't need to get it from the | ||
// swap invoice, however we'll decode it anyway to get the invoice amount | ||
// that was paid in case we don't have the payment anymore. | ||
_, _, swapHash, swapPaymentAmount, err := swap.DecodeInvoice( | ||
params, loopOutSwap.Contract.SwapInvoice, | ||
) | ||
if err != nil { | ||
return loopdb.SwapCost{}, fmt.Errorf("unable to decode the "+ | ||
"swap invoice: %v", err) | ||
} | ||
|
||
var ( | ||
cost loopdb.SwapCost | ||
swapPaid, prepayPaid bool | ||
) | ||
|
||
// Now that we have the prepay and swap amount, we can calculate the | ||
// total cost of the swap. Note that we only need to account for the | ||
// server cost in case the swap was successful or if the sweep timed | ||
// out. Otherwise the server didn't pull the off-chain htlc nor the | ||
// prepay. | ||
switch loopOutSwap.State().State { | ||
case loopdb.StateSuccess: | ||
cost.Server = swapPaymentAmount + prepayAmount - | ||
loopOutSwap.Contract.AmountRequested | ||
|
||
swapPaid = true | ||
prepayPaid = true | ||
|
||
case loopdb.StateFailSweepTimeout: | ||
cost.Server = prepayAmount | ||
|
||
prepayPaid = true | ||
|
||
default: | ||
cost.Server = 0 | ||
} | ||
|
||
// Now attempt to look up the actual payments so we can calculate the | ||
// total routing costs. | ||
prepayPaymentFee, ok := paymentFees[hash] | ||
if prepayPaid && ok { | ||
cost.Offchain += prepayPaymentFee.ToSatoshis() | ||
} else { | ||
log.Debugf("Prepay payment %s is missing, won't account for "+ | ||
"routing fees", hash) | ||
} | ||
|
||
swapPaymentFee, ok := paymentFees[swapHash] | ||
if swapPaid && ok { | ||
cost.Offchain += swapPaymentFee.ToSatoshis() | ||
} else { | ||
log.Debugf("Swap payment %s is missing, won't account for "+ | ||
"routing fees", swapHash) | ||
} | ||
|
||
// For the on-chain cost, just make sure that the cost is positive. | ||
cost.Onchain = loopOutSwap.State().Cost.Onchain | ||
if cost.Onchain < 0 { | ||
cost.Onchain *= -1 | ||
} | ||
|
||
return cost, nil | ||
} | ||
|
||
// MigrateLoopOutCosts will calculate the correct cost for all loop out swaps | ||
// and override the cost values of the last update in the database. | ||
func MigrateLoopOutCosts(ctx context.Context, lnd lndclient.LndServices, | ||
db loopdb.SwapStore) error { | ||
|
||
migrationDone, err := db.HasMigration(ctx, costMigrationID) | ||
if err != nil { | ||
return err | ||
} | ||
if migrationDone { | ||
log.Infof("Cost cleanup migration already done, skipping") | ||
|
||
return nil | ||
} | ||
|
||
log.Infof("Starting cost cleanup migration") | ||
startTs := time.Now() | ||
defer func() { | ||
log.Infof("Finished cost cleanup migration in %v", | ||
time.Since(startTs)) | ||
}() | ||
|
||
// First we'll fetch all loop out swaps from the database. | ||
loopOutSwaps, err := db.FetchLoopOutSwaps(ctx) | ||
if err != nil { | ||
return err | ||
} | ||
|
||
// Next we fetch all payments from LND. | ||
payments, err := lnd.Client.ListPayments( | ||
ctx, lndclient.ListPaymentsRequest{}, | ||
) | ||
if err != nil { | ||
return err | ||
} | ||
|
||
// Gather payment fees to a map for easier lookup. | ||
paymentFees := make(map[lntypes.Hash]lnwire.MilliSatoshi) | ||
for _, payment := range payments.Payments { | ||
paymentFees[payment.Hash] = payment.Fee | ||
} | ||
|
||
// Now we'll calculate the cost for each swap and finally update the | ||
// costs in the database. | ||
updatedCosts := make(map[lntypes.Hash]loopdb.SwapCost) | ||
for _, loopOutSwap := range loopOutSwaps { | ||
cost, err := CalculateLoopOutCost( | ||
lnd.ChainParams, loopOutSwap, paymentFees, | ||
) | ||
if err != nil { | ||
return err | ||
} | ||
|
||
_, ok := updatedCosts[loopOutSwap.Hash] | ||
if ok { | ||
return fmt.Errorf("found a duplicate swap %v while "+ | ||
"updating costs", loopOutSwap.Hash) | ||
} | ||
|
||
updatedCosts[loopOutSwap.Hash] = cost | ||
} | ||
|
||
log.Infof("Updating costs for %d loop out swaps", len(updatedCosts)) | ||
err = db.BatchUpdateLoopOutSwapCosts(ctx, updatedCosts) | ||
if err != nil { | ||
return err | ||
} | ||
|
||
// Finally mark the migration as done. | ||
return db.SetMigration(ctx, costMigrationID) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,184 @@ | ||
package loop | ||
|
||
import ( | ||
"context" | ||
"testing" | ||
"time" | ||
|
||
"github.com/btcsuite/btcd/btcutil" | ||
"github.com/lightninglabs/lndclient" | ||
"github.com/lightninglabs/loop/loopdb" | ||
"github.com/lightninglabs/loop/test" | ||
"github.com/lightningnetwork/lnd/lntypes" | ||
"github.com/lightningnetwork/lnd/lnwire" | ||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
// TestCalculateLoopOutCost tests the CalculateLoopOutCost function. | ||
func TestCalculateLoopOutCost(t *testing.T) { | ||
// Set up test context objects. | ||
lnd := test.NewMockLnd() | ||
server := newServerMock(lnd) | ||
store := loopdb.NewStoreMock(t) | ||
|
||
cfg := &swapConfig{ | ||
lnd: &lnd.LndServices, | ||
store: store, | ||
server: server, | ||
} | ||
|
||
height := int32(600) | ||
req := *testRequest | ||
initResult, err := newLoopOutSwap( | ||
context.Background(), cfg, height, &req, | ||
) | ||
require.NoError(t, err) | ||
swap, err := store.FetchLoopOutSwap( | ||
context.Background(), initResult.swap.hash, | ||
) | ||
require.NoError(t, err) | ||
|
||
// Override the chain cost so it's negative. | ||
const expectedChainCost = btcutil.Amount(1000) | ||
|
||
// Now we have the swap and prepay invoices so let's calculate the | ||
// costs without providing the payments first, so we don't account for | ||
// any routing fees. | ||
paymentFees := make(map[lntypes.Hash]lnwire.MilliSatoshi) | ||
_, err = CalculateLoopOutCost(lnd.ChainParams, swap, paymentFees) | ||
|
||
// We expect that the call fails as the swap isn't finished yet. | ||
require.Error(t, err) | ||
|
||
// Override the swap state to make it look like the swap is finished | ||
// and make the chain cost negative too, so we can test that it'll be | ||
// corrected to be positive in the cost calculation. | ||
swap.Events = append( | ||
swap.Events, &loopdb.LoopEvent{ | ||
SwapStateData: loopdb.SwapStateData{ | ||
State: loopdb.StateSuccess, | ||
Cost: loopdb.SwapCost{ | ||
Onchain: -expectedChainCost, | ||
}, | ||
}, | ||
}, | ||
) | ||
costs, err := CalculateLoopOutCost(lnd.ChainParams, swap, paymentFees) | ||
require.NoError(t, err) | ||
|
||
expectedServerCost := server.swapInvoiceAmt + server.prepayInvoiceAmt - | ||
swap.Contract.AmountRequested | ||
require.Equal(t, expectedServerCost, costs.Server) | ||
require.Equal(t, btcutil.Amount(0), costs.Offchain) | ||
require.Equal(t, expectedChainCost, costs.Onchain) | ||
|
||
// Now add the two payments to the payments map and calculate the costs | ||
// again. We expect that the routng fees are now accounted for. | ||
paymentFees[server.swapHash] = lnwire.NewMSatFromSatoshis(44) | ||
paymentFees[server.prepayHash] = lnwire.NewMSatFromSatoshis(11) | ||
|
||
costs, err = CalculateLoopOutCost(lnd.ChainParams, swap, paymentFees) | ||
require.NoError(t, err) | ||
|
||
expectedOffchainCost := btcutil.Amount(44 + 11) | ||
require.Equal(t, expectedServerCost, costs.Server) | ||
require.Equal(t, expectedOffchainCost, costs.Offchain) | ||
require.Equal(t, expectedChainCost, costs.Onchain) | ||
|
||
// Now override the last update to make the swap timed out at the HTLC | ||
// sweep. We expect that the chain cost won't change, and only the | ||
// prepay will be accounted for. | ||
swap.Events[0] = &loopdb.LoopEvent{ | ||
SwapStateData: loopdb.SwapStateData{ | ||
State: loopdb.StateFailSweepTimeout, | ||
Cost: loopdb.SwapCost{ | ||
Onchain: 0, | ||
}, | ||
}, | ||
} | ||
|
||
costs, err = CalculateLoopOutCost(lnd.ChainParams, swap, paymentFees) | ||
require.NoError(t, err) | ||
|
||
expectedServerCost = server.prepayInvoiceAmt | ||
expectedOffchainCost = btcutil.Amount(11) | ||
require.Equal(t, expectedServerCost, costs.Server) | ||
require.Equal(t, expectedOffchainCost, costs.Offchain) | ||
require.Equal(t, btcutil.Amount(0), costs.Onchain) | ||
} | ||
|
||
// TestCostMigration tests the cost migration for loop out swaps. | ||
func TestCostMigration(t *testing.T) { | ||
// Set up test context objects. | ||
lnd := test.NewMockLnd() | ||
server := newServerMock(lnd) | ||
store := loopdb.NewStoreMock(t) | ||
|
||
cfg := &swapConfig{ | ||
lnd: &lnd.LndServices, | ||
store: store, | ||
server: server, | ||
} | ||
|
||
height := int32(600) | ||
req := *testRequest | ||
initResult, err := newLoopOutSwap( | ||
context.Background(), cfg, height, &req, | ||
) | ||
require.NoError(t, err) | ||
|
||
// Override the chain cost so it's negative. | ||
const expectedChainCost = btcutil.Amount(1000) | ||
|
||
// Override the swap state to make it look like the swap is finished | ||
// and make the chain cost negative too, so we can test that it'll be | ||
// corrected to be positive in the cost calculation. | ||
err = store.UpdateLoopOut( | ||
context.Background(), initResult.swap.hash, time.Now(), | ||
loopdb.SwapStateData{ | ||
State: loopdb.StateSuccess, | ||
Cost: loopdb.SwapCost{ | ||
Onchain: -expectedChainCost, | ||
}, | ||
}, | ||
) | ||
require.NoError(t, err) | ||
|
||
// Add the two mocked payment to LND. Note that we only care about the | ||
// fees here, so we don't need to provide the full payment details. | ||
lnd.Payments = []lndclient.Payment{ | ||
{ | ||
Hash: server.swapHash, | ||
Fee: lnwire.NewMSatFromSatoshis(44), | ||
}, | ||
{ | ||
Hash: server.prepayHash, | ||
Fee: lnwire.NewMSatFromSatoshis(11), | ||
}, | ||
} | ||
|
||
// Now we can run the migration. | ||
err = MigrateLoopOutCosts(context.Background(), lnd.LndServices, store) | ||
require.NoError(t, err) | ||
|
||
// Finally check that the swap cost has been updated correctly. | ||
swap, err := store.FetchLoopOutSwap( | ||
context.Background(), initResult.swap.hash, | ||
) | ||
require.NoError(t, err) | ||
|
||
expectedServerCost := server.swapInvoiceAmt + server.prepayInvoiceAmt - | ||
swap.Contract.AmountRequested | ||
|
||
costs := swap.Events[0].Cost | ||
expectedOffchainCost := btcutil.Amount(44 + 11) | ||
require.Equal(t, expectedServerCost, costs.Server) | ||
require.Equal(t, expectedOffchainCost, costs.Offchain) | ||
require.Equal(t, expectedChainCost, costs.Onchain) | ||
|
||
// Now run the migration again to make sure it doesn't fail. This also | ||
// indicates that the migration did not run the second time as | ||
// otherwise the store mocks SetMigration function would fail. | ||
err = MigrateLoopOutCosts(context.Background(), lnd.LndServices, store) | ||
require.NoError(t, err) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.