diff --git a/src/BaseRewardStreams.sol b/src/BaseRewardStreams.sol index cadc014..f8a6c79 100644 --- a/src/BaseRewardStreams.sol +++ b/src/BaseRewardStreams.sol @@ -20,7 +20,7 @@ abstract contract BaseRewardStreams is IRewardStreams, EVCUtil, ReentrancyGuard uint256 public constant MAX_EPOCHS_AHEAD = 5; uint256 public constant MAX_DISTRIBUTION_LENGTH = 25; uint256 public constant MAX_REWARDS_ENABLED = 5; - uint256 internal constant MAX_EPOCH_DURATION = 7 days; + uint256 internal constant MIN_EPOCH_DURATION = 7 days; uint256 internal constant EPOCHS_PER_SLOT = 2; // this value is used to increase the precision of the calculations due to the fact that distributed amount of @@ -58,6 +58,9 @@ abstract contract BaseRewardStreams is IRewardStreams, EVCUtil, ReentrancyGuard /// @notice Rewards-related error. Thrown when user tries to enable too many rewards. error TooManyRewardsEnabled(); + /// @notice Recipient-related error. Throws when the recipient is invalid. + error InvalidRecipient(); + /// @notice Struct to store distribution data per rewarded and reward tokens. struct DistributionStorage { /// @notice The last timestamp when the distribution was updated. @@ -99,7 +102,7 @@ abstract contract BaseRewardStreams is IRewardStreams, EVCUtil, ReentrancyGuard /// @param _evc The Ethereum Vault Connector contract. /// @param _epochDuration The duration of an epoch. constructor(address _evc, uint48 _epochDuration) EVCUtil(_evc) { - if (_epochDuration < MAX_EPOCH_DURATION) { + if (_epochDuration < MIN_EPOCH_DURATION) { revert InvalidEpoch(); } @@ -450,16 +453,14 @@ abstract contract BaseRewardStreams is IRewardStreams, EVCUtil, ReentrancyGuard /// @notice Claims the earned reward for a specific account, rewarded token, and reward token, and transfers it to /// the recipient. /// @dev If recipient is address(0) or there is no reward to claim, this function does nothing. - /// @param msgSender The address of the account claiming the reward. + /// @param account The address of the account claiming the reward. /// @param rewarded The address of the rewarded token. /// @param reward The address of the reward token. /// @param recipient The address to which the claimed reward will be transferred. - function claim(address msgSender, address rewarded, address reward, address recipient) internal virtual { - if (recipient == address(0)) { - return; - } + function claim(address account, address rewarded, address reward, address recipient) internal virtual { + if (recipient == address(0)) revert InvalidRecipient(); - EarnStorage storage earnStorage = accountEarnedData[msgSender][rewarded][reward]; + EarnStorage storage earnStorage = accountEarnedData[account][rewarded][reward]; uint128 amount = earnStorage.claimable; // If there is a reward token to claim, transfer it to the recipient and emit an event. @@ -475,7 +476,7 @@ abstract contract BaseRewardStreams is IRewardStreams, EVCUtil, ReentrancyGuard earnStorage.claimable = 0; IERC20(reward).safeTransfer(recipient, amount); - emit RewardClaimed(msgSender, rewarded, reward, amount); + emit RewardClaimed(account, rewarded, reward, amount); } } diff --git a/test/unit/Scenarios.t.sol b/test/unit/Scenarios.t.sol index d41f646..6a2c954 100644 --- a/test/unit/Scenarios.t.sol +++ b/test/unit/Scenarios.t.sol @@ -1906,9 +1906,18 @@ contract ScenarioTest is Test { address participant2, address participant3 ) external { - vm.assume(participant1 != address(0) && participant1 != seeder && participant1 != address(evc)); - vm.assume(participant2 != address(0) && participant2 != seeder && participant2 != address(evc)); - vm.assume(participant3 != address(0) && participant3 != seeder && participant3 != address(evc)); + vm.assume( + participant1 != address(0) && participant1 != seeder && participant1 != address(evc) + && participant1.code.length == 0 + ); + vm.assume( + participant2 != address(0) && participant2 != seeder && participant2 != address(evc) + && participant2.code.length == 0 + ); + vm.assume( + participant3 != address(0) && participant3 != seeder && participant3 != address(evc) + && participant3.code.length == 0 + ); vm.assume(participant1 != participant2 && participant1 != participant3 && participant2 != participant3); blockTimestamp = uint48(bound(blockTimestamp, 1, type(uint48).max - 365 days)); @@ -2323,4 +2332,29 @@ contract ScenarioTest is Test { vm.expectRevert(); trackingDistributor.claimReward(_rewarded, _reward, _account, true); } + + function test_RevertWhenRecipientInvalid_Claim( + address _rewarded, + address _reward, + address _receiver, + bool _forfeitRecentReward + ) external { + vm.assume(_receiver != address(0)); + + vm.expectRevert(BaseRewardStreams.InvalidRecipient.selector); + stakingDistributor.claimReward(_rewarded, _reward, address(0), _forfeitRecentReward); + stakingDistributor.claimReward(_rewarded, _reward, _receiver, _forfeitRecentReward); + + vm.expectRevert(BaseRewardStreams.InvalidRecipient.selector); + trackingDistributor.claimReward(_rewarded, _reward, address(0), _forfeitRecentReward); + trackingDistributor.claimReward(_rewarded, _reward, _receiver, _forfeitRecentReward); + + vm.expectRevert(BaseRewardStreams.InvalidRecipient.selector); + stakingDistributor.claimSpilloverReward(_rewarded, _reward, address(0)); + stakingDistributor.claimSpilloverReward(_rewarded, _reward, _receiver); + + vm.expectRevert(BaseRewardStreams.InvalidRecipient.selector); + trackingDistributor.claimSpilloverReward(_rewarded, _reward, address(0)); + trackingDistributor.claimSpilloverReward(_rewarded, _reward, _receiver); + } }