From e4849693370968b2ec1bbc78323bce0eb2bf475a Mon Sep 17 00:00:00 2001 From: Ivan Date: Thu, 19 Dec 2024 05:00:58 +0800 Subject: [PATCH] fix: Tier module fixes and improvements (part 3) (#10) * Add tier module events and attributes * Emit corresponding events on Lock, Unlock, Redelegate, and CompleteUnlocking * Improve CancelUnlocking to search for unbonding delegation entry by creationHeight * Add IterateValidators and TotalBondedTokens to expected staking keeper, define corresponding mocks * Fix calculation bug in TotalAmountByAddr caused by delAddr shadowing, add TestTotalAmountByAddr * Fix SubtractLockup to remove lockup if subtracting the whole locked amount, add TestHasLockup and TestGetUnlockingLockup * Refactor calculateCredit and related tests * Add SaveLockup, refactor SetLockup, fix GetLockup and removeUnlockingLockup * Add TestGetLockup and TestGetLockups, fix/improve existing tests * Use SaveLockup instead of SetLockup in InitGenesis * Update grpc query tests * Improve CancelUnlocking to support partial unlocks and fix bug in existing logic * Improve TestCancelUnlocking to verify subsequent and partial unlocks, fix/update existing tests * Improve SubtractLockup and SubtractUnlockingLockup logic to handle invalid amounts * Add TestSubtractUnlockingLockup, improve TestSubtractLockup * Improve Lock, Unlock, and Redelegate to handle invalid amounts, fix/improve related tests * Add ErrInvalidAmount * Fix typo in ErrUnauthorized, add setupMsgServer and basic TestMsgServer * Add TestMsgUpdateParams * Add TestMsgLock * Add TestMsgUnlock * Add TestMsgRedelegate * Update TestMsgLock, TestMsgUnlock, and TestMsgUpdateParams, fix error messages * Add TestMsgCancelUnlocking, minor fixes --- testutil/mocks.go | 25 ++ x/tier/keeper/calculate_credit.go | 43 ++++ x/tier/keeper/calculate_credit_test.go | 99 ++++++++ x/tier/keeper/credit.go | 41 +--- x/tier/keeper/credit_test.go | 92 ------- x/tier/keeper/grpc_query_test.go | 45 ++-- x/tier/keeper/keeper.go | 124 +++++++--- x/tier/keeper/keeper_mock_test.go | 11 +- x/tier/keeper/keeper_test.go | 147 +++++++++--- x/tier/keeper/lockup.go | 131 ++++++++-- x/tier/keeper/lockup_test.go | 264 +++++++++++++++++++-- x/tier/keeper/msg_cancel_unlocking_test.go | 181 ++++++++++++++ x/tier/keeper/msg_lock_test.go | 128 ++++++++++ x/tier/keeper/msg_redelegate_test.go | 158 ++++++++++++ x/tier/keeper/msg_server.go | 4 +- x/tier/keeper/msg_server_test.go | 24 ++ x/tier/keeper/msg_unlock_test.go | 153 ++++++++++++ x/tier/keeper/msg_update_params_test.go | 64 +++++ x/tier/module/genesis.go | 2 +- x/tier/types/errors.go | 3 +- x/tier/types/events.go | 13 +- x/tier/types/expected_keepers.go | 2 + x/tier/types/messages.go | 2 +- 23 files changed, 1495 insertions(+), 261 deletions(-) create mode 100644 x/tier/keeper/calculate_credit.go create mode 100644 x/tier/keeper/calculate_credit_test.go create mode 100644 x/tier/keeper/msg_cancel_unlocking_test.go create mode 100644 x/tier/keeper/msg_lock_test.go create mode 100644 x/tier/keeper/msg_redelegate_test.go create mode 100644 x/tier/keeper/msg_server_test.go create mode 100644 x/tier/keeper/msg_unlock_test.go create mode 100644 x/tier/keeper/msg_update_params_test.go diff --git a/testutil/mocks.go b/testutil/mocks.go index d94116b1..d2c9e2b0 100644 --- a/testutil/mocks.go +++ b/testutil/mocks.go @@ -233,6 +233,31 @@ func (mr *MockStakingKeeperRecorder) GetValidator(ctx, addr interface{}) *gomock return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetValidator", reflect.TypeOf((*MockStakingKeeper)(nil).GetValidator), ctx, addr) } +func (m *MockStakingKeeper) IterateValidators(ctx context.Context, fn func(index int64, validator stakingtypes.ValidatorI) (stop bool)) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IterateValidators", ctx, fn) + ret0, _ := ret[0].(error) + return ret0 +} + +func (mr *MockStakingKeeperRecorder) IterateValidators(ctx, fn interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IterateValidators", reflect.TypeOf((*MockStakingKeeper)(nil).IterateValidators), ctx, fn) +} + +func (m *MockStakingKeeper) TotalBondedTokens(ctx context.Context) (math.Int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "TotalBondedTokens", ctx) + ret0, _ := ret[0].(math.Int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +func (mr *MockStakingKeeperRecorder) TotalBondedTokens(ctx interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TotalBondedTokens", reflect.TypeOf((*MockStakingKeeper)(nil).TotalBondedTokens), ctx) +} + func (m *MockStakingKeeper) Delegate(ctx context.Context, delAddr sdk.AccAddress, bondAmt math.Int, tokenSrc stakingtypes.BondStatus, validator stakingtypes.Validator, subtractAccount bool) (math.LegacyDec, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Delegate", ctx, delAddr, bondAmt, tokenSrc, validator, subtractAccount) diff --git a/x/tier/keeper/calculate_credit.go b/x/tier/keeper/calculate_credit.go new file mode 100644 index 00000000..0d15fd6c --- /dev/null +++ b/x/tier/keeper/calculate_credit.go @@ -0,0 +1,43 @@ +package keeper + +import ( + "cosmossdk.io/math" + + "github.com/sourcenetwork/sourcehub/x/tier/types" +) + +// calculateCredit calculates the reward earned on the lockingAmt. +// lockingAmt is stacked up on top of the lockedAmt to earn at the highest eligible reward. +func calculateCredit(rateList []types.Rate, lockedAmt, lockingAmt math.Int) math.Int { + credit := math.ZeroInt() + stakedAmt := lockedAmt.Add(lockingAmt) + + // Iterate from the highest reward rate to the lowest. + for _, r := range rateList { + // Continue if the total lock does not reach the current rate requirement. + if stakedAmt.LT(r.Amount) { + continue + } + + lower := math.MaxInt(r.Amount, lockedAmt) + diff := stakedAmt.Sub(lower) + + diffDec := math.LegacyNewDecFromInt(diff) + rateDec := math.LegacyNewDec(r.Rate) + + // rateDec MUST have 2 decimals of precision for the calculation to be correct. + amt := diffDec.Mul(rateDec).Quo(math.LegacyNewDec(100)) + credit = credit.Add(amt.TruncateInt()) + + // Subtract the lock that has been rewarded. + stakedAmt = stakedAmt.Sub(diff) + lockingAmt = lockingAmt.Sub(diff) + + // Break if all the new lock has been rewarded. + if lockingAmt.IsZero() { + break + } + } + + return credit +} diff --git a/x/tier/keeper/calculate_credit_test.go b/x/tier/keeper/calculate_credit_test.go new file mode 100644 index 00000000..bf859ad8 --- /dev/null +++ b/x/tier/keeper/calculate_credit_test.go @@ -0,0 +1,99 @@ +package keeper + +import ( + "fmt" + "reflect" + "testing" + + "cosmossdk.io/math" + + "github.com/sourcenetwork/sourcehub/x/tier/types" +) + +func Test_CalculateCredit(t *testing.T) { + rateList := []types.Rate{ + {Amount: math.NewInt(300), Rate: 150}, + {Amount: math.NewInt(200), Rate: 120}, + {Amount: math.NewInt(100), Rate: 110}, + {Amount: math.NewInt(0), Rate: 100}, + } + + tests := []struct { + lockedAmt int64 + lockingAmt int64 + want int64 + }{ + { + lockedAmt: 100, + lockingAmt: 0, + want: 0, + }, + { + lockedAmt: 250, + lockingAmt: 0, + want: 0, + }, + { + lockedAmt: 0, + lockingAmt: 100, + want: 100, + }, + { + lockedAmt: 0, + lockingAmt: 200, + want: (100 * 1.0) + (100 * 1.1), + }, + { + lockedAmt: 0, + lockingAmt: 250, + want: (100 * 1.0) + (100 * 1.1) + (50 * 1.2), + }, + { + lockedAmt: 0, + lockingAmt: 300, + want: (100 * 1.0) + (100 * 1.1) + (100 * 1.2), + }, + { + lockedAmt: 0, + lockingAmt: 350, + want: (100 * 1.0) + (100 * 1.1) + (100 * 1.2) + (50 * 1.5), + }, + { + lockedAmt: 0, + lockingAmt: 600, + want: (100 * 1.0) + (100 * 1.1) + (100 * 1.2) + (300 * 1.5), + }, + { + lockedAmt: 100, + lockingAmt: 100, + want: (100 * 1.1), + }, + { + lockedAmt: 200, + lockingAmt: 100, + want: (100 * 1.2), + }, + { + lockedAmt: 150, + lockingAmt: 150, + want: (50 * 1.1) + (100 * 1.2), + }, + { + lockedAmt: 50, + lockingAmt: 400, + want: (50 * 1.0) + (100 * 1.1) + (100 * 1.2) + (150 * 1.5), + }, + } + for _, tt := range tests { + name := fmt.Sprintf("%d adds %d", tt.lockedAmt, tt.lockingAmt) + oldLock := math.NewInt(tt.lockedAmt) + newLock := math.NewInt(tt.lockingAmt) + want := math.NewInt(tt.want) + + t.Run(name, func(t *testing.T) { + if got := calculateCredit(rateList, oldLock, newLock); !reflect.DeepEqual(got, want) { + t.Errorf("calculateCredit() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/x/tier/keeper/credit.go b/x/tier/keeper/credit.go index 2c349e96..8a43854f 100644 --- a/x/tier/keeper/credit.go +++ b/x/tier/keeper/credit.go @@ -42,7 +42,7 @@ func (k Keeper) proratedCredit(ctx context.Context, delAddr sdk.AccAddress, lock // Calculate the reward credits earned on the new lock. rates := k.GetParams(ctx).RewardRates lockedAmt := k.TotalAmountByAddr(ctx, delAddr) - credit := CalculateCredit(rates, lockedAmt, lockingAmt) + credit := calculateCredit(rates, lockedAmt, lockingAmt) // Pro-rate the credit based on the time elapsed in the current epoch. epochInfo := k.epochsKeeper.GetEpochInfo(ctx, types.EpochIdentifier) @@ -116,7 +116,7 @@ func (k Keeper) resetAllCredits(ctx context.Context) error { for delStrAddr, amt := range lockedAmts { delAddr := sdk.MustAccAddressFromBech32(delStrAddr) - credit := CalculateCredit(rates, math.ZeroInt(), amt) + credit := calculateCredit(rates, math.ZeroInt(), amt) err := k.MintCredit(ctx, delAddr, credit) if err != nil { return errorsmod.Wrapf(err, "mint %s to %s", credit, delAddr) @@ -125,40 +125,3 @@ func (k Keeper) resetAllCredits(ctx context.Context) error { return nil } - -// CalculateCredit calculates the reward earned on the lockingAmt. -// lockingAmt is stacked up on top of the lockedAmt to earn at the -// highest eligible reward. -func CalculateCredit(rateList []types.Rate, lockedAmt, lockingAmt math.Int) math.Int { - credit := math.ZeroInt() - stakedAmt := lockedAmt.Add(lockingAmt) - - // Iterate from the highest reward rate to the lowest. - for _, r := range rateList { - // Continue if the total lock does not reach the current rate requirement. - if stakedAmt.LT(r.Amount) { - continue - } - - lower := math.MaxInt(r.Amount, lockedAmt) - diff := stakedAmt.Sub(lower) - - diffDec := math.LegacyNewDecFromInt(diff) - rateDec := math.LegacyNewDec(r.Rate) - - // rateDec MUST have 2 decimals of precision for the calculation to be correct. - amt := diffDec.Mul(rateDec).Quo(math.LegacyNewDec(100)) - credit = credit.Add(amt.TruncateInt()) - - // Subtract the lock that has been rewarded. - stakedAmt = stakedAmt.Sub(diff) - lockingAmt = lockingAmt.Sub(diff) - - // Break if all the new lock has been rewarded. - if lockingAmt.IsZero() { - break - } - } - - return credit -} diff --git a/x/tier/keeper/credit_test.go b/x/tier/keeper/credit_test.go index 5c102180..cb40d2ad 100644 --- a/x/tier/keeper/credit_test.go +++ b/x/tier/keeper/credit_test.go @@ -1,106 +1,14 @@ package keeper_test import ( - "fmt" - "reflect" "testing" "cosmossdk.io/math" sdk "github.com/cosmos/cosmos-sdk/types" testutil "github.com/sourcenetwork/sourcehub/testutil" - "github.com/sourcenetwork/sourcehub/x/tier/keeper" - "github.com/sourcenetwork/sourcehub/x/tier/types" ) -func Test_CalculateCredit(t *testing.T) { - rateList := []types.Rate{ - {Amount: math.NewInt(300), Rate: 150}, - {Amount: math.NewInt(200), Rate: 120}, - {Amount: math.NewInt(100), Rate: 110}, - {Amount: math.NewInt(0), Rate: 100}, - } - - tests := []struct { - lockedAmt int64 - lockingAmt int64 - want int64 - }{ - { - lockedAmt: 100, - lockingAmt: 0, - want: 0, - }, - { - lockedAmt: 250, - lockingAmt: 0, - want: 0, - }, - { - lockedAmt: 0, - lockingAmt: 100, - want: 100, - }, - { - lockedAmt: 0, - lockingAmt: 200, - want: (100 * 1.0) + (100 * 1.1), - }, - { - lockedAmt: 0, - lockingAmt: 250, - want: (100 * 1.0) + (100 * 1.1) + (50 * 1.2), - }, - { - lockedAmt: 0, - lockingAmt: 300, - want: (100 * 1.0) + (100 * 1.1) + (100 * 1.2), - }, - { - lockedAmt: 0, - lockingAmt: 350, - want: (100 * 1.0) + (100 * 1.1) + (100 * 1.2) + (50 * 1.5), - }, - { - lockedAmt: 0, - lockingAmt: 600, - want: (100 * 1.0) + (100 * 1.1) + (100 * 1.2) + (300 * 1.5), - }, - { - lockedAmt: 100, - lockingAmt: 100, - want: (100 * 1.1), - }, - { - lockedAmt: 200, - lockingAmt: 100, - want: (100 * 1.2), - }, - { - lockedAmt: 150, - lockingAmt: 150, - want: (50 * 1.1) + (100 * 1.2), - }, - { - lockedAmt: 50, - lockingAmt: 400, - want: (50 * 1.0) + (100 * 1.1) + (100 * 1.2) + (150 * 1.5), - }, - } - for _, tt := range tests { - name := fmt.Sprintf("%d adds %d", tt.lockedAmt, tt.lockingAmt) - oldLock := math.NewInt(tt.lockedAmt) - newLock := math.NewInt(tt.lockingAmt) - want := math.NewInt(tt.want) - - t.Run(name, func(t *testing.T) { - if got := keeper.CalculateCredit(rateList, oldLock, newLock); !reflect.DeepEqual(got, want) { - t.Errorf("CalculateCredit() = %v, want %v", got, tt.want) - } - }) - } -} - func Test_MintCredit(t *testing.T) { tests := []struct { name string diff --git a/x/tier/keeper/grpc_query_test.go b/x/tier/keeper/grpc_query_test.go index 0d278ddb..9429f8f9 100644 --- a/x/tier/keeper/grpc_query_test.go +++ b/x/tier/keeper/grpc_query_test.go @@ -27,15 +27,13 @@ func TestParamsQuery(t *testing.T) { func TestLockupQuery(t *testing.T) { keeper, ctx := keepertest.TierKeeper(t) + amount := math.NewInt(1000) delAddr, err := sdk.AccAddressFromBech32("source1wjj5v5rlf57kayyeskncpu4hwev25ty645p2et") require.NoError(t, err) - valAddr, err := sdk.ValAddressFromBech32("sourcevaloper1cy0p47z24ejzvq55pu3lesxwf73xnrnd0pzkqm") require.NoError(t, err) - amount := math.NewInt(1000) - keeper.AddLockup(ctx, delAddr, valAddr, amount) querier := tierkeeper.NewQuerier(keeper) @@ -56,16 +54,14 @@ func TestLockupQuery(t *testing.T) { func TestLockupsQuery(t *testing.T) { keeper, ctx := keepertest.TierKeeper(t) + amount1 := math.NewInt(1000) + amount2 := math.NewInt(500) delAddr, err := sdk.AccAddressFromBech32("source1wjj5v5rlf57kayyeskncpu4hwev25ty645p2et") require.NoError(t, err) - valAddr, err := sdk.ValAddressFromBech32("sourcevaloper1cy0p47z24ejzvq55pu3lesxwf73xnrnd0pzkqm") require.NoError(t, err) - amount1 := math.NewInt(1000) - amount2 := math.NewInt(500) - keeper.AddLockup(ctx, delAddr, valAddr, amount1) keeper.AddLockup(ctx, delAddr, valAddr, amount2) @@ -81,20 +77,21 @@ func TestLockupsQuery(t *testing.T) { func TestUnlockingLockupQuery(t *testing.T) { keeper, ctx := keepertest.TierKeeper(t) + params := keeper.GetParams(ctx) + epochDuration := *params.EpochDuration + amount := math.NewInt(1000) delAddr, err := sdk.AccAddressFromBech32("source1wjj5v5rlf57kayyeskncpu4hwev25ty645p2et") require.NoError(t, err) - valAddr, err := sdk.ValAddressFromBech32("sourcevaloper1cy0p47z24ejzvq55pu3lesxwf73xnrnd0pzkqm") require.NoError(t, err) - amount := math.NewInt(1000) + ctx = ctx.WithBlockHeight(1).WithBlockTime(time.Now()) - // normalize time to UTC within SetLockup() so unbondTime/unlockTime can be local - unbondTime := time.Now().Add(24 * time.Hour) - unlockTime := time.Now().Add(48 * time.Hour) + keeper.SetLockup(ctx, true, delAddr, valAddr, amount, nil) - keeper.SetLockup(ctx, true, delAddr, valAddr, amount, 1, &unbondTime, &unlockTime) + unbondTime := ctx.BlockTime().Add(epochDuration * time.Duration(params.UnlockingEpochs)) + unlockTime := unbondTime querier := tierkeeper.NewQuerier(keeper) response, err := querier.UnlockingLockup(ctx, &types.UnlockingLockupRequest{ @@ -121,23 +118,27 @@ func TestUnlockingLockupQuery(t *testing.T) { func TestUnlockingLockupsQuery(t *testing.T) { keeper, ctx := keepertest.TierKeeper(t) + params := keeper.GetParams(ctx) + epochDuration := *params.EpochDuration + amount1 := math.NewInt(1000) + amount2 := math.NewInt(500) delAddr, err := sdk.AccAddressFromBech32("source1wjj5v5rlf57kayyeskncpu4hwev25ty645p2et") require.NoError(t, err) - valAddr, err := sdk.ValAddressFromBech32("sourcevaloper1cy0p47z24ejzvq55pu3lesxwf73xnrnd0pzkqm") require.NoError(t, err) - amount1 := math.NewInt(1000) - amount2 := math.NewInt(500) + ctx = ctx.WithBlockHeight(1).WithBlockTime(time.Now()) + keeper.SetLockup(ctx, true, delAddr, valAddr, amount1, nil) + + unbondTime1 := ctx.BlockTime().Add(epochDuration * time.Duration(params.UnlockingEpochs)) + unlockTime1 := unbondTime1 - unbondTime1 := time.Now().Add(24 * time.Hour).UTC() - unlockTime1 := time.Now().Add(48 * time.Hour).UTC() - unbondTime2 := time.Now().Add(36 * time.Hour).UTC() - unlockTime2 := time.Now().Add(72 * time.Hour).UTC() + ctx = ctx.WithBlockHeight(2).WithBlockTime(unbondTime1) + keeper.SetLockup(ctx, true, delAddr, valAddr, amount2, nil) - keeper.SetLockup(ctx, true, delAddr, valAddr, amount1, 1, &unbondTime1, &unlockTime1) - keeper.SetLockup(ctx, true, delAddr, valAddr, amount2, 2, &unbondTime2, &unlockTime2) + unbondTime2 := ctx.BlockTime().Add(epochDuration * time.Duration(params.UnlockingEpochs)) + unlockTime2 := unbondTime2 querier := tierkeeper.NewQuerier(keeper) response, err := querier.UnlockingLockups(ctx, &types.UnlockingLockupsRequest{ diff --git a/x/tier/keeper/keeper.go b/x/tier/keeper/keeper.go index 6d8f2394..69636f52 100644 --- a/x/tier/keeper/keeper.go +++ b/x/tier/keeper/keeper.go @@ -89,6 +89,8 @@ func (k Keeper) Logger() log.Logger { // CompleteUnlocking completes the unlocking process for all lockups that have reached their unlock time. // It is called at the end of each Epoch. func (k Keeper) CompleteUnlocking(ctx context.Context) error { + sdkCtx := sdk.UnwrapSDKContext(ctx) + cb := func(delAddr sdk.AccAddress, valAddr sdk.ValAddress, creationHeight int64, lockup types.Lockup) error { if sdk.UnwrapSDKContext(ctx).BlockTime().Before(*lockup.UnlockTime) { fmt.Printf("Unlock time not reached for %s/%s\n", delAddr, valAddr) @@ -110,7 +112,17 @@ func (k Keeper) CompleteUnlocking(ctx context.Context) error { return errorsmod.Wrapf(err, "undelegate coins to %s for amount %s", delAddr, stake) } - k.removeUnlockingLockup(ctx, delAddr, valAddr) + k.removeUnlockingLockup(ctx, delAddr, valAddr, creationHeight) + + sdkCtx.EventManager().EmitEvent( + sdk.NewEvent( + types.EventTypeCompleteUnlocking, + sdk.NewAttribute(stakingtypes.AttributeKeyDelegator, delAddr.String()), + sdk.NewAttribute(stakingtypes.AttributeKeyValidator, valAddr.String()), + sdk.NewAttribute(sdk.AttributeKeyAmount, lockup.Amount.String()), + sdk.NewAttribute(types.AttributeKeyCreationHeight, fmt.Sprintf("%d", creationHeight)), + ), + ) return nil } @@ -125,6 +137,14 @@ func (k Keeper) CompleteUnlocking(ctx context.Context) error { // Lock locks the stake of a delegator to a validator. func (k Keeper) Lock(ctx context.Context, delAddr sdk.AccAddress, valAddr sdk.ValAddress, amt math.Int) error { + // specified amt must be a positive integer + if !amt.IsPositive() { + return types.ErrInvalidAmount.Wrap("invalid amount") + } + + sdkCtx := sdk.UnwrapSDKContext(ctx) + modAddr := authtypes.NewModuleAddress(types.ModuleName) + validator, err := k.stakingKeeper.GetValidator(ctx, valAddr) if err != nil { return types.ErrInvalidAddress.Wrapf("validator address %s: %s", valAddr, err) @@ -139,7 +159,6 @@ func (k Keeper) Lock(ctx context.Context, delAddr sdk.AccAddress, valAddr sdk.Va } // Delegate the stake to the validator. - modAddr := authtypes.NewModuleAddress(types.ModuleName) _, err = k.stakingKeeper.Delegate(ctx, modAddr, stake.Amount, stakingtypes.Unbonded, validator, true) if err != nil { return errorsmod.Wrapf(err, "delegate %s", stake) @@ -155,6 +174,15 @@ func (k Keeper) Lock(ctx context.Context, delAddr sdk.AccAddress, valAddr sdk.Va return errorsmod.Wrap(err, "mint credit") } + sdkCtx.EventManager().EmitEvent( + sdk.NewEvent( + types.EventTypeLock, + sdk.NewAttribute(stakingtypes.AttributeKeyDelegator, delAddr.String()), + sdk.NewAttribute(stakingtypes.AttributeKeyValidator, valAddr.String()), + sdk.NewAttribute(sdk.AttributeKeyAmount, amt.String()), + ), + ) + return nil } @@ -163,25 +191,24 @@ func (k Keeper) Lock(ctx context.Context, delAddr sdk.AccAddress, valAddr sdk.Va func (k Keeper) Unlock(ctx context.Context, delAddr sdk.AccAddress, valAddr sdk.ValAddress, amt math.Int) ( unbondTime time.Time, unlockTime time.Time, creationHeight int64, err error) { + // specified amt must be a positive integer + if !amt.IsPositive() { + return time.Time{}, time.Time{}, 0, types.ErrInvalidAmount.Wrap("invalid amount") + } + + sdkCtx := sdk.UnwrapSDKContext(ctx) + modAddr := authtypes.NewModuleAddress(types.ModuleName) + validator, err := k.stakingKeeper.GetValidator(ctx, valAddr) if err != nil { return time.Time{}, time.Time{}, 0, types.ErrInvalidAddress.Wrapf("validator address %s: %s", valAddr, err) } - sdkCtx := sdk.UnwrapSDKContext(ctx) - creationHeight = sdkCtx.BlockHeight() - err = k.SubtractLockup(ctx, delAddr, valAddr, amt) if err != nil { return time.Time{}, time.Time{}, 0, errorsmod.Wrap(err, "subtract lockup") } - params := k.GetParams(ctx) - epochDuration := params.EpochDuration - unlockingDuration := time.Duration(params.UnlockingEpochs) * *epochDuration - unlockTime = sdkCtx.BlockTime().Add(unlockingDuration) - modAddr := authtypes.NewModuleAddress(types.ModuleName) - shares, err := k.stakingKeeper.ValidateUnbondAmount(ctx, modAddr, valAddr, amt) if err != nil { return time.Time{}, time.Time{}, 0, errorsmod.Wrap(err, "validate unbond amount") @@ -202,9 +229,19 @@ func (k Keeper) Unlock(ctx context.Context, delAddr sdk.AccAddress, valAddr sdk. return time.Time{}, time.Time{}, 0, errorsmod.Wrap(err, "undelegate") } - k.removeLockup(ctx, delAddr, valAddr) + creationHeight, _, unlockTime = k.SetLockup(ctx, true, delAddr, valAddr, amt, &unbondTime) - k.SetLockup(ctx, true, delAddr, valAddr, amt, creationHeight, &unbondTime, &unlockTime) + sdkCtx.EventManager().EmitEvent( + sdk.NewEvent( + types.EventTypeUnlock, + sdk.NewAttribute(stakingtypes.AttributeKeyDelegator, delAddr.String()), + sdk.NewAttribute(stakingtypes.AttributeKeyValidator, valAddr.String()), + sdk.NewAttribute(sdk.AttributeKeyAmount, amt.String()), + sdk.NewAttribute(types.AttributeKeyUnbondTime, unbondTime.String()), + sdk.NewAttribute(types.AttributeKeyUnlockTime, unlockTime.String()), + sdk.NewAttribute(types.AttributeKeyCreationHeight, fmt.Sprintf("%d", creationHeight)), + ), + ) return unbondTime, unlockTime, creationHeight, nil } @@ -214,6 +251,14 @@ func (k Keeper) Unlock(ctx context.Context, delAddr sdk.AccAddress, valAddr sdk. func (k Keeper) Redelegate(ctx context.Context, delAddr sdk.AccAddress, srcValAddr, dstValAddr sdk.ValAddress, amt math.Int) ( completionTime time.Time, err error) { + // specified amt must be a positive integer + if !amt.IsPositive() { + return time.Time{}, types.ErrInvalidAmount.Wrap("invalid amount") + } + + sdkCtx := sdk.UnwrapSDKContext(ctx) + modAddr := authtypes.NewModuleAddress(types.ModuleName) + err = k.SubtractLockup(ctx, delAddr, srcValAddr, amt) if err != nil { return time.Time{}, errorsmod.Wrap(err, "subtract locked stake from source validator") @@ -221,8 +266,6 @@ func (k Keeper) Redelegate(ctx context.Context, delAddr sdk.AccAddress, srcValAd k.AddLockup(ctx, delAddr, dstValAddr, amt) - modAddr := authtypes.NewModuleAddress(types.ModuleName) - shares, err := k.stakingKeeper.ValidateUnbondAmount(ctx, modAddr, srcValAddr, amt) if err != nil { return time.Time{}, errorsmod.Wrap(err, "validate unbond amount") @@ -233,14 +276,25 @@ func (k Keeper) Redelegate(ctx context.Context, delAddr sdk.AccAddress, srcValAd return time.Time{}, errorsmod.Wrap(err, "begin redelegation") } + sdkCtx.EventManager().EmitEvent( + sdk.NewEvent( + types.EventTypeRedelegate, + sdk.NewAttribute(stakingtypes.AttributeKeyDelegator, delAddr.String()), + sdk.NewAttribute(types.AttributeKeySourceValidator, srcValAddr.String()), + sdk.NewAttribute(types.AttributeKeyDestinationValidator, dstValAddr.String()), + sdk.NewAttribute(sdk.AttributeKeyAmount, amt.String()), + sdk.NewAttribute(types.AttributeKeyCompletionTime, completionTime.String()), + ), + ) + return completionTime, nil } -// CancelUnlocking effectively cancels the pending unlocking lockup. -func (k Keeper) CancelUnlocking(ctx context.Context, delAddr sdk.AccAddress, valAddr sdk.ValAddress, amt math.Int) error { +// CancelUnlocking effectively cancels the pending unlocking lockup partially or in full. +// Reverts the specified amt if a valid value is provided (e.g. amt != nil && 0 < amt < unbondEntry.Balance). +// Otherwise, cancels unlocking lockup record in full (e.g. unbondEntry.Balance). +func (k Keeper) CancelUnlocking(ctx context.Context, delAddr sdk.AccAddress, valAddr sdk.ValAddress, creationHeight int64, amt *math.Int) error { sdkCtx := sdk.UnwrapSDKContext(ctx) - - // use the module account address when interacting with unbonding delegations modAddr := authtypes.NewModuleAddress(types.ModuleName) validator, err := k.stakingKeeper.GetValidator(ctx, valAddr) @@ -253,14 +307,15 @@ func (k Keeper) CancelUnlocking(ctx context.Context, delAddr sdk.AccAddress, val return errorsmod.Wrapf(err, "unbonding delegation not found for delegator %s and validator %s", modAddr, valAddr) } - // search for the unbonding delegation entry that matches the amount and ensure it's valid + // find unbonding delegation entry by CreationHeight + // TODO: handle edge case with 2+ messages at the same height var ( unbondEntryIndex int64 = -1 unbondEntry stakingtypes.UnbondingDelegationEntry ) for i, entry := range ubd.Entries { - if entry.Balance.GTE(amt) && entry.CompletionTime.After(sdkCtx.BlockTime()) { + if entry.CreationHeight == creationHeight && entry.CompletionTime.After(sdkCtx.BlockTime()) { unbondEntryIndex = int64(i) unbondEntry = entry break @@ -268,21 +323,31 @@ func (k Keeper) CancelUnlocking(ctx context.Context, delAddr sdk.AccAddress, val } if unbondEntryIndex == -1 { - return errorsmod.Wrapf(stakingtypes.ErrNoUnbondingDelegation, "no valid unbonding entry found for amount %s", amt) + return errorsmod.Wrapf( + stakingtypes.ErrNoUnbondingDelegation, + "no valid unbonding entry found for creation height %d", + creationHeight, + ) + } + + // revert the specified amt if set and is positive, otherwise revert the entire UnbondingDelegationEntry + restoreAmount := unbondEntry.Balance + if amt != nil && amt.IsPositive() && amt.LT(unbondEntry.Balance) { + restoreAmount = *amt } - _, err = k.stakingKeeper.Delegate(ctx, modAddr, amt, stakingtypes.Unbonding, validator, false) + _, err = k.stakingKeeper.Delegate(ctx, modAddr, restoreAmount, stakingtypes.Unbonding, validator, false) if err != nil { return errorsmod.Wrap(err, "failed to delegate tokens back to validator") } // update or remove the unbonding delegation entry - remainingBalance := unbondEntry.Balance.Sub(amt) + remainingBalance := unbondEntry.Balance.Sub(restoreAmount) if remainingBalance.IsZero() { ubd.RemoveEntry(unbondEntryIndex) } else { unbondEntry.Balance = remainingBalance - unbondEntry.InitialBalance = unbondEntry.InitialBalance.Sub(amt) + unbondEntry.InitialBalance = unbondEntry.InitialBalance.Sub(restoreAmount) ubd.Entries[unbondEntryIndex] = unbondEntry } @@ -296,14 +361,19 @@ func (k Keeper) CancelUnlocking(ctx context.Context, delAddr sdk.AccAddress, val return errorsmod.Wrap(err, "failed to update unbonding delegation") } - k.AddLockup(ctx, delAddr, valAddr, amt) + // remove unlocking lockup if no amt was specified (e.g. no partial unlocking lockup cancelation) + k.SubtractUnlockingLockup(ctx, delAddr, valAddr, creationHeight, restoreAmount) + + // add restoreAmount back to the lockup (without modifying the unlock/unbond times) + k.AddLockup(ctx, delAddr, valAddr, restoreAmount) sdkCtx.EventManager().EmitEvent( sdk.NewEvent( types.EventTypeCancelUnlocking, sdk.NewAttribute(stakingtypes.AttributeKeyDelegator, delAddr.String()), sdk.NewAttribute(stakingtypes.AttributeKeyValidator, valAddr.String()), - sdk.NewAttribute(sdk.AttributeKeyAmount, amt.String()), + sdk.NewAttribute(sdk.AttributeKeyAmount, restoreAmount.String()), + sdk.NewAttribute(types.AttributeKeyCreationHeight, fmt.Sprintf("%d", creationHeight)), ), ) diff --git a/x/tier/keeper/keeper_mock_test.go b/x/tier/keeper/keeper_mock_test.go index 4a129e7f..5f246d25 100644 --- a/x/tier/keeper/keeper_mock_test.go +++ b/x/tier/keeper/keeper_mock_test.go @@ -81,7 +81,6 @@ func (suite *KeeperTestSuite) TestLock() { delAddr, err := sdk.AccAddressFromBech32("source1wjj5v5rlf57kayyeskncpu4hwev25ty645p2et") suite.Require().NoError(err) - valAddr, err := sdk.ValAddressFromBech32("sourcevaloper1cy0p47z24ejzvq55pu3lesxwf73xnrnd0pzkqm") suite.Require().NoError(err) @@ -138,7 +137,6 @@ func (suite *KeeperTestSuite) TestUnlock() { delAddr, err := sdk.AccAddressFromBech32("source1wjj5v5rlf57kayyeskncpu4hwev25ty645p2et") suite.Require().NoError(err) - valAddr, err := sdk.ValAddressFromBech32("sourcevaloper1cy0p47z24ejzvq55pu3lesxwf73xnrnd0pzkqm") suite.Require().NoError(err) @@ -186,6 +184,11 @@ func (suite *KeeperTestSuite) TestUnlock() { suite.tierKeeper.SetParams(suite.ctx, params) + // add a lockup and verify that it exists before trying to unlock + suite.tierKeeper.AddLockup(suite.ctx, delAddr, valAddr, amount) + lockedAmt := suite.tierKeeper.GetLockupAmount(suite.ctx, delAddr, valAddr) + suite.Require().Equal(amount, lockedAmt, "expected lockup amount to be set") + // perform unlock and verify that unlocking lockup is set correctly unbondTime, unlockTime, creationHeight, err := suite.tierKeeper.Unlock(suite.ctx, delAddr, valAddr, amount) suite.Require().NoError(err) @@ -200,14 +203,13 @@ func (suite *KeeperTestSuite) TestUnlock() { // TestRedelegate is using mock keepers to verify that required function calls are made as expected on Redelegate(). func (suite *KeeperTestSuite) TestRedelegate() { amount := math.NewInt(1000) + completionTime := suite.ctx.BlockTime() shares := math.LegacyNewDecFromInt(amount) delAddr, err := sdk.AccAddressFromBech32("source1wjj5v5rlf57kayyeskncpu4hwev25ty645p2et") suite.Require().NoError(err) - srcValAddr, err := sdk.ValAddressFromBech32("sourcevaloper1cy0p47z24ejzvq55pu3lesxwf73xnrnd0pzkqm") suite.Require().NoError(err) - dstValAddr, err := sdk.ValAddressFromBech32("sourcevaloper13fj7t2yptf9k6ad6fv38434znzay4s4pjk0r4f") suite.Require().NoError(err) @@ -218,7 +220,6 @@ func (suite *KeeperTestSuite) TestRedelegate() { ValidateUnbondAmount(gomock.Any(), authtypes.NewModuleAddress(types.ModuleName), srcValAddr, amount). Return(shares, nil).Times(1) - completionTime := suite.ctx.BlockTime().Add(24 * time.Hour) suite.stakingKeeper.EXPECT(). BeginRedelegation(gomock.Any(), authtypes.NewModuleAddress(types.ModuleName), srcValAddr, dstValAddr, shares). Return(completionTime, nil).Times(1) diff --git a/x/tier/keeper/keeper_test.go b/x/tier/keeper/keeper_test.go index 673e8274..9dbb80f6 100644 --- a/x/tier/keeper/keeper_test.go +++ b/x/tier/keeper/keeper_test.go @@ -43,19 +43,29 @@ func initializeDelegator(t *testing.T, k *tierkeeper.Keeper, ctx sdk.Context, de func TestLock(t *testing.T) { k, ctx := testutil.SetupKeeper(t) + amount := math.NewInt(1000) + invalidAmount := math.NewInt(-100) + delAddr, err := sdk.AccAddressFromBech32("source1wjj5v5rlf57kayyeskncpu4hwev25ty645p2et") require.NoError(t, err) - valAddr, err := sdk.ValAddressFromBech32("sourcevaloper1cy0p47z24ejzvq55pu3lesxwf73xnrnd0pzkqm") require.NoError(t, err) initialDelegatorBalance := math.NewInt(2000) initializeDelegator(t, k, ctx, delAddr, initialDelegatorBalance) - initialValidatorBalance := math.NewInt(1000) initializeValidator(t, k.GetStakingKeeper().(*keeper.Keeper), ctx, valAddr, initialValidatorBalance) - amount := math.NewInt(1000) + // set initial block height and time + ctx = ctx.WithBlockHeight(1).WithBlockTime(time.Now()) + + // locking invalid amounts should fail + err = k.Lock(ctx, delAddr, valAddr, invalidAmount) + require.Error(t, err) + err = k.Lock(ctx, delAddr, valAddr, math.ZeroInt()) + require.Error(t, err) + + // lock valid amount err = k.Lock(ctx, delAddr, valAddr, amount) require.NoError(t, err) @@ -68,37 +78,47 @@ func TestLock(t *testing.T) { func TestUnlock(t *testing.T) { k, ctx := testutil.SetupKeeper(t) + lockAmount := math.NewInt(1000) + unlockAmount := math.NewInt(500) + invalidUnlockAmount := math.NewInt(-500) + delAddr, err := sdk.AccAddressFromBech32("source1wjj5v5rlf57kayyeskncpu4hwev25ty645p2et") require.NoError(t, err) - valAddr, err := sdk.ValAddressFromBech32("sourcevaloper1cy0p47z24ejzvq55pu3lesxwf73xnrnd0pzkqm") require.NoError(t, err) initialDelegatorBalance := math.NewInt(2000) initializeDelegator(t, k, ctx, delAddr, initialDelegatorBalance) - initialValidatorBalance := math.NewInt(1000) initializeValidator(t, k.GetStakingKeeper().(*keeper.Keeper), ctx, valAddr, initialValidatorBalance) - amount := math.NewInt(1000) - err = k.Lock(ctx, delAddr, valAddr, amount) + // set initial block height and time + ctx = ctx.WithBlockHeight(1).WithBlockTime(time.Now()) + + err = k.Lock(ctx, delAddr, valAddr, lockAmount) require.NoError(t, err) // verify that lockup was added lockedAmt := k.GetLockupAmount(ctx, delAddr, valAddr) - require.Equal(t, amount, lockedAmt) + require.Equal(t, lockAmount, lockedAmt) + + // unlocking invalid amounts should fail + _, _, _, err = k.Unlock(ctx, delAddr, valAddr, invalidUnlockAmount) + require.Error(t, err) + _, _, _, err = k.Unlock(ctx, delAddr, valAddr, math.ZeroInt()) + require.Error(t, err) - unbondTime, unlockTime, creationHeight, err := k.Unlock(ctx, delAddr, valAddr, math.NewInt(500)) + unbondTime, unlockTime, creationHeight, err := k.Unlock(ctx, delAddr, valAddr, unlockAmount) require.NoError(t, err) // verify that lockup was updated lockedAmt = k.GetLockupAmount(ctx, delAddr, valAddr) - require.Equal(t, math.ZeroInt(), lockedAmt) + require.Equal(t, lockAmount.Sub(unlockAmount), lockedAmt) // check the unlocking entry found, amt, unbTime, unlTime := k.GetUnlockingLockup(ctx, delAddr, valAddr, creationHeight) require.True(t, found) - require.Equal(t, math.NewInt(500), amt) + require.Equal(t, unlockAmount, amt) require.Equal(t, unbondTime, unbTime) require.Equal(t, unlockTime, unlTime) } @@ -107,26 +127,34 @@ func TestUnlock(t *testing.T) { func TestRedelegate(t *testing.T) { k, ctx := testutil.SetupKeeper(t) + amount := math.NewInt(1000) + invalidAmount := math.NewInt(-100) + delAddr, err := sdk.AccAddressFromBech32("source1wjj5v5rlf57kayyeskncpu4hwev25ty645p2et") require.NoError(t, err) - srcValAddr, err := sdk.ValAddressFromBech32("sourcevaloper1cy0p47z24ejzvq55pu3lesxwf73xnrnd0pzkqm") require.NoError(t, err) - dstValAddr, err := sdk.ValAddressFromBech32("sourcevaloper13fj7t2yptf9k6ad6fv38434znzay4s4pjk0r4f") require.NoError(t, err) initialDelegatorBalance := math.NewInt(2000) initializeDelegator(t, k, ctx, delAddr, initialDelegatorBalance) - initialValidatorBalance := math.NewInt(1000) initializeValidator(t, k.GetStakingKeeper().(*keeper.Keeper), ctx, srcValAddr, initialValidatorBalance) initializeValidator(t, k.GetStakingKeeper().(*keeper.Keeper), ctx, dstValAddr, initialValidatorBalance) + // set initial block height and time + ctx = ctx.WithBlockHeight(1).WithBlockTime(time.Now()) + // lock tokens with the source validator - amount := math.NewInt(1000) require.NoError(t, k.Lock(ctx, delAddr, srcValAddr, amount)) + // redelegating invalid amounts should fail + _, err = k.Redelegate(ctx, delAddr, srcValAddr, dstValAddr, invalidAmount) + require.Error(t, err) + _, err = k.Redelegate(ctx, delAddr, srcValAddr, dstValAddr, math.ZeroInt()) + require.Error(t, err) + // redelegate from the source validator to the destination validator completionTime, err := k.Redelegate(ctx, delAddr, srcValAddr, dstValAddr, math.NewInt(500)) require.NoError(t, err) @@ -147,19 +175,22 @@ func TestRedelegate(t *testing.T) { func TestCompleteUnlocking(t *testing.T) { k, ctx := testutil.SetupKeeper(t) + lockAmount := math.NewInt(123_456) + unlockAmount := math.NewInt(123_456) + delAddr, err := sdk.AccAddressFromBech32("source1m4f5a896t7fzd9vc7pfgmc3fxkj8n24s68fcw9") require.NoError(t, err) - valAddr, err := sdk.ValAddressFromBech32("sourcevaloper1cy0p47z24ejzvq55pu3lesxwf73xnrnd0pzkqm") require.NoError(t, err) initialDelegatorBalance := math.NewInt(200_000) initializeDelegator(t, k, ctx, delAddr, initialDelegatorBalance) - initialValidatorBalance := math.NewInt(1_000_000) initializeValidator(t, k.GetStakingKeeper().(*keeper.Keeper), ctx, valAddr, initialValidatorBalance) - lockAmount := math.NewInt(123_456) + // set initial block height and time + ctx = ctx.WithBlockHeight(1).WithBlockTime(time.Now()) + err = k.Lock(ctx, delAddr, valAddr, lockAmount) require.NoError(t, err) @@ -169,7 +200,6 @@ func TestCompleteUnlocking(t *testing.T) { balance := k.GetBankKeeper().GetBalance(ctx, delAddr, appparams.DefaultBondDenom) require.Equal(t, initialDelegatorBalance.Sub(lockAmount), balance.Amount) - unlockAmount := math.NewInt(123_456) adjustedUnlockAmount := unlockAmount.Sub(math.OneInt()) // unlock tokens @@ -215,21 +245,26 @@ func TestCompleteUnlocking(t *testing.T) { func TestCancelUnlocking(t *testing.T) { k, ctx := testutil.SetupKeeper(t) + initialAmount := math.NewInt(1000) + unlockAmount := math.NewInt(500) + partialUnlockAmount := math.NewInt(200) + adjustedInitialAmount := initialAmount.Sub(math.OneInt()) // 999 + adjustedUnlockAmount := unlockAmount.Sub(math.OneInt()) // 499 + newLockAmount := initialAmount.Sub(unlockAmount).Add(partialUnlockAmount) + adjustedUnlockAmountFinal := initialAmount.Sub(unlockAmount).Sub(partialUnlockAmount).Sub(math.OneInt()) + delAddr, err := sdk.AccAddressFromBech32("source1m4f5a896t7fzd9vc7pfgmc3fxkj8n24s68fcw9") require.NoError(t, err) - valAddr, err := sdk.ValAddressFromBech32("sourcevaloper1cy0p47z24ejzvq55pu3lesxwf73xnrnd0pzkqm") require.NoError(t, err) initialDelegatorBalance := math.NewInt(200_000) initializeDelegator(t, k, ctx, delAddr, initialDelegatorBalance) - initialValidatorBalance := math.NewInt(10_000_000) initializeValidator(t, k.GetStakingKeeper().(*keeper.Keeper), ctx, valAddr, initialValidatorBalance) - initialAmount := math.NewInt(1000) - unlockAmount := math.NewInt(500) - adjustedUnlockAmount := unlockAmount.Sub(math.OneInt()) + // set initial block height and time + ctx = ctx.WithBlockHeight(1).WithBlockTime(time.Now()) // lock the initialAmount err = k.Lock(ctx, delAddr, valAddr, initialAmount) @@ -239,33 +274,83 @@ func TestCancelUnlocking(t *testing.T) { lockedAmt := k.GetLockupAmount(ctx, delAddr, valAddr) require.Equal(t, initialAmount, lockedAmt) - // unlock the unlockAmount + // unlock the unlockAmount (partial unlock) unbondTime, unlockTime, creationHeight, err := k.Unlock(ctx, delAddr, valAddr, unlockAmount) require.NoError(t, err) // verify that lockup was updated lockedAmt = k.GetLockupAmount(ctx, delAddr, valAddr) - require.Equal(t, math.ZeroInt(), lockedAmt) + require.Equal(t, initialAmount.Sub(unlockAmount), lockedAmt) // 500 // check the unlocking entry based on adjusted unlock amount found, amt, unbTime, unlTime := k.GetUnlockingLockup(ctx, delAddr, valAddr, creationHeight) require.True(t, found) - require.Equal(t, adjustedUnlockAmount, amt) + require.Equal(t, adjustedUnlockAmount, amt) // 499 require.Equal(t, unbondTime, unbTime) require.Equal(t, unlockTime, unlTime) - // cancel (remove) the unlocking lockup - err = k.CancelUnlocking(ctx, delAddr, valAddr, adjustedUnlockAmount) + // partially cancel the unlocking lockup + err = k.CancelUnlocking(ctx, delAddr, valAddr, creationHeight, &partialUnlockAmount) require.NoError(t, err) // verify that lockup was updated lockupAmount := k.GetLockupAmount(ctx, delAddr, valAddr) - require.Equal(t, adjustedUnlockAmount, lockupAmount) + require.Equal(t, newLockAmount, lockupAmount) // 700 // check the unlocking entry found, amt, unbTime, unlTime = k.GetUnlockingLockup(ctx, delAddr, valAddr, creationHeight) require.Equal(t, true, found) - require.Equal(t, adjustedUnlockAmount, amt) + require.Equal(t, adjustedUnlockAmountFinal, amt) // 299 require.Equal(t, unbondTime, unbTime) require.Equal(t, unlockTime, unlTime) + + // advance block height by 1 so that subsequent unlocking lockup is stored separately + // otherwise, existing unlocking lockup is overrirden (e.g. delAddr/valAddr/creationHeight/) + // TODO: handle edge case with 2+ messages at the same height + ctx = ctx.WithBlockHeight(2).WithBlockTime(ctx.BlockTime().Add(time.Minute)) + + // add new unlocking lockup record at height 2 to fully unlock the remaining adjustedUnlockAmountFinal + unbondTime2, unlockTime2, creationHeight2, err := k.Unlock(ctx, delAddr, valAddr, adjustedUnlockAmountFinal) + require.NoError(t, err) + + // verify that lockup was updated + lockedAmt = k.GetLockupAmount(ctx, delAddr, valAddr) + require.Equal(t, newLockAmount.Sub(adjustedUnlockAmountFinal), lockedAmt) // 401 + + // check the unlocking entry based on adjusted unlock amount + found, amt, unbTime, unlTime = k.GetUnlockingLockup(ctx, delAddr, valAddr, creationHeight2) + require.True(t, found) + require.Equal(t, adjustedUnlockAmountFinal.Sub(math.OneInt()), amt) // 298 + require.Equal(t, unbondTime2, unbTime) + require.Equal(t, unlockTime2, unlTime) + + // cancel (remove) the unlocking lockup at height 2 + err = k.CancelUnlocking(ctx, delAddr, valAddr, creationHeight2, nil) + require.NoError(t, err) + + // verify that lockup was updated + lockupAmount = k.GetLockupAmount(ctx, delAddr, valAddr) + require.Equal(t, newLockAmount.Sub(math.OneInt()), lockupAmount) // 699 + + // there is still a partial unlocking lockup at height 1 since we did not cancel it's whole amount + found, amt, unbTime, unlTime = k.GetUnlockingLockup(ctx, delAddr, valAddr, 1) + require.Equal(t, true, found) + require.Equal(t, adjustedUnlockAmountFinal, amt) // 299 + require.Equal(t, unbondTime, unbTime) + require.Equal(t, unlockTime, unlTime) + + // cancel (remove) the remaining unlocking lockup at height 1 + err = k.CancelUnlocking(ctx, delAddr, valAddr, 1, nil) + require.NoError(t, err) + + // verify that lockup was updated + lockupAmount = k.GetLockupAmount(ctx, delAddr, valAddr) + require.Equal(t, adjustedInitialAmount.Sub(math.OneInt()), lockupAmount) // 998 + + // confirm that unlocking lockup was removed if we cancel whole amount (e.g. use nil) + found, amt, unbTime, unlTime = k.GetUnlockingLockup(ctx, delAddr, valAddr, creationHeight) + require.Equal(t, false, found) + require.Equal(t, math.ZeroInt(), amt) + require.Equal(t, time.Time{}, unbTime) + require.Equal(t, time.Time{}, unlTime) } diff --git a/x/tier/keeper/lockup.go b/x/tier/keeper/lockup.go index 7fa28616..fe284a0a 100644 --- a/x/tier/keeper/lockup.go +++ b/x/tier/keeper/lockup.go @@ -33,9 +33,10 @@ func (k Keeper) GetAllLockups(ctx context.Context) []types.Lockup { return lockups } -// SetLockup stores or updates a lockup in the state based on the key from LockupKey/UnlockingLockupKey. -// We normalize lockup times to UTC before saving to the store for consistentcy. -func (k Keeper) SetLockup(ctx context.Context, unlocking bool, delAddr sdk.AccAddress, valAddr sdk.ValAddress, amt math.Int, +// SaveLockup stores lockup or unlocking lockup based on the specified params. +// It is used in SubtractUnlockingLockup to override the same record considering existing creationHeight, +// as well as for importing lockups from the GenesisState.Lockups as part of the InitGenesis(). +func (k Keeper) SaveLockup(ctx context.Context, unlocking bool, delAddr sdk.AccAddress, valAddr sdk.ValAddress, amt math.Int, creationHeight int64, unbondTime *time.Time, unlockTime *time.Time) { var unbTime, unlTime *time.Time @@ -69,7 +70,63 @@ func (k Keeper) SetLockup(ctx context.Context, unlocking bool, delAddr sdk.AccAd store.Set(key, b) } -// GetLockup returns existing lockup amount, or nil if not found. +// SetLockup stores or updates a lockup in the state based on the key from LockupKey/UnlockingLockupKey. +// We normalize lockup times to UTC before saving to the store for consistentcy. +func (k Keeper) SetLockup(ctx context.Context, unlocking bool, delAddr sdk.AccAddress, valAddr sdk.ValAddress, amt math.Int, unbondTime *time.Time) (int64, *time.Time, time.Time) { + sdkCtx := sdk.UnwrapSDKContext(ctx) + params := k.GetParams(ctx) + creationHeight := sdkCtx.BlockHeight() + epochDuration := *params.EpochDuration + + unlockTime := sdkCtx.BlockTime().Add(epochDuration * time.Duration(params.UnlockingEpochs)) + // use unbondTime from stakingKeeper.Undelegate() if present, set it to match unlockTime otherwise + var unbTime *time.Time + if unbondTime != nil { + utcTime := unbondTime.UTC() + unbTime = &utcTime + } else { + unbTime = &unlockTime + } + + lockup := &types.Lockup{ + DelegatorAddress: delAddr.String(), + ValidatorAddress: valAddr.String(), + Amount: amt, + CreationHeight: creationHeight, + UnbondTime: unbTime, + UnlockTime: &unlockTime, + } + + // use different key for unlocking lockups + var key []byte + if unlocking { + key = types.UnlockingLockupKey(delAddr, valAddr, creationHeight) + } else { + key = types.LockupKey(delAddr, valAddr) + } + + b := k.cdc.MustMarshal(lockup) + store := k.lockupStore(ctx, unlocking) + store.Set(key, b) + + return creationHeight, unbTime, unlockTime +} + +func (k Keeper) GetLockups(ctx context.Context, delAddr sdk.AccAddress) []types.Lockup { + var lockups []types.Lockup + + cb := func(d sdk.AccAddress, valAddr sdk.ValAddress, lockup types.Lockup) { + if d.Equals(delAddr) { + lockups = append(lockups, lockup) + } + } + + k.MustIterateLockups(ctx, cb) + + return lockups +} + +// GetLockup returns a pointer to existing lockup, or nil if not found. func (k Keeper) GetLockup(ctx context.Context, delAddr sdk.AccAddress, valAddr sdk.ValAddress) *types.Lockup { key := types.LockupKey(delAddr, valAddr) store := k.lockupStore(ctx, false) @@ -78,13 +135,13 @@ func (k Keeper) GetLockup(ctx context.Context, delAddr sdk.AccAddress, valAddr s return nil } - var lockup *types.Lockup - k.cdc.MustUnmarshal(b, lockup) + var lockup types.Lockup + k.cdc.MustUnmarshal(b, &lockup) - return lockup + return &lockup } -// GetLockup returns existing lockup amount, or math.ZeroInt() if not found. +// GetLockupAmount returns existing lockup amount, or math.ZeroInt() if not found. func (k Keeper) GetLockupAmount(ctx context.Context, delAddr sdk.AccAddress, valAddr sdk.ValAddress) math.Int { key := types.LockupKey(delAddr, valAddr) store := k.lockupStore(ctx, false) @@ -133,8 +190,8 @@ func (k Keeper) removeLockup(ctx context.Context, delAddr sdk.AccAddress, valAdd } // removeUnlockingLockup removes existing unlocking lockup (delAddr/valAddr/creationHeight/). -func (k Keeper) removeUnlockingLockup(ctx context.Context, delAddr sdk.AccAddress, valAddr sdk.ValAddress) { - key := types.LockupKey(delAddr, valAddr) +func (k Keeper) removeUnlockingLockup(ctx context.Context, delAddr sdk.AccAddress, valAddr sdk.ValAddress, creationHeight int64) { + key := types.UnlockingLockupKey(delAddr, valAddr, creationHeight) store := k.lockupStore(ctx, true) store.Delete(key) } @@ -143,19 +200,63 @@ func (k Keeper) removeUnlockingLockup(ctx context.Context, delAddr sdk.AccAddres func (k Keeper) AddLockup(ctx context.Context, delAddr sdk.AccAddress, valAddr sdk.ValAddress, amt math.Int) { lockedAmt := k.GetLockupAmount(ctx, delAddr, valAddr) amt = amt.Add(lockedAmt) - k.SetLockup(ctx, false, delAddr, valAddr, amt, sdk.UnwrapSDKContext(ctx).BlockHeight(), nil, nil) + k.SetLockup(ctx, false, delAddr, valAddr, amt, nil) } // SubtractLockup subtracts provided amt from the existing delAddr/valAddr lockup. func (k Keeper) SubtractLockup(ctx context.Context, delAddr sdk.AccAddress, valAddr sdk.ValAddress, amt math.Int) error { lockedAmt := k.GetLockupAmount(ctx, delAddr, valAddr) - lockedAmt, err := lockedAmt.SafeSub(amt) + // subtracted amt must not be larger than the lockedAmt + if amt.GT(lockedAmt) { + return types.ErrInvalidAmount.Wrap("invalid amount") + } + + // remove lockup record completely if subtracted amt is equal to lockedAmt + if amt.Equal(lockedAmt) { + k.removeLockup(ctx, delAddr, valAddr) + return nil + } + + // subtract amt from the lockedAmt othwerwise + newAmt, err := lockedAmt.SafeSub(amt) if err != nil { return errorsmod.Wrapf(err, "subtract %s from locked amount %s", amt, lockedAmt) } - k.SetLockup(ctx, false, delAddr, valAddr, lockedAmt, sdk.UnwrapSDKContext(ctx).BlockHeight(), nil, nil) + k.SetLockup(ctx, false, delAddr, valAddr, newAmt, nil) + + return nil +} + +// SubtractUnlockingLockup subtracts provided amt from the existing unlocking lockup (delAddr/valAddr/creationHeight/). +func (k Keeper) SubtractUnlockingLockup(ctx context.Context, delAddr sdk.AccAddress, valAddr sdk.ValAddress, creationHeight int64, amt math.Int) error { + // get full unlocking lockup record because we must pass valid time(s) to SaveLockup + found, lockedAmt, unbondTime, unlockTime := k.GetUnlockingLockup(ctx, delAddr, valAddr, creationHeight) + + // return early if not found + if !found { + return nil + } + + // subtracted amt must not be larger than the lockedAmt + if amt.GT(lockedAmt) { + return types.ErrInvalidAmount.Wrap("invalid amount") + } + + // remove lockup record completely if subtracted amt is equal to lockedAmt + if amt.Equal(lockedAmt) { + k.removeUnlockingLockup(ctx, delAddr, valAddr, creationHeight) + return nil + } + + // subtract amt from the lockedAmt othwerwise + newAmt, err := lockedAmt.SafeSub(amt) + if err != nil { + return errorsmod.Wrapf(err, "subtract %s from unlocking lockup locked amount %s", amt, lockedAmt) + } + + k.SaveLockup(ctx, true, delAddr, valAddr, newAmt, creationHeight, &unbondTime, &unlockTime) return nil } @@ -164,8 +265,8 @@ func (k Keeper) SubtractLockup(ctx context.Context, delAddr sdk.AccAddress, valA func (k Keeper) TotalAmountByAddr(ctx context.Context, delAddr sdk.AccAddress) math.Int { amt := math.ZeroInt() - cb := func(delAddr sdk.AccAddress, valAddr sdk.ValAddress, lockup types.Lockup) { - if delAddr.Equals(delAddr) { + cb := func(d sdk.AccAddress, valAddr sdk.ValAddress, lockup types.Lockup) { + if d.Equals(delAddr) { amt = amt.Add(lockup.Amount) } } diff --git a/x/tier/keeper/lockup_test.go b/x/tier/keeper/lockup_test.go index 10b3a9ac..589ab465 100644 --- a/x/tier/keeper/lockup_test.go +++ b/x/tier/keeper/lockup_test.go @@ -1,6 +1,7 @@ package keeper_test import ( + "errors" "testing" "time" @@ -19,17 +20,23 @@ func init() { func TestSetAndGetLockup(t *testing.T) { k, ctx := testutil.SetupKeeper(t) - amount := math.NewInt(1000) + now := time.Now() + params := k.GetParams(ctx) + epochDuration := *params.EpochDuration creationHeight := int64(10) - unbondTime := time.Now().Add(1 * time.Hour) - unlockTime := time.Now().Add(2 * time.Hour) + amount := math.NewInt(1000) delAddr, err := sdk.AccAddressFromBech32("source1m4f5a896t7fzd9vc7pfgmc3fxkj8n24s68fcw9") require.Nil(t, err) valAddr, err := sdk.ValAddressFromBech32("sourcevaloper1cy0p47z24ejzvq55pu3lesxwf73xnrnd0pzkqm") require.Nil(t, err) - k.SetLockup(ctx, false, delAddr, valAddr, amount, creationHeight, &unbondTime, &unlockTime) + ctx = ctx.WithBlockHeight(creationHeight).WithBlockTime(now) + + unbondTime := ctx.BlockTime().Add(epochDuration * time.Duration(params.UnlockingEpochs)) + unlockTime := unbondTime + + k.SetLockup(ctx, false, delAddr, valAddr, amount, nil) store := k.GetAllLockups(ctx) require.Len(t, store, 1) @@ -62,20 +69,36 @@ func TestAddLockup(t *testing.T) { func TestSubtractLockup(t *testing.T) { k, ctx := testutil.SetupKeeper(t) - amount := math.NewInt(1000) + lockupAmount := math.NewInt(1000) + partialSubtractAmount := math.NewInt(500) + invalidSubtractAmount := math.NewInt(2000) delAddr, err := sdk.AccAddressFromBech32("source1m4f5a896t7fzd9vc7pfgmc3fxkj8n24s68fcw9") require.Nil(t, err) valAddr, err := sdk.ValAddressFromBech32("sourcevaloper1cy0p47z24ejzvq55pu3lesxwf73xnrnd0pzkqm") require.Nil(t, err) - k.AddLockup(ctx, delAddr, valAddr, amount) + k.AddLockup(ctx, delAddr, valAddr, lockupAmount) - err = k.SubtractLockup(ctx, delAddr, valAddr, math.NewInt(500)) + // subtract a partial amount + err = k.SubtractLockup(ctx, delAddr, valAddr, partialSubtractAmount) require.NoError(t, err) - lockup := k.GetLockupAmount(ctx, delAddr, valAddr) - require.Equal(t, math.NewInt(500), lockup) + lockedAmt := k.GetLockupAmount(ctx, delAddr, valAddr) + require.Equal(t, partialSubtractAmount, lockedAmt) + + // attempt to subtract more than the locked amount + err = k.SubtractLockup(ctx, delAddr, valAddr, invalidSubtractAmount) + require.Error(t, err) + + // subtract the remaining amount + err = k.SubtractLockup(ctx, delAddr, valAddr, partialSubtractAmount) + require.NoError(t, err) + + // verify that the lockup has been removed + lockedAmt = k.GetLockupAmount(ctx, delAddr, valAddr) + require.True(t, lockedAmt.IsZero(), "remaining lockup amount should be zero") + require.False(t, k.HasLockup(ctx, delAddr, valAddr), "lockup should be removed") } func TestGetAllLockups(t *testing.T) { @@ -91,11 +114,11 @@ func TestGetAllLockups(t *testing.T) { delAddr2, err := sdk.AccAddressFromBech32("source1m4f5a896t7fzd9vc7pfgmc3fxkj8n24s68fcw9") require.Nil(t, err) - valAddr2, err := sdk.ValAddressFromBech32("sourcevaloper1cy0p47z24ejzvq55pu3lesxwf73xnrnd0pzkqm") + valAddr2, err := sdk.ValAddressFromBech32("sourcevaloper13fj7t2yptf9k6ad6fv38434znzay4s4pjk0r4f") require.Nil(t, err) - k.SetLockup(ctx, false, delAddr1, valAddr1, amount1, 1, nil, nil) - k.SetLockup(ctx, false, delAddr2, valAddr2, amount2, 2, nil, nil) + k.SetLockup(ctx, false, delAddr1, valAddr1, amount1, nil) + k.SetLockup(ctx, false, delAddr2, valAddr2, amount2, nil) lockups := k.GetAllLockups(ctx) require.Len(t, lockups, 2) @@ -139,10 +162,7 @@ func TestMustIterateUnlockingLockups(t *testing.T) { valAddr, err := sdk.ValAddressFromBech32("sourcevaloper1cy0p47z24ejzvq55pu3lesxwf73xnrnd0pzkqm") require.Nil(t, err) - unbondTime := time.Now().Add(24 * time.Hour) - unlockTime := time.Now().Add(48 * time.Hour) - - k.SetLockup(ctx, true, delAddr, valAddr, amount, 1, &unbondTime, &unlockTime) + k.SetLockup(ctx, true, delAddr, valAddr, amount, nil) count := 0 k.MustIterateUnlockingLockups(ctx, func(delAddr sdk.AccAddress, valAddr sdk.ValAddress, creationHeight int64, lockup types.Lockup) { @@ -166,17 +186,21 @@ func TestIterateLockups(t *testing.T) { delAddr2, err := sdk.AccAddressFromBech32("source1m4f5a896t7fzd9vc7pfgmc3fxkj8n24s68fcw9") require.Nil(t, err) - valAddr2, err := sdk.ValAddressFromBech32("sourcevaloper1cy0p47z24ejzvq55pu3lesxwf73xnrnd0pzkqm") + valAddr2, err := sdk.ValAddressFromBech32("sourcevaloper13fj7t2yptf9k6ad6fv38434znzay4s4pjk0r4f") require.Nil(t, err) - k.SetLockup(ctx, false, delAddr1, valAddr1, math.NewInt(1000), 1, nil, nil) - k.SetLockup(ctx, false, delAddr2, valAddr2, math.NewInt(500), 2, nil, nil) + ctx = ctx.WithBlockHeight(1) + k.SetLockup(ctx, false, delAddr1, valAddr1, math.NewInt(1000), nil) + k.SetLockup(ctx, false, delAddr2, valAddr2, math.NewInt(500), nil) + + ctx = ctx.WithBlockHeight(2) + k.SetLockup(ctx, true, delAddr1, valAddr1, math.NewInt(200), nil) - unbondTime := time.Now().Add(24 * time.Hour) - unlockTime := time.Now().Add(48 * time.Hour) - k.SetLockup(ctx, true, delAddr1, valAddr1, math.NewInt(200), 3, &unbondTime, &unlockTime) - k.SetLockup(ctx, true, delAddr1, valAddr1, math.NewInt(200), 4, &unbondTime, &unbondTime) - k.SetLockup(ctx, true, delAddr1, valAddr1, math.NewInt(200), 5, &unbondTime, &unbondTime) + ctx = ctx.WithBlockHeight(3) + k.SetLockup(ctx, true, delAddr1, valAddr1, math.NewInt(200), nil) + + ctx = ctx.WithBlockHeight(4) + k.SetLockup(ctx, true, delAddr1, valAddr1, math.NewInt(200), nil) lockupsCount := 0 err = k.IterateLockups(ctx, false, func(delAddr sdk.AccAddress, valAddr sdk.ValAddress, creationHeight int64, lockup types.Lockup) error { @@ -201,4 +225,196 @@ func TestIterateLockups(t *testing.T) { }) require.NoError(t, err) require.Equal(t, 3, unlockingLockupsCount) + + err = k.IterateLockups(ctx, false, func(delAddr sdk.AccAddress, valAddr sdk.ValAddress, creationHeight int64, lockup types.Lockup) error { + return errors.New("not found") + }) + require.Error(t, err) +} + +func TestTotalAmountByAddr(t *testing.T) { + k, ctx := testutil.SetupKeeper(t) + + delAddr1, err := sdk.AccAddressFromBech32("source1m4f5a896t7fzd9vc7pfgmc3fxkj8n24s68fcw9") + require.NoError(t, err) + valAddr1, err := sdk.ValAddressFromBech32("sourcevaloper1cy0p47z24ejzvq55pu3lesxwf73xnrnd0pzkqm") + require.NoError(t, err) + + delAddr2, err := sdk.AccAddressFromBech32("source1wjj5v5rlf57kayyeskncpu4hwev25ty645p2et") + require.NoError(t, err) + valAddr2, err := sdk.ValAddressFromBech32("sourcevaloper13fj7t2yptf9k6ad6fv38434znzay4s4pjk0r4f") + require.NoError(t, err) + + k.AddLockup(ctx, delAddr1, valAddr1, math.NewInt(1000)) + k.AddLockup(ctx, delAddr1, valAddr1, math.NewInt(500)) + k.AddLockup(ctx, delAddr2, valAddr2, math.NewInt(700)) + + totalDel1 := k.TotalAmountByAddr(ctx, delAddr1) + require.Equal(t, math.NewInt(1500), totalDel1, "delAddr1 should have a total of 1500") + + totalDel2 := k.TotalAmountByAddr(ctx, delAddr2) + require.Equal(t, math.NewInt(700), totalDel2, "delAddr2 should have a total of 700") + + delAddr3, err := sdk.AccAddressFromBech32("source1n34fvpteuanu2nx2a4hql4jvcrcnal3gsrjppy") + require.NoError(t, err) + totalDel3 := k.TotalAmountByAddr(ctx, delAddr3) + require.True(t, totalDel3.IsZero(), "delAddr3 should have no lockups") +} + +func TestHasLockup(t *testing.T) { + k, ctx := testutil.SetupKeeper(t) + + delAddr, err := sdk.AccAddressFromBech32("source1m4f5a896t7fzd9vc7pfgmc3fxkj8n24s68fcw9") + require.NoError(t, err) + valAddr, err := sdk.ValAddressFromBech32("sourcevaloper1cy0p47z24ejzvq55pu3lesxwf73xnrnd0pzkqm") + require.NoError(t, err) + + require.False(t, k.HasLockup(ctx, delAddr, valAddr)) + + k.AddLockup(ctx, delAddr, valAddr, math.NewInt(100)) + require.True(t, k.HasLockup(ctx, delAddr, valAddr)) + + err = k.SubtractLockup(ctx, delAddr, valAddr, math.NewInt(100)) + require.NoError(t, err) + require.False(t, k.HasLockup(ctx, delAddr, valAddr), "lockup should no longer exist after removing the entire amount") +} + +func TestGetUnlockingLockup(t *testing.T) { + k, ctx := testutil.SetupKeeper(t) + + now := time.Now() + params := k.GetParams(ctx) + epochDuration := *params.EpochDuration + creationHeight := int64(10) + amount := math.NewInt(300) + + delAddr, err := sdk.AccAddressFromBech32("source1m4f5a896t7fzd9vc7pfgmc3fxkj8n24s68fcw9") + require.NoError(t, err) + valAddr, err := sdk.ValAddressFromBech32("sourcevaloper1cy0p47z24ejzvq55pu3lesxwf73xnrnd0pzkqm") + require.NoError(t, err) + + ctx = ctx.WithBlockHeight(creationHeight).WithBlockTime(now) + + unbondTime := ctx.BlockTime().Add(epochDuration * time.Duration(params.UnlockingEpochs)) + unlockTime := unbondTime + + k.SetLockup(ctx, true, delAddr, valAddr, amount, nil) + + found, amt, gotUnbondTime, gotUnlockTime := k.GetUnlockingLockup(ctx, delAddr, valAddr, creationHeight) + require.True(t, found, "unlocking lockup should be found") + require.Equal(t, amount, amt, "amount should match the one set") + require.Equal(t, unbondTime, gotUnbondTime, "unbondTime should match the one set") + require.Equal(t, unlockTime, gotUnlockTime, "unlockTime should match the one set") + + found, amt, gotUnbondTime, gotUnlockTime = k.GetUnlockingLockup(ctx, delAddr, valAddr, creationHeight+1) + require.False(t, found, "this unlocking lockup does not exist") + require.True(t, amt.IsZero(), "amount should be zero") + require.True(t, gotUnbondTime.IsZero(), "unbond time should be zero") + require.True(t, gotUnlockTime.IsZero(), "unlock time should be zero") +} + +func TestGetLockup(t *testing.T) { + k, ctx := testutil.SetupKeeper(t) + + now := time.Now() + creationHeight := int64(10) + amount := math.NewInt(1000) + + delAddr, err := sdk.AccAddressFromBech32("source1m4f5a896t7fzd9vc7pfgmc3fxkj8n24s68fcw9") + require.NoError(t, err) + valAddr, err := sdk.ValAddressFromBech32("sourcevaloper1cy0p47z24ejzvq55pu3lesxwf73xnrnd0pzkqm") + require.NoError(t, err) + + ctx = ctx.WithBlockHeight(creationHeight).WithBlockTime(now) + + params := k.GetParams(ctx) + unbondTime := ctx.BlockTime().Add(*params.EpochDuration * time.Duration(params.UnlockingEpochs)) + unlockTime := unbondTime + + k.SetLockup(ctx, false, delAddr, valAddr, amount, nil) + + lockup := k.GetLockup(ctx, delAddr, valAddr) + + require.NotNil(t, lockup, "lockup should exist") + require.Equal(t, delAddr.String(), lockup.DelegatorAddress, "delegator address should match") + require.Equal(t, valAddr.String(), lockup.ValidatorAddress, "validator address should match") + require.Equal(t, amount, lockup.Amount, "amount should match") + require.Equal(t, creationHeight, lockup.CreationHeight, "creation height should match") + require.Equal(t, unbondTime.UTC(), *lockup.UnbondTime, "unbond time should match") + require.Equal(t, unlockTime.UTC(), *lockup.UnlockTime, "unlock time should match") + + nonExistentValAddr, err := sdk.ValAddressFromBech32("sourcevaloper13fj7t2yptf9k6ad6fv38434znzay4s4pjk0r4f") + require.NoError(t, err) + + nonExistentLockup := k.GetLockup(ctx, delAddr, nonExistentValAddr) + require.Nil(t, nonExistentLockup, "lockup should not exist for the given validator") +} + +func TestGetLockups(t *testing.T) { + k, ctx := testutil.SetupKeeper(t) + + amount1 := math.NewInt(1000) + amount2 := math.NewInt(500) + + delAddr, err := sdk.AccAddressFromBech32("source1m4f5a896t7fzd9vc7pfgmc3fxkj8n24s68fcw9") + require.NoError(t, err) + valAddr1, err := sdk.ValAddressFromBech32("sourcevaloper13fj7t2yptf9k6ad6fv38434znzay4s4pjk0r4f") + require.NoError(t, err) + valAddr2, err := sdk.ValAddressFromBech32("sourcevaloper1cy0p47z24ejzvq55pu3lesxwf73xnrnd0pzkqm") + require.NoError(t, err) + + ctx = ctx.WithBlockHeight(10).WithBlockTime(time.Now()) + k.SetLockup(ctx, false, delAddr, valAddr1, amount1, nil) + + ctx = ctx.WithBlockHeight(11).WithBlockTime(time.Now().Add(time.Minute)) + k.SetLockup(ctx, false, delAddr, valAddr2, amount2, nil) + + lockups := k.GetLockups(ctx, delAddr) + + require.Len(t, lockups, 2, "delegator should have 2 lockups") + + require.Equal(t, delAddr.String(), lockups[0].DelegatorAddress) + require.Equal(t, valAddr1.String(), lockups[0].ValidatorAddress) + require.Equal(t, amount1, lockups[0].Amount) + + require.Equal(t, delAddr.String(), lockups[1].DelegatorAddress) + require.Equal(t, valAddr2.String(), lockups[1].ValidatorAddress) + require.Equal(t, amount2, lockups[1].Amount) +} + +func TestSubtractUnlockingLockup(t *testing.T) { + k, ctx := testutil.SetupKeeper(t) + + unlockingLockupAmount := math.NewInt(1000) + cancelUnlockAmount := math.NewInt(500) + cancelUnlockAmount2 := math.NewInt(2000) + creationHeight := int64(10) + + delAddr, err := sdk.AccAddressFromBech32("source1m4f5a896t7fzd9vc7pfgmc3fxkj8n24s68fcw9") + require.NoError(t, err) + valAddr, err := sdk.ValAddressFromBech32("sourcevaloper1cy0p47z24ejzvq55pu3lesxwf73xnrnd0pzkqm") + require.NoError(t, err) + + ctx = ctx.WithBlockHeight(creationHeight) + k.SetLockup(ctx, true, delAddr, valAddr, unlockingLockupAmount, nil) + + // subtract partial amount + err = k.SubtractUnlockingLockup(ctx, delAddr, valAddr, creationHeight, cancelUnlockAmount) + require.NoError(t, err) + + found, lockedAmt, _, _ := k.GetUnlockingLockup(ctx, delAddr, valAddr, creationHeight) + require.True(t, found) + require.Equal(t, cancelUnlockAmount, lockedAmt) + + // try to subtract more than the locked amount + err = k.SubtractUnlockingLockup(ctx, delAddr, valAddr, creationHeight, cancelUnlockAmount2) + require.Error(t, err) + + // subtract remaining amount + err = k.SubtractUnlockingLockup(ctx, delAddr, valAddr, creationHeight, cancelUnlockAmount) + require.NoError(t, err) + + found, lockedAmt, _, _ = k.GetUnlockingLockup(ctx, delAddr, valAddr, creationHeight) + require.False(t, found) + require.True(t, lockedAmt.IsZero()) } diff --git a/x/tier/keeper/msg_cancel_unlocking_test.go b/x/tier/keeper/msg_cancel_unlocking_test.go new file mode 100644 index 00000000..615eda59 --- /dev/null +++ b/x/tier/keeper/msg_cancel_unlocking_test.go @@ -0,0 +1,181 @@ +package keeper_test + +import ( + "testing" + + "cosmossdk.io/math" + sdk "github.com/cosmos/cosmos-sdk/types" + stakingkeeper "github.com/cosmos/cosmos-sdk/x/staking/keeper" + "github.com/stretchr/testify/require" + + appparams "github.com/sourcenetwork/sourcehub/app/params" + "github.com/sourcenetwork/sourcehub/x/tier/keeper" + "github.com/sourcenetwork/sourcehub/x/tier/types" +) + +type TestCase struct { + name string + input *types.MsgCancelUnlocking + expErr bool + expErrMsg string + expectedAmount math.Int +} + +func runMsgTestCase(t *testing.T, tc TestCase, k keeper.Keeper, ms types.MsgServer, initState func() sdk.Context, delAddress sdk.AccAddress, valAddress sdk.ValAddress) { + ctx := initState() + + err := tc.input.ValidateBasic() + if err != nil { + if tc.expErr { + require.Contains(t, err.Error(), tc.expErrMsg) + return + } + t.Fatalf("unexpected error in ValidateBasic: %v", err) + } + + resp, err := ms.CancelUnlocking(ctx, tc.input) + + if tc.expErr { + require.Error(t, err) + require.Contains(t, err.Error(), tc.expErrMsg) + } else { + require.NoError(t, err) + require.NotNil(t, resp, "Response should not be nil for valid cancel unlocking") + + lockup := k.GetLockup(ctx, delAddress, valAddress) + require.NotNil(t, lockup, "Lockup should not be nil after cancel unlocking") + require.Equal(t, tc.expectedAmount, lockup.Amount, "Lockup amount should match expected after cancel unlocking") + } +} + +func TestMsgCancelUnlocking(t *testing.T) { + k, ms, ctx := setupMsgServer(t) + sdkCtx := sdk.UnwrapSDKContext(ctx) + + validCoin := sdk.NewCoin(appparams.DefaultBondDenom, math.NewInt(100)) + partialCancelCoin := sdk.NewCoin(appparams.DefaultBondDenom, math.NewInt(50)) + excessCoin := sdk.NewCoin(appparams.DefaultBondDenom, math.NewInt(500)) + zeroCoin := sdk.NewCoin(appparams.DefaultBondDenom, math.ZeroInt()) + negativeAmount := math.NewInt(-100) + initialDelegatorBalance := math.NewInt(2000) + initialValidatorBalance := math.NewInt(1000) + + delAddr := "source1m4f5a896t7fzd9vc7pfgmc3fxkj8n24s68fcw9" + valAddr := "sourcevaloper1cy0p47z24ejzvq55pu3lesxwf73xnrnd0pzkqm" + + delAddress, err := sdk.AccAddressFromBech32(delAddr) + require.NoError(t, err) + valAddress, err := sdk.ValAddressFromBech32(valAddr) + require.NoError(t, err) + + initState := func() sdk.Context { + ctx, _ := sdkCtx.CacheContext() + initializeDelegator(t, &k, ctx, delAddress, initialDelegatorBalance) + initializeValidator(t, k.GetStakingKeeper().(*stakingkeeper.Keeper), ctx, valAddress, initialValidatorBalance) + err = k.Lock(ctx, delAddress, valAddress, validCoin.Amount) + require.NoError(t, err) + _, _, _, err = k.Unlock(ctx, delAddress, valAddress, validCoin.Amount) + require.NoError(t, err) + return ctx + } + + testCases := []TestCase{ + { + name: "invalid stake amount (zero)", + input: &types.MsgCancelUnlocking{ + DelegatorAddress: delAddr, + ValidatorAddress: valAddr, + CreationHeight: 1, + Stake: zeroCoin, + }, + expErr: true, + expErrMsg: "invalid amount", + }, + { + name: "invalid stake amount (negative)", + input: &types.MsgCancelUnlocking{ + DelegatorAddress: delAddr, + ValidatorAddress: valAddr, + CreationHeight: 1, + Stake: sdk.Coin{ + Denom: appparams.DefaultBondDenom, + Amount: negativeAmount, + }, + }, + expErr: true, + expErrMsg: "invalid amount", + }, + { + name: "non-existent unlocking", + input: &types.MsgCancelUnlocking{ + DelegatorAddress: delAddr, + ValidatorAddress: valAddr, + CreationHeight: 100, + Stake: validCoin, + }, + expErr: true, + expErrMsg: "no unbonding delegation found", + }, + { + name: "invalid delegator address", + input: &types.MsgCancelUnlocking{ + DelegatorAddress: "invalid-address", + ValidatorAddress: valAddr, + CreationHeight: 1, + Stake: validCoin, + }, + expErr: true, + expErrMsg: "delegator address", + }, + { + name: "invalid validator address", + input: &types.MsgCancelUnlocking{ + DelegatorAddress: delAddr, + ValidatorAddress: "invalid-validator-address", + CreationHeight: 1, + Stake: validCoin, + }, + expErr: true, + expErrMsg: "validator address", + }, + { + name: "valid cancel unlocking (partial)", + input: &types.MsgCancelUnlocking{ + DelegatorAddress: delAddr, + ValidatorAddress: valAddr, + CreationHeight: 1, + Stake: partialCancelCoin, + }, + expErr: false, + expectedAmount: validCoin.Amount.Sub(partialCancelCoin.Amount), + }, + { + name: "valid cancel unlocking (full)", + input: &types.MsgCancelUnlocking{ + DelegatorAddress: delAddr, + ValidatorAddress: valAddr, + CreationHeight: 1, + Stake: validCoin, + }, + expErr: false, + expectedAmount: validCoin.Amount.Sub(math.OneInt()), + }, + { + name: "excess unlocking amount", + input: &types.MsgCancelUnlocking{ + DelegatorAddress: delAddr, + ValidatorAddress: valAddr, + CreationHeight: 1, + Stake: excessCoin, + }, + expErr: false, + expectedAmount: validCoin.Amount.Sub(math.OneInt()), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + runMsgTestCase(t, tc, k, ms, initState, delAddress, valAddress) + }) + } +} diff --git a/x/tier/keeper/msg_lock_test.go b/x/tier/keeper/msg_lock_test.go new file mode 100644 index 00000000..779e4d26 --- /dev/null +++ b/x/tier/keeper/msg_lock_test.go @@ -0,0 +1,128 @@ +package keeper_test + +import ( + "testing" + + "cosmossdk.io/math" + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/stretchr/testify/require" + + stakingkeeper "github.com/cosmos/cosmos-sdk/x/staking/keeper" + appparams "github.com/sourcenetwork/sourcehub/app/params" + "github.com/sourcenetwork/sourcehub/x/tier/types" +) + +func TestMsgLock(t *testing.T) { + k, ms, ctx := setupMsgServer(t) + sdkCtx := sdk.UnwrapSDKContext(ctx) + + validCoin1 := sdk.NewCoin(appparams.DefaultBondDenom, math.NewInt(100)) + validCoin2 := sdk.NewCoin(appparams.DefaultBondDenom, math.NewInt(3000)) + zeroCoin := sdk.NewCoin(appparams.DefaultBondDenom, math.ZeroInt()) + negativeAmount := math.NewInt(-1000) + initialDelegatorBalance := math.NewInt(2000) + initialValidatorBalance := math.NewInt(1000) + + delAddr := "source1m4f5a896t7fzd9vc7pfgmc3fxkj8n24s68fcw9" + valAddr := "sourcevaloper1cy0p47z24ejzvq55pu3lesxwf73xnrnd0pzkqm" + + delAddress, err := sdk.AccAddressFromBech32(delAddr) + require.NoError(t, err) + valAddress, err := sdk.ValAddressFromBech32(valAddr) + require.NoError(t, err) + + initializeDelegator(t, &k, sdkCtx, delAddress, initialDelegatorBalance) + initializeValidator(t, k.GetStakingKeeper().(*stakingkeeper.Keeper), sdkCtx, valAddress, initialValidatorBalance) + + testCases := []struct { + name string + input *types.MsgLock + expErr bool + expErrMsg string + }{ + { + name: "valid lock", + input: &types.MsgLock{ + DelegatorAddress: delAddr, + ValidatorAddress: valAddr, + Stake: validCoin1, + }, + expErr: false, + }, + { + name: "insufficient funds", + input: &types.MsgLock{ + DelegatorAddress: delAddr, + ValidatorAddress: valAddr, + Stake: validCoin2, + }, + expErr: true, + expErrMsg: "insufficient funds", + }, + { + name: "invalid stake amount (zero)", + input: &types.MsgLock{ + DelegatorAddress: delAddr, + ValidatorAddress: valAddr, + Stake: zeroCoin, + }, + expErr: true, + expErrMsg: "invalid amount", + }, + { + name: "invalid stake amount (negative)", + input: &types.MsgLock{ + DelegatorAddress: delAddr, + ValidatorAddress: valAddr, + Stake: sdk.Coin{ + Denom: appparams.DefaultBondDenom, + Amount: negativeAmount, + }, + }, + expErr: true, + expErrMsg: "invalid amount", + }, + { + name: "invalid delegator address", + input: &types.MsgLock{ + DelegatorAddress: "invalid-delegator-address", + ValidatorAddress: valAddr, + Stake: validCoin1, + }, + expErr: true, + expErrMsg: "delegator address", + }, + { + name: "invalid validator address", + input: &types.MsgLock{ + DelegatorAddress: delAddr, + ValidatorAddress: "invalid-validator-address", + Stake: validCoin1, + }, + expErr: true, + expErrMsg: "validator address", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := tc.input.ValidateBasic() + if err != nil { + if tc.expErr { + require.Contains(t, err.Error(), tc.expErrMsg) + return + } + t.Fatalf("unexpected error in ValidateBasic: %v", err) + } + + _, err = ms.Lock(sdkCtx, tc.input) + + if tc.expErr { + require.Error(t, err) + require.Contains(t, err.Error(), tc.expErrMsg) + } else { + require.NoError(t, err) + } + }) + } +} diff --git a/x/tier/keeper/msg_redelegate_test.go b/x/tier/keeper/msg_redelegate_test.go new file mode 100644 index 00000000..e52ff01e --- /dev/null +++ b/x/tier/keeper/msg_redelegate_test.go @@ -0,0 +1,158 @@ +package keeper_test + +import ( + "testing" + "time" + + "cosmossdk.io/math" + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/stretchr/testify/require" + + stakingkeeper "github.com/cosmos/cosmos-sdk/x/staking/keeper" + appparams "github.com/sourcenetwork/sourcehub/app/params" + "github.com/sourcenetwork/sourcehub/x/tier/types" +) + +func TestMsgRedelegate(t *testing.T) { + k, ms, ctx := setupMsgServer(t) + sdkCtx := sdk.UnwrapSDKContext(ctx) + + validCoin := sdk.NewCoin(appparams.DefaultBondDenom, math.NewInt(100)) + zeroCoin := sdk.NewCoin(appparams.DefaultBondDenom, math.ZeroInt()) + negativeAmount := math.NewInt(-100) + initialDelegatorBalance := math.NewInt(2000) + initialSrcValidatorBalance := math.NewInt(1000) + initialDstValidatorBalance := math.NewInt(500) + + delAddr := "source1m4f5a896t7fzd9vc7pfgmc3fxkj8n24s68fcw9" + srcValAddr := "sourcevaloper1cy0p47z24ejzvq55pu3lesxwf73xnrnd0pzkqm" + dstValAddr := "sourcevaloper13fj7t2yptf9k6ad6fv38434znzay4s4pjk0r4f" + + delAddress, err := sdk.AccAddressFromBech32(delAddr) + require.NoError(t, err) + srcValAddress, err := sdk.ValAddressFromBech32(srcValAddr) + require.NoError(t, err) + dstValAddress, err := sdk.ValAddressFromBech32(dstValAddr) + require.NoError(t, err) + + initializeDelegator(t, &k, sdkCtx, delAddress, initialDelegatorBalance) + initializeValidator(t, k.GetStakingKeeper().(*stakingkeeper.Keeper), sdkCtx, srcValAddress, initialSrcValidatorBalance) + initializeValidator(t, k.GetStakingKeeper().(*stakingkeeper.Keeper), sdkCtx, dstValAddress, initialDstValidatorBalance) + + // lock some tokens to test redelegate logic + err = k.Lock(ctx, delAddress, srcValAddress, validCoin.Amount) + require.NoError(t, err) + + stakingParams, err := k.GetStakingKeeper().(*stakingkeeper.Keeper).GetParams(ctx) + require.NoError(t, err) + + // expectedCompletionTime should match the default staking unbonding time (e.g. 21 days) + expectedCompletionTime := sdkCtx.BlockTime().Add(stakingParams.UnbondingTime) + + testCases := []struct { + name string + input *types.MsgRedelegate + expErr bool + expErrMsg string + }{ + { + name: "valid redelegate", + input: &types.MsgRedelegate{ + DelegatorAddress: delAddr, + SrcValidatorAddress: srcValAddr, + DstValidatorAddress: dstValAddr, + Stake: validCoin, + }, + expErr: false, + }, + { + name: "insufficient lockup", + input: &types.MsgRedelegate{ + DelegatorAddress: delAddr, + SrcValidatorAddress: srcValAddr, + DstValidatorAddress: dstValAddr, + Stake: sdk.NewCoin(appparams.DefaultBondDenom, math.NewInt(500)), + }, + expErr: true, + expErrMsg: "subtract locked stake from source validator", + }, + { + name: "invalid stake amount (zero)", + input: &types.MsgRedelegate{ + DelegatorAddress: delAddr, + SrcValidatorAddress: srcValAddr, + DstValidatorAddress: dstValAddr, + Stake: zeroCoin, + }, + expErr: true, + expErrMsg: "invalid amount", + }, + { + name: "invalid stake amount (negative)", + input: &types.MsgRedelegate{ + DelegatorAddress: delAddr, + SrcValidatorAddress: srcValAddr, + DstValidatorAddress: dstValAddr, + Stake: sdk.Coin{ + Denom: appparams.DefaultBondDenom, + Amount: negativeAmount, + }, + }, + expErr: true, + expErrMsg: "invalid amount", + }, + { + name: "non-existent lockup", + input: &types.MsgRedelegate{ + DelegatorAddress: "source1wjj5v5rlf57kayyeskncpu4hwev25ty645p2et", + SrcValidatorAddress: srcValAddr, + DstValidatorAddress: dstValAddr, + Stake: validCoin, + }, + expErr: true, + expErrMsg: "subtract locked stake from source validator", + }, + { + name: "source and destination validator are the same", + input: &types.MsgRedelegate{ + DelegatorAddress: delAddr, + SrcValidatorAddress: srcValAddr, + DstValidatorAddress: srcValAddr, + Stake: validCoin, + }, + expErr: true, + expErrMsg: "src and dst validator addresses are the same", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := tc.input.ValidateBasic() + if err != nil { + if tc.expErr { + require.Contains(t, err.Error(), tc.expErrMsg) + return + } + t.Fatalf("unexpected error in ValidateBasic: %v", err) + } + + resp, err := ms.Redelegate(sdkCtx, tc.input) + + if tc.expErr { + require.Error(t, err) + require.Contains(t, err.Error(), tc.expErrMsg) + } else { + require.NoError(t, err) + require.NotNil(t, resp) + + srcLockup := k.GetLockupAmount(sdkCtx, delAddress, srcValAddress) + require.Equal(t, math.ZeroInt(), srcLockup, "Source validator lockup should be zero after valid redelegate") + + dstLockup := k.GetLockupAmount(sdkCtx, delAddress, dstValAddress) + require.Equal(t, validCoin.Amount, dstLockup, "Destination validator lockup should equal redelegated amount") + + require.WithinDuration(t, expectedCompletionTime, resp.CompletionTime, time.Second) + } + }) + } +} diff --git a/x/tier/keeper/msg_server.go b/x/tier/keeper/msg_server.go index 98a698b3..b6aaf8a4 100644 --- a/x/tier/keeper/msg_server.go +++ b/x/tier/keeper/msg_server.go @@ -22,7 +22,7 @@ func NewMsgServerImpl(keeper Keeper) types.MsgServer { func (m msgServer) UpdateParams(ctx context.Context, msg *types.MsgUpdateParams) (*types.MsgUpdateParamsResponse, error) { authority := m.Keeper.GetAuthority() if msg.Authority != authority { - return nil, types.ErrUnauthorized.Wrapf("expected authority: %s, got: %s", authority, msg.Authority) + return nil, types.ErrUnauthorized.Wrapf("invalid authority: %s", msg.Authority) } err := msg.Params.Validate() @@ -68,7 +68,7 @@ func (m msgServer) CancelUnlocking(ctx context.Context, msg *types.MsgCancelUnlo delAddr := sdk.MustAccAddressFromBech32(msg.DelegatorAddress) valAddr := types.MustValAddressFromBech32(msg.ValidatorAddress) - err := m.Keeper.CancelUnlocking(ctx, delAddr, valAddr, msg.Stake.Amount) + err := m.Keeper.CancelUnlocking(ctx, delAddr, valAddr, msg.CreationHeight, &msg.Stake.Amount) if err != nil { return nil, errorsmod.Wrap(err, "cancel unlocking") } diff --git a/x/tier/keeper/msg_server_test.go b/x/tier/keeper/msg_server_test.go new file mode 100644 index 00000000..328fa2ec --- /dev/null +++ b/x/tier/keeper/msg_server_test.go @@ -0,0 +1,24 @@ +package keeper_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + + keepertest "github.com/sourcenetwork/sourcehub/testutil/keeper" + "github.com/sourcenetwork/sourcehub/x/tier/keeper" + "github.com/sourcenetwork/sourcehub/x/tier/types" +) + +func setupMsgServer(t testing.TB) (keeper.Keeper, types.MsgServer, context.Context) { + k, ctx := keepertest.TierKeeper(t) + return k, keeper.NewMsgServerImpl(k), ctx +} + +func TestMsgServer(t *testing.T) { + k, ms, ctx := setupMsgServer(t) + require.NotNil(t, ms) + require.NotNil(t, ctx) + require.NotEmpty(t, k) +} diff --git a/x/tier/keeper/msg_unlock_test.go b/x/tier/keeper/msg_unlock_test.go new file mode 100644 index 00000000..c1365bf3 --- /dev/null +++ b/x/tier/keeper/msg_unlock_test.go @@ -0,0 +1,153 @@ +package keeper_test + +import ( + "testing" + "time" + + "cosmossdk.io/math" + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/stretchr/testify/require" + + stakingkeeper "github.com/cosmos/cosmos-sdk/x/staking/keeper" + appparams "github.com/sourcenetwork/sourcehub/app/params" + "github.com/sourcenetwork/sourcehub/x/tier/types" +) + +func TestMsgUnlock(t *testing.T) { + k, ms, ctx := setupMsgServer(t) + sdkCtx := sdk.UnwrapSDKContext(ctx) + + validCoin := sdk.NewCoin(appparams.DefaultBondDenom, math.NewInt(100)) + zeroCoin := sdk.NewCoin(appparams.DefaultBondDenom, math.ZeroInt()) + negativeAmount := math.NewInt(-100) + initialDelegatorBalance := math.NewInt(2000) + initialValidatorBalance := math.NewInt(1000) + + delAddr := "source1m4f5a896t7fzd9vc7pfgmc3fxkj8n24s68fcw9" + valAddr := "sourcevaloper1cy0p47z24ejzvq55pu3lesxwf73xnrnd0pzkqm" + + delAddress, err := sdk.AccAddressFromBech32(delAddr) + require.NoError(t, err) + valAddress, err := sdk.ValAddressFromBech32(valAddr) + require.NoError(t, err) + + initializeDelegator(t, &k, sdkCtx, delAddress, initialDelegatorBalance) + initializeValidator(t, k.GetStakingKeeper().(*stakingkeeper.Keeper), sdkCtx, valAddress, initialValidatorBalance) + + // lock some tokens to test the unlock logic + err = k.Lock(ctx, delAddress, valAddress, validCoin.Amount) + require.NoError(t, err) + + // expectedUnlockTime should match the SetLockup logic for setting the unlock time + params := k.GetParams(ctx) + epochDuration := *params.EpochDuration + expectedUnlockTime := sdkCtx.BlockTime().Add(epochDuration * time.Duration(params.UnlockingEpochs)) + + testCases := []struct { + name string + input *types.MsgUnlock + expErr bool + expErrMsg string + }{ + { + name: "valid unlock", + input: &types.MsgUnlock{ + DelegatorAddress: delAddr, + ValidatorAddress: valAddr, + Stake: validCoin, + }, + expErr: false, + }, + { + name: "insufficient lockup", + input: &types.MsgUnlock{ + DelegatorAddress: delAddr, + ValidatorAddress: valAddr, + Stake: sdk.NewCoin(appparams.DefaultBondDenom, math.NewInt(500)), + }, + expErr: true, + expErrMsg: "subtract lockup", + }, + { + name: "invalid stake amount (zero)", + input: &types.MsgUnlock{ + DelegatorAddress: delAddr, + ValidatorAddress: valAddr, + Stake: zeroCoin, + }, + expErr: true, + expErrMsg: "invalid amount", + }, + { + name: "invalid stake amount (negative)", + input: &types.MsgUnlock{ + DelegatorAddress: delAddr, + ValidatorAddress: valAddr, + Stake: sdk.Coin{ + Denom: appparams.DefaultBondDenom, + Amount: negativeAmount, + }, + }, + expErr: true, + expErrMsg: "invalid amount", + }, + { + name: "non-existent lockup", + input: &types.MsgUnlock{ + DelegatorAddress: "source1wjj5v5rlf57kayyeskncpu4hwev25ty645p2et", + ValidatorAddress: valAddr, + Stake: validCoin, + }, + expErr: true, + expErrMsg: "subtract lockup", + }, + { + name: "invalid delegator address", + input: &types.MsgUnlock{ + DelegatorAddress: "invalid-delegator-address", + ValidatorAddress: valAddr, + Stake: validCoin, + }, + expErr: true, + expErrMsg: "delegator address", + }, + { + name: "invalid validator address", + input: &types.MsgUnlock{ + DelegatorAddress: delAddr, + ValidatorAddress: "invalid-validator-address", + Stake: validCoin, + }, + expErr: true, + expErrMsg: "validator address", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := tc.input.ValidateBasic() + if err != nil { + if tc.expErr { + require.Contains(t, err.Error(), tc.expErrMsg) + return + } + t.Fatalf("unexpected error in ValidateBasic: %v", err) + } + + resp, err := ms.Unlock(sdkCtx, tc.input) + + if tc.expErr { + require.Error(t, err) + require.Contains(t, err.Error(), tc.expErrMsg) + } else { + require.NoError(t, err) + require.NotNil(t, resp) + + lockup := k.GetLockup(sdkCtx, delAddress, valAddress) + require.Nil(t, lockup, "Lockup should be nil after valid unlock") + + require.WithinDuration(t, expectedUnlockTime, resp.CompletionTime, time.Second) + } + }) + } +} diff --git a/x/tier/keeper/msg_update_params_test.go b/x/tier/keeper/msg_update_params_test.go new file mode 100644 index 00000000..8ba5ec3b --- /dev/null +++ b/x/tier/keeper/msg_update_params_test.go @@ -0,0 +1,64 @@ +package keeper_test + +import ( + "testing" + + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/stretchr/testify/require" + + "github.com/sourcenetwork/sourcehub/x/tier/types" +) + +func TestMsgUpdateParams(t *testing.T) { + k, ms, ctx := setupMsgServer(t) + params := types.DefaultParams() + require.NoError(t, k.SetParams(ctx, params)) + sdkCtx := sdk.UnwrapSDKContext(ctx) + + // default params + testCases := []struct { + name string + input *types.MsgUpdateParams + expErr bool + expErrMsg string + }{ + { + name: "invalid authority", + input: &types.MsgUpdateParams{ + Authority: "invalid", + Params: params, + }, + expErr: true, + expErrMsg: "invalid authority", + }, + { + name: "send enabled param", + input: &types.MsgUpdateParams{ + Authority: k.GetAuthority(), + Params: types.Params{}, + }, + expErr: false, + }, + { + name: "all good", + input: &types.MsgUpdateParams{ + Authority: k.GetAuthority(), + Params: params, + }, + expErr: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + _, err := ms.UpdateParams(sdkCtx, tc.input) + + if tc.expErr { + require.Error(t, err) + require.Contains(t, err.Error(), tc.expErrMsg) + } else { + require.NoError(t, err) + } + }) + } +} diff --git a/x/tier/module/genesis.go b/x/tier/module/genesis.go index a0dc7dfb..d16a74b4 100644 --- a/x/tier/module/genesis.go +++ b/x/tier/module/genesis.go @@ -21,7 +21,7 @@ func InitGenesis(ctx context.Context, k keeper.Keeper, genState types.GenesisSta if k.HasLockup(ctx, delAddr, valAddr) { k.AddLockup(ctx, delAddr, valAddr, lockup.Amount) } else { - k.SetLockup(ctx, lockup.UnlockTime != nil, delAddr, valAddr, lockup.Amount, lockup.CreationHeight, lockup.UnbondTime, lockup.UnlockTime) + k.SaveLockup(ctx, lockup.UnlockTime != nil, delAddr, valAddr, lockup.Amount, lockup.CreationHeight, lockup.UnbondTime, lockup.UnlockTime) } } } diff --git a/x/tier/types/errors.go b/x/tier/types/errors.go index a6b37021..a4c88453 100644 --- a/x/tier/types/errors.go +++ b/x/tier/types/errors.go @@ -8,9 +8,10 @@ import ( // x/tier module sentinel errors var ( - ErrUnauthorized = sdkerrors.Register(ModuleName, 1101, "unathorized") + ErrUnauthorized = sdkerrors.Register(ModuleName, 1101, "unauthorized") ErrNotFound = sdkerrors.Register(ModuleName, 1102, "not found") ErrInvalidRequest = sdkerrors.Register(ModuleName, 1103, "invalid request") ErrInvalidAddress = sdkerrors.Register(ModuleName, 1104, "invalid address") ErrInvalidDenom = sdkerrors.Register(ModuleName, 1105, "invalid denom") + ErrInvalidAmount = sdkerrors.Register(ModuleName, 1106, "invalid amount") ) diff --git a/x/tier/types/events.go b/x/tier/types/events.go index 02d11621..48c6f3e6 100644 --- a/x/tier/types/events.go +++ b/x/tier/types/events.go @@ -1,5 +1,16 @@ package types const ( - EventTypeCancelUnlocking = "cancel_unlocking" + EventTypeCancelUnlocking = "cancel_unlocking" + EventTypeCompleteUnlocking = "complete_unlocking" + EventTypeLock = "lock" + EventTypeRedelegate = "redelegate" + EventTypeUnlock = "unlock" + + AttributeKeyCompletionTime = "completion_time" + AttributeKeyCreationHeight = "creation_height" + AttributeKeyDestinationValidator = "destination_validator" + AttributeKeySourceValidator = "source_validator" + AttributeKeyUnbondTime = "unbond_time" + AttributeKeyUnlockTime = "unlock_time" ) diff --git a/x/tier/types/expected_keepers.go b/x/tier/types/expected_keepers.go index 6efe1867..15b01147 100644 --- a/x/tier/types/expected_keepers.go +++ b/x/tier/types/expected_keepers.go @@ -26,6 +26,8 @@ type StakingKeeper interface { sharesAmount math.LegacyDec) (completionTime time.Time, err error) BondDenom(ctx context.Context) (string, error) GetValidator(ctx context.Context, addr sdk.ValAddress) (validator stakingtypes.Validator, err error) + IterateValidators(context.Context, func(index int64, validator stakingtypes.ValidatorI) (stop bool)) error + TotalBondedTokens(context.Context) (math.Int, error) ValidateUnbondAmount(ctx context.Context, delAddr sdk.AccAddress, valAddr sdk.ValAddress, amt math.Int) ( shares math.LegacyDec, err error) GetUnbondingDelegation(ctx context.Context, delAddr sdk.AccAddress, valAddr sdk.ValAddress) ( diff --git a/x/tier/types/messages.go b/x/tier/types/messages.go index 869a6818..f28d3c3c 100644 --- a/x/tier/types/messages.go +++ b/x/tier/types/messages.go @@ -91,7 +91,7 @@ func NewMsgRedelegate(delAddress, srcValAddr, dstValAddr string, stake sdk.Coin) func (msg *MsgRedelegate) ValidateBasic() error { if msg.SrcValidatorAddress == msg.DstValidatorAddress { - return ErrInvalidAddress.Wrapf("src and dst validator addresses are the sames") + return ErrInvalidAddress.Wrapf("src and dst validator addresses are the same") } if err := validateAccAddr(msg.DelegatorAddress); err != nil { return err