Skip to content

Commit

Permalink
feat: mock internal pure functions (#69)
Browse files Browse the repository at this point in the history
  • Loading branch information
agusduha authored Apr 23, 2024
1 parent 143e5c4 commit ebb0895
Show file tree
Hide file tree
Showing 8 changed files with 151 additions and 62 deletions.
19 changes: 19 additions & 0 deletions solidity/contracts/ContractTest.sol
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,17 @@ contract ContractTest is IContractTest {
_string = 'test';
}

function internalPureVirtualFunction(uint256 _newValue)
internal
pure
virtual
returns (bool _result, uint256 _value, string memory _string)
{
_result = true;
_value = _newValue;
_string = 'test';
}

function internalNonVirtualFunction(
uint256 _newValue,
bool
Expand Down Expand Up @@ -172,4 +183,12 @@ contract ContractTest is IContractTest {
{
(_result, _value, _string) = internalViewVirtualFunction(_newValue);
}

function callInternalPureVirtualFunction(uint256 _newValue)
public
pure
returns (bool _result, uint256 _value, string memory _string)
{
(_result, _value, _string) = internalPureVirtualFunction(_newValue);
}
}
13 changes: 13 additions & 0 deletions solidity/test/ContractTest.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -352,4 +352,17 @@ contract E2EMockContractTest_Mock_call_Internal_Func is CommonE2EBase {
assertEq(_res2, 1);
assertEq(_res3, 'test');
}

function test_MockCall_InternalPureVirtualFunction() public {
_contractTest.mock_call_internalPureVirtualFunction(10, false, 12, 'TEST');
(bool _res1, uint256 _res2, string memory _res3) = _contractTest.callInternalPureVirtualFunction(10);
assertEq(_res1, false);
assertEq(_res2, 12);
assertEq(_res3, 'TEST');

(_res1, _res2, _res3) = _contractTest.callInternalPureVirtualFunction(11);
assertEq(_res1, true);
assertEq(_res2, 11);
assertEq(_res3, 'test');
}
}
13 changes: 9 additions & 4 deletions src/context.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import {
extractStructFieldsNames,
extractConstructorsParameters,
} from './utils';
import { FunctionDefinition, VariableDeclaration, Identifier, ImportDirective } from 'solc-typed-ast';
import { FunctionDefinition, VariableDeclaration, Identifier, ImportDirective, FunctionStateMutability } from 'solc-typed-ast';

export function internalFunctionContext(node: FunctionDefinition): InternalFunctionContext {
// Check if the function is internal
Expand All @@ -27,7 +27,9 @@ export function internalFunctionContext(node: FunctionDefinition): InternalFunct
if (!node.virtual) throw new Error('The function is not virtual');

const { functionParameters, parameterTypes, parameterNames } = extractParameters(node.vParameters.vParameters);
const { functionReturnParameters, returnParameterTypes, returnParameterNames } = extractReturnParameters(node.vReturnParameters.vParameters);
const { functionReturnParameters, returnParameterTypes, returnParameterNames, returnExplicitParameterTypes } = extractReturnParameters(
node.vReturnParameters.vParameters,
);
const signature = parameterTypes ? `${node.name}(${parameterTypes.join(',')})` : `${node.name}()`;

// Create the string that will be used in the mock function signature
Expand All @@ -43,8 +45,9 @@ export function internalFunctionContext(node: FunctionDefinition): InternalFunct
params = `${inputs}, ${outputs}`;
}

// Check if the function is view
const isView = node.stateMutability === 'view';
// Check if the function is view or pure
const isView = node.stateMutability === FunctionStateMutability.View;
const isPure = node.stateMutability === FunctionStateMutability.Pure;

// Save the internal function information
return {
Expand All @@ -55,9 +58,11 @@ export function internalFunctionContext(node: FunctionDefinition): InternalFunct
outputs: outputs,
inputTypes: parameterTypes,
outputTypes: returnParameterTypes,
explicitOutputTypes: returnExplicitParameterTypes,
inputNames: parameterNames,
outputNames: returnParameterNames,
isView: isView,
isPure: isPure,
implemented: node.implemented,
};
}
Expand Down
69 changes: 12 additions & 57 deletions src/templates/partials/internal-function.hbs
Original file line number Diff line number Diff line change
@@ -1,62 +1,17 @@
{{#if isView}}
{{#if outputs}}
struct {{functionName}}Output {
{{#each outputTypes}}
{{this}} {{lookup ../outputNames @index}};
{{/each}}
}

mapping(bytes32 => {{functionName}}Output) private {{functionName}}Outputs;
{{/if}}

bytes32[] private {{functionName}}InputHashes;
{{/if}}

function mock_call_{{functionName}}({{parameters}}) public {
{{#if isView}}
bytes32 _key = keccak256(abi.encode({{#each inputNames}}{{this}}{{#unless @last}}, {{/unless}}{{/each}}));
{{#if outputs}}
{{functionName}}Outputs[_key] = {{functionName}}Output({{#each outputNames}}{{this}}{{#unless @last}}, {{/unless}}{{/each}});
{{/if}}

for (uint256 _i; _i < {{functionName}}InputHashes.length; ++_i) {
if (_key == {{functionName}}InputHashes[_i]) return;
}

{{functionName}}InputHashes.push(_key);
{{else}}
vm.mockCall(
address(this),
abi.encodeWithSignature('{{signature}}'{{#if inputs}}, {{#each inputNames}}{{this}}{{#unless @last}}, {{/unless}}{{/each}}{{/if}}),
abi.encode({{#each outputNames}}{{this}}{{#unless @last}}, {{/unless}}{{/each}})
);
{{/if}}
vm.mockCall(
address(this),
abi.encodeWithSignature('{{signature}}'{{#if inputs}}, {{#each inputNames}}{{this}}{{#unless @last}}, {{/unless}}{{/each}}{{/if}}),
abi.encode({{#each outputNames}}{{this}}{{#unless @last}}, {{/unless}}{{/each}})
);
}

function {{functionName}}({{inputs}}) internal {{#if isView}}view {{/if}}override {{#if outputs}}returns ({{outputs}}){{/if}} {
{{#if isView}}
bytes32 _key = keccak256(abi.encode({{#each inputNames}}{{this}}{{#unless @last}}, {{/unless}}{{/each}}));

for (uint256 _i; _i < {{functionName}}InputHashes.length; ++_i) {
if (_key == {{functionName}}InputHashes[_i]) {
{{#if outputs}}
{{functionName}}Output memory _output = {{functionName}}Outputs[_key];
{{/if}}

return ({{#each outputNames}}_output.{{this}}{{#unless @last}}, {{/unless}}{{/each}});
}
}

{{#if implemented}}
return super.{{functionName}}({{#each inputNames}}{{this}}{{#unless @last}}, {{/unless}}{{/each}});
{{/if}}
{{else}}
(bool _success, bytes memory _data) = address(this).call(abi.encodeWithSignature('{{signature}}'{{#if inputs}}, {{#each inputNames}}{{this}}{{#unless @last}}, {{/unless}}{{/each}}{{/if}}));

if (_success) return abi.decode(_data, ({{#each outputTypes}}{{this}}{{#unless @last}}, {{/unless}}{{/each}}));

{{#if implemented}}
else return super.{{functionName}}({{#each inputNames}}{{this}}{{#unless @last}}, {{/unless}}{{/each}});
{{/if}}
function {{functionName}}({{inputs}}) internal override {{#if outputs}}returns ({{outputs}}){{/if}} {
(bool _success, bytes memory _data) = address(this).call(abi.encodeWithSignature('{{signature}}'{{#if inputs}}, {{#each inputNames}}{{this}}{{#unless @last}}, {{/unless}}{{/each}}{{/if}}));

if (_success) return abi.decode(_data, ({{#each outputTypes}}{{this}}{{#unless @last}}, {{/unless}}{{/each}}));

{{#if implemented}}
else return super.{{functionName}}({{#each inputNames}}{{this}}{{#unless @last}}, {{/unless}}{{/each}});
{{/if}}
}
60 changes: 60 additions & 0 deletions src/templates/partials/internal-view-function.hbs
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
{{#if outputs}}
struct {{functionName}}Output {
{{#each outputTypes}}
{{this}} {{lookup ../outputNames @index}};
{{/each}}
}

mapping(bytes32 => {{functionName}}Output) private {{functionName}}Outputs;
{{/if}}

bytes32[] private {{functionName}}InputHashes;

function mock_call_{{functionName}}({{parameters}}) public {
bytes32 _key = keccak256(abi.encode({{#each inputNames}}{{this}}{{#unless @last}}, {{/unless}}{{/each}}));
{{#if outputs}}
{{functionName}}Outputs[_key] = {{functionName}}Output({{#each outputNames}}{{this}}{{#unless @last}}, {{/unless}}{{/each}});
{{/if}}

for (uint256 _i; _i < {{functionName}}InputHashes.length; ++_i) {
if (_key == {{functionName}}InputHashes[_i]) return;
}

{{functionName}}InputHashes.push(_key);
}

function {{functionName}}{{#if isPure}}Helper{{/if}}({{inputs}}) internal view {{#unless isPure}}override{{/unless}} {{#if outputs}}returns ({{outputs}}){{/if}} {
bytes32 _key = keccak256(abi.encode({{#each inputNames}}{{this}}{{#unless @last}}, {{/unless}}{{/each}}));

for (uint256 _i; _i < {{functionName}}InputHashes.length; ++_i) {
if (_key == {{functionName}}InputHashes[_i]) {
{{#if outputs}}
{{functionName}}Output memory _output = {{functionName}}Outputs[_key];
{{/if}}

return ({{#each outputNames}}_output.{{this}}{{#unless @last}}, {{/unless}}{{/each}});
}
}

{{#if implemented}}
return super.{{functionName}}({{#each inputNames}}{{this}}{{#unless @last}}, {{/unless}}{{/each}});
{{/if}}
}

{{#if isPure}}

function _{{functionName}}CastToPure(function({{#each inputTypes}}{{this}}{{#unless @last}}, {{/unless}}{{/each}}) internal view {{#if outputs}}returns ({{explicitOutputTypes}}){{/if}} fnIn)
internal
pure
returns (function({{#each inputTypes}}{{this}}{{#unless @last}}, {{/unless}}{{/each}}) internal pure {{#if outputs}}returns ({{explicitOutputTypes}}){{/if}} fnOut)
{
assembly {
fnOut := fnIn
}
}

function {{functionName}}({{inputs}}) internal pure override {{#if outputs}}returns ({{outputs}}){{/if}} {
return _{{functionName}}CastToPure({{functionName}}Helper)({{#each inputNames}}{{this}}{{#unless @last}}, {{/unless}}{{/each}});
}

{{/if}}
2 changes: 2 additions & 0 deletions src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ export interface ExternalFunctionContext {
export interface InternalFunctionContext extends Omit<ExternalFunctionContext, 'visibility' | 'stateMutability'> {
inputTypes: string[];
outputTypes: string[];
explicitOutputTypes: string[];
isView: boolean;
isPure: boolean;
}

export interface ImportContext {
Expand Down
12 changes: 11 additions & 1 deletion src/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ export function extractReturnParameters(returnParameters: VariableDeclaration[])
functionReturnParameters: string[];
returnParameterTypes: string[];
returnParameterNames: string[];
returnExplicitParameterTypes: string[];
} {
const functionReturnParameters = returnParameters.map((parameter: VariableDeclaration, index: number) => {
const typeName: string = sanitizeParameterType(parameter.typeString);
Expand All @@ -169,11 +170,15 @@ export function extractReturnParameters(returnParameters: VariableDeclaration[])

const returnParameterTypes = returnParameters.map((parameter) => sanitizeParameterType(parameter.typeString));
const returnParameterNames = returnParameters.map((parameter, index) => parameter.name || `_returnParam${index}`);
const returnExplicitParameterTypes = returnParameters.map((parameter) =>
explicitTypeStorageLocation(sanitizeParameterType(parameter.typeString)),
);

return {
functionReturnParameters,
returnParameterTypes,
returnParameterNames,
returnExplicitParameterTypes,
};
}

Expand All @@ -188,6 +193,7 @@ export async function renderNodeMock(node: ASTNode): Promise<string> {
constructor: constructorContext,
'external-or-public-function': externalOrPublicFunctionContext,
'internal-function': internalFunctionContext,
'internal-view-function': internalFunctionContext,
import: importContext,
};

Expand Down Expand Up @@ -217,7 +223,11 @@ export function partialName(node: ASTNode): string {
} else if (node.visibility === 'external' || node.visibility === 'public') {
return 'external-or-public-function';
} else if (node.visibility === 'internal' && node.virtual) {
return 'internal-function';
if (node.stateMutability === 'view' || node.stateMutability === 'pure') {
return 'internal-view-function';
} else {
return 'internal-function';
}
}
} else if (node instanceof ImportDirective) {
return 'import';
Expand Down
25 changes: 25 additions & 0 deletions test/unit/context/internalFunctionContext.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ describe('internalFunctionContext', () => {
outputTypes: [],
implemented: true,
isView: false,
explicitOutputTypes: [],
isPure: false,
});
});

Expand All @@ -65,6 +67,8 @@ describe('internalFunctionContext', () => {
outputTypes: [],
implemented: true,
isView: false,
explicitOutputTypes: [],
isPure: false,
});
});

Expand All @@ -85,6 +89,8 @@ describe('internalFunctionContext', () => {
outputTypes: [],
implemented: true,
isView: false,
explicitOutputTypes: [],
isPure: false,
});
});

Expand All @@ -105,6 +111,8 @@ describe('internalFunctionContext', () => {
outputTypes: ['uint256', 'boolean'],
implemented: true,
isView: false,
explicitOutputTypes: ['uint256', 'boolean'],
isPure: false,
});
});

Expand All @@ -128,6 +136,8 @@ describe('internalFunctionContext', () => {
outputTypes: ['uint256', 'boolean'],
implemented: true,
isView: false,
explicitOutputTypes: ['uint256', 'boolean'],
isPure: false,
});
});

Expand All @@ -153,6 +163,8 @@ describe('internalFunctionContext', () => {
outputTypes: [],
implemented: true,
isView: false,
explicitOutputTypes: [],
isPure: false,
});
});

Expand All @@ -178,6 +190,8 @@ describe('internalFunctionContext', () => {
outputTypes: ['uint256', 'string', 'bytes', 'boolean'],
implemented: true,
isView: false,
explicitOutputTypes: ['uint256', 'string memory', 'bytes memory', 'boolean'],
isPure: false,
});
});

Expand All @@ -190,6 +204,15 @@ describe('internalFunctionContext', () => {
}
});

it('determines whether the function is pure or not', () => {
for (const stateMutability in FunctionStateMutability) {
const isPure = FunctionStateMutability[stateMutability] === FunctionStateMutability.Pure;
const node = mockFunctionDefinition({ ...defaultAttributes, stateMutability: FunctionStateMutability[stateMutability] });
const context = internalFunctionContext(node);
expect(context.isPure).to.be.equal(isPure);
}
});

it('determines whether the function is implemented or not', () => {
for (const implemented of [true, false]) {
const node = mockFunctionDefinition({ ...defaultAttributes, implemented: implemented });
Expand Down Expand Up @@ -232,6 +255,8 @@ describe('internalFunctionContext', () => {
outputTypes: [],
implemented: true,
isView: false,
explicitOutputTypes: [],
isPure: false,
});
});
});

0 comments on commit ebb0895

Please sign in to comment.