From d4f3bc6f489e310529cbabc9ee92a337b3a293b8 Mon Sep 17 00:00:00 2001 From: Filipp Makarov Date: Fri, 13 Oct 2023 08:57:01 +0300 Subject: [PATCH] refactor methods, test suite structure --- .solhintignore | 3 +- .../modules/AccountRecoveryModule.sol | 119 +++++++++--------- test/module/AccountRecovery.Module.specs.ts | 33 ++++- 3 files changed, 96 insertions(+), 59 deletions(-) diff --git a/.solhintignore b/.solhintignore index 24ef6876..a2a9c423 100644 --- a/.solhintignore +++ b/.solhintignore @@ -1,4 +1,5 @@ node_modules artifacts contracts/smart-account/test -contracts/smart-account/libs \ No newline at end of file +contracts/smart-account/libs +contracts/smart-account/modules/AccountRecoveryModule.sol \ No newline at end of file diff --git a/contracts/smart-account/modules/AccountRecoveryModule.sol b/contracts/smart-account/modules/AccountRecoveryModule.sol index 4a458753..0f2a49af 100644 --- a/contracts/smart-account/modules/AccountRecoveryModule.sol +++ b/contracts/smart-account/modules/AccountRecoveryModule.sol @@ -52,12 +52,12 @@ contract AccountRecoveryModule is BaseAuthorizationModule { // see https://docs.soliditylang.org/en/v0.8.15/internals/layout_in_storage.html#mappings-and-dynamic-arrays mapping(bytes32 => mapping(address => TimeFrame)) internal _guardians; + //mapping(address => bytes32[]) internal _smartAccountGuardiansLists; + mapping(address => SaSettings) internal _smartAccountSettings; mapping(address => RecoveryRequest) internal _smartAccountRequests; - // TODO - // EVENTS event RecoveryRequestSubmitted( address indexed smartAccount, bytes indexed requestCallData @@ -77,24 +77,16 @@ contract AccountRecoveryModule is BaseAuthorizationModule { event ThresholdChanged(address indexed smartAccount, uint8 threshold); error AlreadyInitedForSmartAccount(address smartAccount); - error ThresholdNotSetForSmartAccount(address smartAccount); - error InvalidSignaturesLength(); - error NotUniqueGuardianOrInvalidOrder( - address lastGuardian, - address currentGuardian - ); - error ZeroGuardian(); error InvalidTimeFrame(uint48 validUntil, uint48 validAfter); error ExpiredValidUntil(uint48 validUntil); error GuardianAlreadySet(bytes32 guardian, address smartAccount); - + error GuardianNotSet(bytes32 guardian, address smartAccount); error ThresholdTooHigh(uint8 threshold, uint256 guardiansExist); error ZeroThreshold(); error InvalidAmountOfGuardianParams(); error GuardiansAreIdentical(); - error LastGuardianRemovalAttempt(bytes32 lastGuardian); - + error GuardianNotExpired(bytes32 guardian, address smartAccount); error EmptyRecoveryCallData(); error RecoveryRequestAlreadyExists( address smartAccount, @@ -227,8 +219,9 @@ contract AccountRecoveryModule is BaseAuthorizationModule { userOp.sender ].validUntil; - // 0,0 means the `currentGuardian` has not been set as guardian for the userOp.sender smartAccount - if (validUntil == 0 && validAfter == 0) { + // validUntil == 0 means the `currentGuardian` has not been set as guardian for the userOp.sender smartAccount + // validUntil can never be 0 as it is set to type(uint48).max in initForSmartAccount + if (validUntil == 0) { return SIG_VALIDATION_FAILED; } @@ -290,15 +283,6 @@ contract AccountRecoveryModule is BaseAuthorizationModule { } } - // NOTE - if validUntil is 0, guardian is considered active forever - // Thus we put type(uint48).max as value for validUntil in this case, - // so the calldata itself doesn't need to contain this big value and thus - // txn is cheaper - // we need to explicitly change 0 to type(uint48).max, so the algorithm of intersecting - // validUntil's and validAfter's for several guardians works correctly - // @note if validAfter is less then now + securityDelay, it is set to now + securityDelay - // as for security reasons new guardian is only active after securityDelay - function addGuardian( bytes32 guardian, uint48 validUntil, @@ -308,16 +292,7 @@ contract AccountRecoveryModule is BaseAuthorizationModule { if (_guardians[guardian][msg.sender].validUntil != 0) revert GuardianAlreadySet(guardian, msg.sender); - if (validUntil == 0) validUntil = type(uint48).max; - uint48 minimalSecureValidAfter = uint48( - block.timestamp + _smartAccountSettings[msg.sender].securityDelay - ); - validAfter = validAfter < minimalSecureValidAfter - ? minimalSecureValidAfter - : validAfter; - if (validUntil < validAfter) - revert InvalidTimeFrame(validUntil, validAfter); - if (validUntil < block.timestamp) revert ExpiredValidUntil(validUntil); + (validUntil, validAfter) = _checkAndAdjustValidUntilValidAfter(validUntil, validAfter); // TODO: // make a test case that it fails if validAfter + securityDelay together overflow uint48 @@ -338,20 +313,13 @@ contract AccountRecoveryModule is BaseAuthorizationModule { uint48 validUntil, uint48 validAfter ) external { + if (_guardians[guardian][msg.sender].validUntil == 0) + revert GuardianNotSet(guardian, msg.sender); if (guardian == newGuardian) revert GuardiansAreIdentical(); if (guardian == bytes32(0)) revert ZeroGuardian(); if (newGuardian == bytes32(0)) revert ZeroGuardian(); - if (validUntil == 0) validUntil = type(uint48).max; - uint48 minimalSecureValidAfter = uint48( - block.timestamp + _smartAccountSettings[msg.sender].securityDelay - ); - validAfter = validAfter < minimalSecureValidAfter - ? minimalSecureValidAfter - : validAfter; - if (validUntil < validAfter) - revert InvalidTimeFrame(validUntil, validAfter); - if (validUntil < block.timestamp) revert ExpiredValidUntil(validUntil); + (validUntil, validAfter) = _checkAndAdjustValidUntilValidAfter(validUntil, validAfter); // make the new one valid _guardians[newGuardian][msg.sender] = TimeFrame( @@ -381,27 +349,66 @@ contract AccountRecoveryModule is BaseAuthorizationModule { // natspec function removeGuardian(bytes32 guardian) external { - delete _guardians[guardian][msg.sender]; - --_smartAccountSettings[msg.sender].guardiansCount; - if (_smartAccountSettings[msg.sender].guardiansCount == 0) - revert LastGuardianRemovalAttempt(guardian); - emit GuardianRemoved(msg.sender, guardian); + if (_guardians[guardian][msg.sender].validUntil == 0) + revert GuardianNotSet(guardian, msg.sender); + _removeGuardianAndChangeTresholdIfNeeded(guardian, msg.sender); + } + + // natspec + // REMOVE EXPIRED GUARDIAN + // not permissioned - anyone can call it but the check if the guardian is expired is on-chain + // it will allow us clearing expired guardians from the backend and maintain the list of guardians actual + function removeExpiredGuardian(bytes32 guardian, address smartAccount) external { + uint48 validUntil = _guardians[guardian][smartAccount].validUntil; + if (validUntil == 0) + revert GuardianNotSet(guardian, smartAccount); + if(validUntil>=block.timestamp) + revert GuardianNotExpired(guardian, smartAccount); + _removeGuardianAndChangeTresholdIfNeeded(guardian, smartAccount); + } + + // NOTE - if validUntil is 0, guardian is considered active forever + // Thus we put type(uint48).max as value for validUntil in this case, + // so the calldata itself doesn't need to contain this big value and thus + // txn is cheaper + // we need to explicitly change 0 to type(uint48).max, so the algorithm of intersecting + // validUntil's and validAfter's for several guardians works correctly + // @note if validAfter is less then now + securityDelay, it is set to now + securityDelay + // as for security reasons new guardian is only active after securityDelay + function _checkAndAdjustValidUntilValidAfter( + uint48 validUntil, + uint48 validAfter + ) internal view returns (uint48, uint48) { + if (validUntil == 0) validUntil = type(uint48).max; + uint48 minimalSecureValidAfter = uint48( + block.timestamp + _smartAccountSettings[msg.sender].securityDelay + ); + validAfter = validAfter < minimalSecureValidAfter + ? minimalSecureValidAfter + : validAfter; + if (validUntil < validAfter) + revert InvalidTimeFrame(validUntil, validAfter); + if (validUntil < block.timestamp) revert ExpiredValidUntil(validUntil); + return (validUntil, validAfter); + } + + function _removeGuardianAndChangeTresholdIfNeeded(bytes32 guardian, address smartAccount) internal { + delete _guardians[guardian][smartAccount]; + --_smartAccountSettings[smartAccount].guardiansCount; + emit GuardianRemoved(smartAccount, guardian); // if number of guardians became less than threshold, lower the threshold if ( - _smartAccountSettings[msg.sender].guardiansCount < - _smartAccountSettings[msg.sender].recoveryThreshold + _smartAccountSettings[smartAccount].guardiansCount < + _smartAccountSettings[smartAccount].recoveryThreshold ) { - _smartAccountSettings[msg.sender].recoveryThreshold--; + _smartAccountSettings[smartAccount].recoveryThreshold--; emit ThresholdChanged( - msg.sender, - _smartAccountSettings[msg.sender].recoveryThreshold + smartAccount, + _smartAccountSettings[smartAccount].recoveryThreshold ); } } - // DISABLE ACCOUNT RECOVERY - // Requires to explicitly list all the guardians to delete them - // change timeframe function changeGuardianParams( bytes32 guardian, diff --git a/test/module/AccountRecovery.Module.specs.ts b/test/module/AccountRecovery.Module.specs.ts index 1e46ba8c..69e1dc23 100644 --- a/test/module/AccountRecovery.Module.specs.ts +++ b/test/module/AccountRecovery.Module.specs.ts @@ -2125,11 +2125,27 @@ describe("Account Recovery Module: ", async () => { expect(guardiansAfter).to.equal(guardiansBefore + 1); }); + /* + it("Should revert if zero guardian is provided", async () => {}); + + it("Should revert if such a guardian has already been set", async () => {}); + + it("Should set validUntil to uint48.max if validUntil = 0 is provided", async () => {}); + + it("Should set validAfter as guardian.timeframe.validAfter if it is bigger than now+securityDelay", async () => {}); + + it("Should set now+securityDelay if validAfter is less than it", async () => {}); + + it("Should revert if validUntil is less than resulting validAfter", async () => {}); + + it("Should revert if validUntil < now", async () => {}); + */ + // it("_________", async () => {}); }); /* - describe("changeGuardian", async () => { + describe("replaceGuardian", async () => { it("_________", async () => {}); }); @@ -2137,8 +2153,21 @@ describe("Account Recovery Module: ", async () => { it("_________", async () => {}); }); + describe("removeExpiredGuardian", async () => { + it("_________", async () => {}); + }); - // DISABLE ACC RECOVERY + describe("changeGuardianParams", async () => { + it("_________", async () => {}); + }); + + describe("setThreshold", async () => { + it("_________", async () => {}); + }); + + describe("setSecurityDelay", async () => { + it("_________", async () => {}); + }); // Check all the errors declarations to be actually used in the contract code