diff --git a/x/tier/keeper/lockup.go b/x/tier/keeper/lockup.go index 67885c8..65154f7 100644 --- a/x/tier/keeper/lockup.go +++ b/x/tier/keeper/lockup.go @@ -150,12 +150,19 @@ func (k Keeper) AddLockup(ctx context.Context, delAddr sdk.AccAddress, valAddr s func (k Keeper) SubtractLockup(ctx context.Context, delAddr sdk.AccAddress, valAddr sdk.ValAddress, amt math.Int) error { lockedAmt := k.GetLockupAmount(ctx, delAddr, valAddr) - lockedAmt, err := lockedAmt.SafeSub(amt) + // remove lockup record completely if subtracted amt is equal to lockedAmt + if amt.Equal(lockedAmt) { + k.removeLockup(ctx, delAddr, valAddr) + return nil + } + + // subtract amt from the lockedAmt othwerwise + newAmt, err := lockedAmt.SafeSub(amt) if err != nil { return errorsmod.Wrapf(err, "subtract %s from locked amount %s", amt, lockedAmt) } - k.SetLockup(ctx, false, delAddr, valAddr, lockedAmt, sdk.UnwrapSDKContext(ctx).BlockHeight(), nil, nil) + k.SetLockup(ctx, false, delAddr, valAddr, newAmt, sdk.UnwrapSDKContext(ctx).BlockHeight(), nil, nil) return nil } diff --git a/x/tier/keeper/lockup_test.go b/x/tier/keeper/lockup_test.go index d2db795..ca726b4 100644 --- a/x/tier/keeper/lockup_test.go +++ b/x/tier/keeper/lockup_test.go @@ -74,8 +74,14 @@ func TestSubtractLockup(t *testing.T) { err = k.SubtractLockup(ctx, delAddr, valAddr, math.NewInt(500)) require.NoError(t, err) - lockup := k.GetLockupAmount(ctx, delAddr, valAddr) - require.Equal(t, math.NewInt(500), lockup) + lockupAmt := k.GetLockupAmount(ctx, delAddr, valAddr) + require.Equal(t, math.NewInt(500), lockupAmt) + + err = k.SubtractLockup(ctx, delAddr, valAddr, math.NewInt(500)) + require.NoError(t, err) + + lockupAmt = k.GetLockupAmount(ctx, delAddr, valAddr) + require.True(t, lockupAmt.IsZero()) } func TestGetAllLockups(t *testing.T) { @@ -231,3 +237,49 @@ func TestTotalAmountByAddr(t *testing.T) { totalDel3 := k.TotalAmountByAddr(ctx, delAddr3) require.True(t, totalDel3.IsZero(), "delAddr3 should have no lockups") } + +func TestHasLockup(t *testing.T) { + k, ctx := testutil.SetupKeeper(t) + + delAddr, err := sdk.AccAddressFromBech32("source1m4f5a896t7fzd9vc7pfgmc3fxkj8n24s68fcw9") + require.NoError(t, err) + valAddr, err := sdk.ValAddressFromBech32("sourcevaloper1cy0p47z24ejzvq55pu3lesxwf73xnrnd0pzkqm") + require.NoError(t, err) + + require.False(t, k.HasLockup(ctx, delAddr, valAddr)) + + k.AddLockup(ctx, delAddr, valAddr, math.NewInt(100)) + require.True(t, k.HasLockup(ctx, delAddr, valAddr)) + + err = k.SubtractLockup(ctx, delAddr, valAddr, math.NewInt(100)) + require.NoError(t, err) + require.False(t, k.HasLockup(ctx, delAddr, valAddr), "lockup should no longer exist after removing the entire amount") +} + +func TestGetUnlockingLockup(t *testing.T) { + k, ctx := testutil.SetupKeeper(t) + + delAddr, err := sdk.AccAddressFromBech32("source1m4f5a896t7fzd9vc7pfgmc3fxkj8n24s68fcw9") + require.NoError(t, err) + valAddr, err := sdk.ValAddressFromBech32("sourcevaloper1cy0p47z24ejzvq55pu3lesxwf73xnrnd0pzkqm") + require.NoError(t, err) + + creationHeight := int64(42) + amount := math.NewInt(300) + unbondTime := time.Now().Add(24 * time.Hour).UTC() + unlockTime := time.Now().Add(48 * time.Hour).UTC() + + k.SetLockup(ctx, true, delAddr, valAddr, amount, creationHeight, &unbondTime, &unlockTime) + + found, amt, gotUnbondTime, gotUnlockTime := k.GetUnlockingLockup(ctx, delAddr, valAddr, creationHeight) + require.True(t, found, "unlocking lockup should be found") + require.Equal(t, amount, amt, "amount should match the one set") + require.Equal(t, unbondTime, gotUnbondTime, "unbondTime should match the one set") + require.Equal(t, unlockTime, gotUnlockTime, "unlockTime should match the one set") + + found, amt, gotUnbondTime, gotUnlockTime = k.GetUnlockingLockup(ctx, delAddr, valAddr, creationHeight+1) + require.False(t, found, "this unlocking lockup does not exist") + require.True(t, amt.IsZero(), "amount should be zero") + require.True(t, gotUnbondTime.IsZero(), "unbond time should be zero") + require.True(t, gotUnlockTime.IsZero(), "unlock time should be zero") +}