diff --git a/contracts/smart-account/interfaces/modules/ISecurityPolicyPlugin.sol b/contracts/smart-account/interfaces/modules/ISecurityPolicyPlugin.sol index 25e9c069..974a79e4 100644 --- a/contracts/smart-account/interfaces/modules/ISecurityPolicyPlugin.sol +++ b/contracts/smart-account/interfaces/modules/ISecurityPolicyPlugin.sol @@ -9,8 +9,5 @@ interface ISecurityPolicyPlugin { /// set in the security policy of the smart contract wallet. /// @param _scw The address of the smart contract wallet /// @param _plugin The address of the plugin to be installed - function validateSecurityPolicy( - address _scw, - address _plugin - ) external view; + function validateSecurityPolicy(address _scw, address _plugin) external; } diff --git a/contracts/smart-account/modules/SecurityPolicyManagerPlugin.sol b/contracts/smart-account/modules/SecurityPolicyManagerPlugin.sol index 1d3e8bf3..4a195bfc 100644 --- a/contracts/smart-account/modules/SecurityPolicyManagerPlugin.sol +++ b/contracts/smart-account/modules/SecurityPolicyManagerPlugin.sol @@ -19,6 +19,9 @@ contract SecurityPolicyManagerPlugin is ISecurityPolicyManagerPlugin { address _setupContract, bytes calldata _setupData ) external override returns (address) { + // The Setup Contract must satisfy all security policies + _validateSecurityPolicies(msg.sender, _setupContract); + // Instruct the SA to install the module and return the address ISmartAccount sa = ISmartAccount(msg.sender); (bool success, bytes memory returndata) = sa @@ -37,8 +40,11 @@ contract SecurityPolicyManagerPlugin is ISecurityPolicyManagerPlugin { address module = abi.decode(returndata, (address)); - // Validate the security policies - _validateSecurityPolicies(msg.sender, module); + // If the setup contract differs from the installed module, + // Validate the module as well + if (module != _setupContract) { + _validateSecurityPolicies(msg.sender, module); + } return module; } diff --git a/test/foundry/module/SecurityPolicy/SecurityPolicyManagerPlugin.ModuleInstallation.t.sol b/test/foundry/module/SecurityPolicy/SecurityPolicyManagerPlugin.ModuleInstallation.t.sol index c8b7fd2a..844a07fd 100644 --- a/test/foundry/module/SecurityPolicy/SecurityPolicyManagerPlugin.ModuleInstallation.t.sol +++ b/test/foundry/module/SecurityPolicy/SecurityPolicyManagerPlugin.ModuleInstallation.t.sol @@ -8,13 +8,29 @@ import {SecurityPolicyManagerPlugin, SENTINEL_MODULE_ADDRESS} from "modules/Secu import {ISecurityPolicyPlugin} from "interfaces/modules/ISecurityPolicyPlugin.sol"; import {ISecurityPolicyManagerPlugin, ISecurityPolicyManagerPluginEventsErrors} from "interfaces/modules/ISecurityPolicyManagerPlugin.sol"; import {UserOperation} from "aa-core/EntryPoint.sol"; +import {MultichainECDSAValidator} from "modules/MultiChainECDSAValidator.sol"; import "forge-std/console2.sol"; contract TestSecurityPolicyPlugin is ISecurityPolicyPlugin { bool public shouldRevert; + bool public wasCalled; - function validateSecurityPolicy(address, address) external view override { - require(!shouldRevert, "TestSecurityPolicyPlugin: shouldRevert"); + mapping(address => bool) public blacklist; + + constructor() { + blacklist[address(0x2)] = true; + } + + error TestSecurityPolicyPluginError(address); + + function validateSecurityPolicy( + address, + address _plugin + ) external override { + wasCalled = true; + if (shouldRevert || blacklist[_plugin]) { + revert TestSecurityPolicyPluginError(address(this)); + } } function setShouldRevert(bool _shouldRevert) external { @@ -22,6 +38,12 @@ contract TestSecurityPolicyPlugin is ISecurityPolicyPlugin { } } +contract TestSetupContractBlacklistReturn { + function initForSmartAccount(address) external view returns (address) { + return address(0x2); + } +} + contract SecurityPolicyManagerPluginModuleInstallationTest is SATestBase, ISecurityPolicyManagerPluginEventsErrors @@ -33,6 +55,8 @@ contract SecurityPolicyManagerPluginModuleInstallationTest is TestSecurityPolicyPlugin p3; TestSecurityPolicyPlugin p4; + MultichainECDSAValidator validator; + function setUp() public virtual override { super.setUp(); @@ -95,5 +119,140 @@ contract SecurityPolicyManagerPluginModuleInstallationTest is alice ); entryPoint.handleOps(arraifyOps(op), owner.addr); + + // Create MultichainValidator + validator = new MultichainECDSAValidator(); + } + + function testModuleInstallation() external { + bytes memory setupData = abi.encodeCall( + validator.initForSmartAccount, + (alice.addr) + ); + + UserOperation memory op = makeEcdsaModuleUserOp( + getSmartAccountExecuteCalldata( + address(spmp), + 0, + abi.encodeCall( + ISecurityPolicyManagerPlugin.checkSetupAndEnableModule, + (address(validator), setupData) + ) + ), + sa, + 0, + alice + ); + + vm.expectEmit(true, true, true, true); + emit ModuleValidated(address(sa), address(validator)); + + entryPoint.handleOps(arraifyOps(op), owner.addr); + + assertTrue(p1.wasCalled()); + assertTrue(p2.wasCalled()); + assertTrue(p3.wasCalled()); + assertTrue(p4.wasCalled()); + assertTrue(sa.isModuleEnabled(address(validator))); + } + + function testShouldRevertModuleInstallationIfSecurityPolicyIsNotSatisifedOnSetupContract() + external + { + TestSetupContractBlacklistReturn blacklistReturn = new TestSetupContractBlacklistReturn(); + + bytes memory setupData = abi.encodeCall( + validator.initForSmartAccount, + (alice.addr) + ); + + UserOperation memory op = makeEcdsaModuleUserOp( + getSmartAccountExecuteCalldata( + address(spmp), + 0, + abi.encodeCall( + ISecurityPolicyManagerPlugin.checkSetupAndEnableModule, + (address(blacklistReturn), setupData) + ) + ), + sa, + 0, + alice + ); + + vm.recordLogs(); + entryPoint.handleOps(arraifyOps(op), owner.addr); + Vm.Log[] memory logs = vm.getRecordedLogs(); + UserOperationEventData memory eventData = getUserOperationEventData( + logs + ); + assertFalse(eventData.success); + UserOperationRevertReasonEventData + memory revertReasonEventData = getUserOperationRevertReasonEventData( + logs + ); + assertEq( + keccak256(revertReasonEventData.revertReason), + keccak256( + abi.encodeWithSelector( + TestSecurityPolicyPlugin + .TestSecurityPolicyPluginError + .selector, + p4 + ) + ) + ); + + assertFalse(sa.isModuleEnabled(address(validator))); + } + + function testShouldRevertModuleInstallationIfSecurityPolicyIsNotSatisifedOnInstalledPlugin() + external + { + bytes memory setupData = abi.encodeCall( + validator.initForSmartAccount, + (alice.addr) + ); + + UserOperation memory op = makeEcdsaModuleUserOp( + getSmartAccountExecuteCalldata( + address(spmp), + 0, + abi.encodeCall( + ISecurityPolicyManagerPlugin.checkSetupAndEnableModule, + (address(validator), setupData) + ) + ), + sa, + 0, + alice + ); + + p4.setShouldRevert(true); + + vm.recordLogs(); + entryPoint.handleOps(arraifyOps(op), owner.addr); + Vm.Log[] memory logs = vm.getRecordedLogs(); + UserOperationEventData memory eventData = getUserOperationEventData( + logs + ); + assertFalse(eventData.success); + UserOperationRevertReasonEventData + memory revertReasonEventData = getUserOperationRevertReasonEventData( + logs + ); + assertEq( + keccak256(revertReasonEventData.revertReason), + keccak256( + abi.encodeWithSelector( + TestSecurityPolicyPlugin + .TestSecurityPolicyPluginError + .selector, + p4 + ) + ) + ); + + assertFalse(sa.isModuleEnabled(address(validator))); } }