diff --git a/solidity/contracts/ContractTest.sol b/solidity/contracts/ContractTest.sol index 3a68fc9..676b0cd 100644 --- a/solidity/contracts/ContractTest.sol +++ b/solidity/contracts/ContractTest.sol @@ -128,6 +128,17 @@ contract ContractTest is IContractTest { _string = 'test'; } + function internalPureVirtualFunction(uint256 _newValue) + internal + pure + virtual + returns (bool _result, uint256 _value, string memory _string) + { + _result = true; + _value = _newValue; + _string = 'test'; + } + function internalNonVirtualFunction( uint256 _newValue, bool @@ -172,4 +183,12 @@ contract ContractTest is IContractTest { { (_result, _value, _string) = internalViewVirtualFunction(_newValue); } + + function callInternalPureVirtualFunction(uint256 _newValue) + public + pure + returns (bool _result, uint256 _value, string memory _string) + { + (_result, _value, _string) = internalPureVirtualFunction(_newValue); + } } diff --git a/solidity/test/ContractTest.t.sol b/solidity/test/ContractTest.t.sol index ec199df..f6502cf 100644 --- a/solidity/test/ContractTest.t.sol +++ b/solidity/test/ContractTest.t.sol @@ -352,4 +352,17 @@ contract E2EMockContractTest_Mock_call_Internal_Func is CommonE2EBase { assertEq(_res2, 1); assertEq(_res3, 'test'); } + + function test_MockCall_InternalPureVirtualFunction() public { + _contractTest.mock_call_internalPureVirtualFunction(10, false, 12, 'TEST'); + (bool _res1, uint256 _res2, string memory _res3) = _contractTest.callInternalPureVirtualFunction(10); + assertEq(_res1, false); + assertEq(_res2, 12); + assertEq(_res3, 'TEST'); + + (_res1, _res2, _res3) = _contractTest.callInternalPureVirtualFunction(11); + assertEq(_res1, true); + assertEq(_res2, 11); + assertEq(_res3, 'test'); + } } diff --git a/src/context.ts b/src/context.ts index c37d4b0..bc1934b 100644 --- a/src/context.ts +++ b/src/context.ts @@ -18,7 +18,7 @@ import { extractStructFieldsNames, extractConstructorsParameters, } from './utils'; -import { FunctionDefinition, VariableDeclaration, Identifier, ImportDirective } from 'solc-typed-ast'; +import { FunctionDefinition, VariableDeclaration, Identifier, ImportDirective, FunctionStateMutability } from 'solc-typed-ast'; export function internalFunctionContext(node: FunctionDefinition): InternalFunctionContext { // Check if the function is internal @@ -27,7 +27,9 @@ export function internalFunctionContext(node: FunctionDefinition): InternalFunct if (!node.virtual) throw new Error('The function is not virtual'); const { functionParameters, parameterTypes, parameterNames } = extractParameters(node.vParameters.vParameters); - const { functionReturnParameters, returnParameterTypes, returnParameterNames } = extractReturnParameters(node.vReturnParameters.vParameters); + const { functionReturnParameters, returnParameterTypes, returnParameterNames, returnExplicitParameterTypes } = extractReturnParameters( + node.vReturnParameters.vParameters, + ); const signature = parameterTypes ? `${node.name}(${parameterTypes.join(',')})` : `${node.name}()`; // Create the string that will be used in the mock function signature @@ -43,8 +45,9 @@ export function internalFunctionContext(node: FunctionDefinition): InternalFunct params = `${inputs}, ${outputs}`; } - // Check if the function is view - const isView = node.stateMutability === 'view'; + // Check if the function is view or pure + const isView = node.stateMutability === FunctionStateMutability.View; + const isPure = node.stateMutability === FunctionStateMutability.Pure; // Save the internal function information return { @@ -55,9 +58,11 @@ export function internalFunctionContext(node: FunctionDefinition): InternalFunct outputs: outputs, inputTypes: parameterTypes, outputTypes: returnParameterTypes, + explicitOutputTypes: returnExplicitParameterTypes, inputNames: parameterNames, outputNames: returnParameterNames, isView: isView, + isPure: isPure, implemented: node.implemented, }; } diff --git a/src/templates/partials/internal-function.hbs b/src/templates/partials/internal-function.hbs index 4efc9ae..fd530f4 100644 --- a/src/templates/partials/internal-function.hbs +++ b/src/templates/partials/internal-function.hbs @@ -1,62 +1,17 @@ -{{#if isView}} - {{#if outputs}} - struct {{functionName}}Output { - {{#each outputTypes}} - {{this}} {{lookup ../outputNames @index}}; - {{/each}} - } - - mapping(bytes32 => {{functionName}}Output) private {{functionName}}Outputs; - {{/if}} - - bytes32[] private {{functionName}}InputHashes; -{{/if}} - function mock_call_{{functionName}}({{parameters}}) public { - {{#if isView}} - bytes32 _key = keccak256(abi.encode({{#each inputNames}}{{this}}{{#unless @last}}, {{/unless}}{{/each}})); - {{#if outputs}} - {{functionName}}Outputs[_key] = {{functionName}}Output({{#each outputNames}}{{this}}{{#unless @last}}, {{/unless}}{{/each}}); - {{/if}} - - for (uint256 _i; _i < {{functionName}}InputHashes.length; ++_i) { - if (_key == {{functionName}}InputHashes[_i]) return; - } - - {{functionName}}InputHashes.push(_key); - {{else}} - vm.mockCall( - address(this), - abi.encodeWithSignature('{{signature}}'{{#if inputs}}, {{#each inputNames}}{{this}}{{#unless @last}}, {{/unless}}{{/each}}{{/if}}), - abi.encode({{#each outputNames}}{{this}}{{#unless @last}}, {{/unless}}{{/each}}) - ); - {{/if}} + vm.mockCall( + address(this), + abi.encodeWithSignature('{{signature}}'{{#if inputs}}, {{#each inputNames}}{{this}}{{#unless @last}}, {{/unless}}{{/each}}{{/if}}), + abi.encode({{#each outputNames}}{{this}}{{#unless @last}}, {{/unless}}{{/each}}) + ); } -function {{functionName}}({{inputs}}) internal {{#if isView}}view {{/if}}override {{#if outputs}}returns ({{outputs}}){{/if}} { - {{#if isView}} - bytes32 _key = keccak256(abi.encode({{#each inputNames}}{{this}}{{#unless @last}}, {{/unless}}{{/each}})); - - for (uint256 _i; _i < {{functionName}}InputHashes.length; ++_i) { - if (_key == {{functionName}}InputHashes[_i]) { - {{#if outputs}} - {{functionName}}Output memory _output = {{functionName}}Outputs[_key]; - {{/if}} - - return ({{#each outputNames}}_output.{{this}}{{#unless @last}}, {{/unless}}{{/each}}); - } - } - - {{#if implemented}} - return super.{{functionName}}({{#each inputNames}}{{this}}{{#unless @last}}, {{/unless}}{{/each}}); - {{/if}} - {{else}} - (bool _success, bytes memory _data) = address(this).call(abi.encodeWithSignature('{{signature}}'{{#if inputs}}, {{#each inputNames}}{{this}}{{#unless @last}}, {{/unless}}{{/each}}{{/if}})); - - if (_success) return abi.decode(_data, ({{#each outputTypes}}{{this}}{{#unless @last}}, {{/unless}}{{/each}})); - - {{#if implemented}} - else return super.{{functionName}}({{#each inputNames}}{{this}}{{#unless @last}}, {{/unless}}{{/each}}); - {{/if}} +function {{functionName}}({{inputs}}) internal override {{#if outputs}}returns ({{outputs}}){{/if}} { + (bool _success, bytes memory _data) = address(this).call(abi.encodeWithSignature('{{signature}}'{{#if inputs}}, {{#each inputNames}}{{this}}{{#unless @last}}, {{/unless}}{{/each}}{{/if}})); + + if (_success) return abi.decode(_data, ({{#each outputTypes}}{{this}}{{#unless @last}}, {{/unless}}{{/each}})); + + {{#if implemented}} + else return super.{{functionName}}({{#each inputNames}}{{this}}{{#unless @last}}, {{/unless}}{{/each}}); {{/if}} } diff --git a/src/templates/partials/internal-view-function.hbs b/src/templates/partials/internal-view-function.hbs new file mode 100644 index 0000000..56ea106 --- /dev/null +++ b/src/templates/partials/internal-view-function.hbs @@ -0,0 +1,60 @@ +{{#if outputs}} + struct {{functionName}}Output { + {{#each outputTypes}} + {{this}} {{lookup ../outputNames @index}}; + {{/each}} + } + + mapping(bytes32 => {{functionName}}Output) private {{functionName}}Outputs; +{{/if}} + +bytes32[] private {{functionName}}InputHashes; + +function mock_call_{{functionName}}({{parameters}}) public { + bytes32 _key = keccak256(abi.encode({{#each inputNames}}{{this}}{{#unless @last}}, {{/unless}}{{/each}})); + {{#if outputs}} + {{functionName}}Outputs[_key] = {{functionName}}Output({{#each outputNames}}{{this}}{{#unless @last}}, {{/unless}}{{/each}}); + {{/if}} + + for (uint256 _i; _i < {{functionName}}InputHashes.length; ++_i) { + if (_key == {{functionName}}InputHashes[_i]) return; + } + + {{functionName}}InputHashes.push(_key); +} + +function {{functionName}}{{#if isPure}}Helper{{/if}}({{inputs}}) internal view {{#unless isPure}}override{{/unless}} {{#if outputs}}returns ({{outputs}}){{/if}} { + bytes32 _key = keccak256(abi.encode({{#each inputNames}}{{this}}{{#unless @last}}, {{/unless}}{{/each}})); + + for (uint256 _i; _i < {{functionName}}InputHashes.length; ++_i) { + if (_key == {{functionName}}InputHashes[_i]) { + {{#if outputs}} + {{functionName}}Output memory _output = {{functionName}}Outputs[_key]; + {{/if}} + + return ({{#each outputNames}}_output.{{this}}{{#unless @last}}, {{/unless}}{{/each}}); + } + } + + {{#if implemented}} + return super.{{functionName}}({{#each inputNames}}{{this}}{{#unless @last}}, {{/unless}}{{/each}}); + {{/if}} +} + +{{#if isPure}} + + function _{{functionName}}CastToPure(function({{#each inputTypes}}{{this}}{{#unless @last}}, {{/unless}}{{/each}}) internal view {{#if outputs}}returns ({{explicitOutputTypes}}){{/if}} fnIn) + internal + pure + returns (function({{#each inputTypes}}{{this}}{{#unless @last}}, {{/unless}}{{/each}}) internal pure {{#if outputs}}returns ({{explicitOutputTypes}}){{/if}} fnOut) + { + assembly { + fnOut := fnIn + } + } + + function {{functionName}}({{inputs}}) internal pure override {{#if outputs}}returns ({{outputs}}){{/if}} { + return _{{functionName}}CastToPure({{functionName}}Helper)({{#each inputNames}}{{this}}{{#unless @last}}, {{/unless}}{{/each}}); + } + +{{/if}} \ No newline at end of file diff --git a/src/types.ts b/src/types.ts index d82ae5d..acad2ae 100644 --- a/src/types.ts +++ b/src/types.ts @@ -26,7 +26,9 @@ export interface ExternalFunctionContext { export interface InternalFunctionContext extends Omit { inputTypes: string[]; outputTypes: string[]; + explicitOutputTypes: string[]; isView: boolean; + isPure: boolean; } export interface ImportContext { diff --git a/src/utils.ts b/src/utils.ts index fbae3db..7685fbf 100644 --- a/src/utils.ts +++ b/src/utils.ts @@ -159,6 +159,7 @@ export function extractReturnParameters(returnParameters: VariableDeclaration[]) functionReturnParameters: string[]; returnParameterTypes: string[]; returnParameterNames: string[]; + returnExplicitParameterTypes: string[]; } { const functionReturnParameters = returnParameters.map((parameter: VariableDeclaration, index: number) => { const typeName: string = sanitizeParameterType(parameter.typeString); @@ -169,11 +170,15 @@ export function extractReturnParameters(returnParameters: VariableDeclaration[]) const returnParameterTypes = returnParameters.map((parameter) => sanitizeParameterType(parameter.typeString)); const returnParameterNames = returnParameters.map((parameter, index) => parameter.name || `_returnParam${index}`); + const returnExplicitParameterTypes = returnParameters.map((parameter) => + explicitTypeStorageLocation(sanitizeParameterType(parameter.typeString)), + ); return { functionReturnParameters, returnParameterTypes, returnParameterNames, + returnExplicitParameterTypes, }; } @@ -188,6 +193,7 @@ export async function renderNodeMock(node: ASTNode): Promise { constructor: constructorContext, 'external-or-public-function': externalOrPublicFunctionContext, 'internal-function': internalFunctionContext, + 'internal-view-function': internalFunctionContext, import: importContext, }; @@ -217,7 +223,11 @@ export function partialName(node: ASTNode): string { } else if (node.visibility === 'external' || node.visibility === 'public') { return 'external-or-public-function'; } else if (node.visibility === 'internal' && node.virtual) { - return 'internal-function'; + if (node.stateMutability === 'view' || node.stateMutability === 'pure') { + return 'internal-view-function'; + } else { + return 'internal-function'; + } } } else if (node instanceof ImportDirective) { return 'import'; diff --git a/test/unit/context/internalFunctionContext.spec.ts b/test/unit/context/internalFunctionContext.spec.ts index e516ab9..5191273 100644 --- a/test/unit/context/internalFunctionContext.spec.ts +++ b/test/unit/context/internalFunctionContext.spec.ts @@ -42,6 +42,8 @@ describe('internalFunctionContext', () => { outputTypes: [], implemented: true, isView: false, + explicitOutputTypes: [], + isPure: false, }); }); @@ -65,6 +67,8 @@ describe('internalFunctionContext', () => { outputTypes: [], implemented: true, isView: false, + explicitOutputTypes: [], + isPure: false, }); }); @@ -85,6 +89,8 @@ describe('internalFunctionContext', () => { outputTypes: [], implemented: true, isView: false, + explicitOutputTypes: [], + isPure: false, }); }); @@ -105,6 +111,8 @@ describe('internalFunctionContext', () => { outputTypes: ['uint256', 'boolean'], implemented: true, isView: false, + explicitOutputTypes: ['uint256', 'boolean'], + isPure: false, }); }); @@ -128,6 +136,8 @@ describe('internalFunctionContext', () => { outputTypes: ['uint256', 'boolean'], implemented: true, isView: false, + explicitOutputTypes: ['uint256', 'boolean'], + isPure: false, }); }); @@ -153,6 +163,8 @@ describe('internalFunctionContext', () => { outputTypes: [], implemented: true, isView: false, + explicitOutputTypes: [], + isPure: false, }); }); @@ -178,6 +190,8 @@ describe('internalFunctionContext', () => { outputTypes: ['uint256', 'string', 'bytes', 'boolean'], implemented: true, isView: false, + explicitOutputTypes: ['uint256', 'string memory', 'bytes memory', 'boolean'], + isPure: false, }); }); @@ -190,6 +204,15 @@ describe('internalFunctionContext', () => { } }); + it('determines whether the function is pure or not', () => { + for (const stateMutability in FunctionStateMutability) { + const isPure = FunctionStateMutability[stateMutability] === FunctionStateMutability.Pure; + const node = mockFunctionDefinition({ ...defaultAttributes, stateMutability: FunctionStateMutability[stateMutability] }); + const context = internalFunctionContext(node); + expect(context.isPure).to.be.equal(isPure); + } + }); + it('determines whether the function is implemented or not', () => { for (const implemented of [true, false]) { const node = mockFunctionDefinition({ ...defaultAttributes, implemented: implemented }); @@ -232,6 +255,8 @@ describe('internalFunctionContext', () => { outputTypes: [], implemented: true, isView: false, + explicitOutputTypes: [], + isPure: false, }); }); });