Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: fix mock calling structs #59

Merged
merged 2 commits into from
Apr 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion solidity/contracts/utils/ContractG.sol
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,12 @@ contract ContractG {
Set _set;
}

mapping(bytes32 _disputeId => bool _finished) internal _finished;
struct NestedStruct {
uint256 _counter;
CommonStruct _common;
}

mapping(bytes32 _disputeId => bool _finished) public _finished;

mapping(bytes32 _disputeId => Set _votersSet) internal _votersA;

Expand All @@ -36,4 +41,8 @@ contract ContractG {
mapping(bytes32 _disputeId => AddressSets _votersSets) internal _votersD;

mapping(bytes32 _disputeId => ComplexStruct _complexStruct) internal _complexStructs;

mapping(bytes32 _disputeId => NestedStruct _nestedStruct) public _nestedStructs;

NestedStruct public nestedStruct;
}
48 changes: 24 additions & 24 deletions solidity/test/ContractTest.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -205,12 +205,12 @@ contract E2EMockContractTest_Mock_call_Simple_Vars is CommonE2EBase {
assertEq(_contractTest.bytes32Variable(), bytes32('40'));
}

// function test_MockCall_MyStructVar() public {
// _contractTest.mock_call_myStructVariable(IContractTest.MyStruct(100, 'hundred'));
// (uint256 _value, string memory _name) = _contractTest.myStructVariable();
// assertEq(_value, 100);
// assertEq(_name, 'hundred');
// }
function test_MockCall_MyStructVar() public {
_contractTest.mock_call_myStructVariable(IContractTest.MyStruct(100, 'hundred'));
(uint256 _value, string memory _name) = _contractTest.myStructVariable();
assertEq(_value, 100);
assertEq(_name, 'hundred');
}

function test_MockCall_InternalUintVar_Fail() public {
// no mock calls for internal vars
Expand All @@ -236,12 +236,12 @@ contract E2EMockContractTest_Mock_call_Array_Vars is CommonE2EBase {
assertEq(_contractTest.bytes32Array(0), bytes32('40'));
}

// function test_MockCall_MyStructArray() public {
// _contractTest.mock_call_myStructArray(0, IContractTest.MyStruct(100, 'hundred'));
// (uint256 _value, string memory _name) = _contractTest.myStructArray(0);
// assertEq(_value, 100);
// assertEq(_name, 'hundred');
// }
function test_MockCall_MyStructArray() public {
_contractTest.mock_call_myStructArray(0, IContractTest.MyStruct(100, 'hundred'));
(uint256 _value, string memory _name) = _contractTest.myStructArray(0);
assertEq(_value, 100);
assertEq(_name, 'hundred');
}

function test_MockCall_InternalAddressArray_Fail() public {
// no mock calls for internal vars
Expand All @@ -267,24 +267,24 @@ contract E2EMockContractTest_Mock_call_Mapping_Vars is CommonE2EBase {
assertEq(_contractTest.bytes32ToBytes(bytes32('40')), bytes('50'));
}

// function test_MockCall_Uint256ToMyStructMappings() public {
// _contractTest.mock_call_uint256ToMyStruct(10, IContractTest.MyStruct(100, 'hundred'));
// (uint256 _value, string memory _name) = _contractTest.uint256ToMyStruct(10);
// assertEq(_value, 100);
// assertEq(_name, 'hundred');
// }
function test_MockCall_Uint256ToMyStructMappings() public {
_contractTest.mock_call_uint256ToMyStruct(10, IContractTest.MyStruct(100, 'hundred'));
(uint256 _value, string memory _name) = _contractTest.uint256ToMyStruct(10);
assertEq(_value, 100);
assertEq(_name, 'hundred');
}

function test_MockCall_Uint256ToAddressArrayMappings() public {
_contractTest.mock_call_uint256ToAddressArray(10, 0, _user);
assertEq(_contractTest.uint256ToAddressArray(10, 0), _user);
}

// function test_MockCall_Uint256ToMyStructArrayMappings() public {
// _contractTest.mock_call_uint256ToMyStructArray(10, 0, IContractTest.MyStruct(100, 'hundred'));
// (uint256 _value, string memory _name) = _contractTest.uint256ToMyStructArray(10, 0);
// assertEq(_value, 100);
// assertEq(_name, 'hundred');
// }
function test_MockCall_Uint256ToMyStructArrayMappings() public {
_contractTest.mock_call_uint256ToMyStructArray(10, 0, IContractTest.MyStruct(100, 'hundred'));
(uint256 _value, string memory _name) = _contractTest.uint256ToMyStructArray(10, 0);
assertEq(_value, 100);
assertEq(_name, 'hundred');
}

function test_MockCall_Uint256ToAddressToBytes32Mappings() public {
_contractTest.mock_call_uint256ToAddressToBytes32(10, _owner, bytes32('40'));
Expand Down
15 changes: 15 additions & 0 deletions src/context.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import {
extractReturnParameters,
extractOverrides,
hasNestedMappings,
extractStructFieldsNames,
} from './utils';
import { ContractDefinition, FunctionDefinition, VariableDeclaration, Identifier, ImportDirective } from 'solc-typed-ast';

Expand Down Expand Up @@ -165,6 +166,9 @@ export function mappingVariableContext(node: VariableDeclaration): MappingVariab
// Check if value is a struct and has nested mappings
const hasNestedMapping = hasNestedMappings(mappingTypeNameNode);

// Check if the variable is a struct and get its fields
const structFields = extractStructFieldsNames(mappingTypeNameNode);

// If the mapping is internal we don't create mockCall for it
const isInternal: boolean = node.visibility === 'internal';

Expand All @@ -179,9 +183,11 @@ export function mappingVariableContext(node: VariableDeclaration): MappingVariab
keyTypes: keyTypes,
valueType: valueType,
baseType: baseType,
structFields,
},
isInternal: isInternal,
isArray: isArray,
isStruct: structFields.length > 0,
isStructArray: isStructArray,
hasNestedMapping,
};
Expand All @@ -200,6 +206,9 @@ export function arrayVariableContext(node: VariableDeclaration): ArrayVariableCo
// Struct flag
const isStructArray: boolean = node.typeString.startsWith('struct ');

// Check if the variable is a struct and get its fields
const structFields = extractStructFieldsNames(node.vType);

// If the array is internal we don't create mockCall for it
const isInternal: boolean = node.visibility === 'internal';

Expand All @@ -213,6 +222,7 @@ export function arrayVariableContext(node: VariableDeclaration): ArrayVariableCo
functionName: arrayName,
arrayType: arrayType,
baseType: baseType,
structFields,
},
isInternal: isInternal,
isStructArray: isStructArray,
Expand All @@ -229,6 +239,9 @@ export function stateVariableContext(node: VariableDeclaration): StateVariableCo
// If the variable is internal we don't create mockCall for it
const isInternal: boolean = node.visibility === 'internal';

// Check if the variable is a struct and get its fields
const structFields = extractStructFieldsNames(node.vType);

// Save the state variable information
return {
setFunction: {
Expand All @@ -239,7 +252,9 @@ export function stateVariableContext(node: VariableDeclaration): StateVariableCo
mockFunction: {
functionName: variableName,
paramType: variableType,
structFields,
},
isInternal: isInternal,
isStruct: structFields.length > 0,
};
}
2 changes: 1 addition & 1 deletion src/templates/partials/array-state-variable.hbs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ function mock_call_{{mockFunction.functionName}}(uint256 _index, {{mockFunction.
vm.mockCall(
address(this),
abi.encodeWithSignature('{{mockFunction.functionName}}(uint256)', _index),
abi.encode(_value)
abi.encode({{#if isStructArray}}{{#each mockFunction.structFields}}_value.{{this}}{{#unless @last}}, {{/unless}}{{/each}}{{else}}_value{{/if}})
);
}
{{/unless}}
2 changes: 1 addition & 1 deletion src/templates/partials/mapping-state-variable.hbs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ function mock_call_{{mockFunction.functionName}}({{#each mockFunction.keyTypes}}
vm.mockCall(
address(this),
abi.encodeWithSignature('{{mockFunction.functionName}}({{#each mockFunction.keyTypes}}{{this}}{{#unless @last}},{{/unless}}{{/each}}{{#if isArray}},uint256{{/if}})'{{#each mockFunction.keyTypes}}, _key{{@index}}{{/each}}{{#if isArray}}, _index{{/if}}),
abi.encode(_value)
abi.encode({{#if isStruct}}{{#each mockFunction.structFields}}_value.{{this}}{{#unless @last}}, {{/unless}}{{/each}}{{else}}_value{{/if}})
);
}
{{/unless}}
Expand Down
4 changes: 2 additions & 2 deletions src/templates/partials/state-variable.hbs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@ function set_{{setFunction.functionName}}({{setFunction.paramType}} _{{setFuncti
}
{{#unless isInternal}}

function mock_call_{{mockFunction.functionName}}({{mockFunction.paramType}} _{{mockFunction.functionName}}) public {
function mock_call_{{mockFunction.functionName}}({{mockFunction.paramType}} _value) public {
vm.mockCall(
address(this),
abi.encodeWithSignature('{{mockFunction.functionName}}()'),
abi.encode(_{{mockFunction.functionName}})
abi.encode({{#if isStruct}}{{#each mockFunction.structFields}}_value.{{this}}{{#unless @last}}, {{/unless}}{{/each}}{{else}}_value{{/if}})
);
}
{{/unless}}
5 changes: 5 additions & 0 deletions src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,11 @@ export interface MappingVariableContext {
keyTypes: string[];
valueType: string;
baseType: string;
structFields?: string[];
};
isInternal: boolean;
isArray: boolean;
isStruct: boolean;
isStructArray: boolean;
hasNestedMapping: boolean;
}
Expand All @@ -63,13 +65,15 @@ export interface ArrayVariableContext {
functionName: string;
arrayType: string;
baseType: string;
structFields?: string[];
};
isInternal: boolean;
isStructArray: boolean;
}

export interface StateVariableContext {
isInternal: boolean;
isStruct: boolean;
setFunction: {
functionName: string;
paramType: string;
Expand All @@ -78,6 +82,7 @@ export interface StateVariableContext {
mockFunction: {
functionName: string;
paramType: string;
structFields?: string[];
};
}
interface Selector {
Expand Down
43 changes: 32 additions & 11 deletions src/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,13 @@
* Registers the nested templates
* @returns The content of the template
*/
export function getContractTemplate(): HandlebarsTemplateDelegate<any> {

Check warning on line 61 in src/utils.ts

View workflow job for this annotation

GitHub Actions / Run Linters (18.x)

Unexpected any. Specify a different type
const templatePath = path.resolve(__dirname, 'templates', 'contract-template.hbs');
const templateContent = readFileSync(templatePath, 'utf8');
return Handlebars.compile(templateContent);
}

export function getSmockHelperTemplate(): HandlebarsTemplateDelegate<any> {

Check warning on line 67 in src/utils.ts

View workflow job for this annotation

GitHub Actions / Run Linters (18.x)

Unexpected any. Specify a different type
const templatePath = path.resolve(__dirname, 'templates', 'helper-template.hbs');
const templateContent = readFileSync(templatePath, 'utf8');
return Handlebars.compile(templateContent);
Expand Down Expand Up @@ -224,7 +224,7 @@
const regex = /remappings[\s|\n]*=[\s\n]*\[(?<remappings>[^\]]+)]/;
const matches = foundryConfigContent.match(regex);
if (matches) {
return matches

Check warning on line 227 in src/utils.ts

View workflow job for this annotation

GitHub Actions / Run Linters (18.x)

Forbidden non-null assertion
.groups!.remappings.split(',')
.map((line) => line.trim())
.map((line) => line.replace(/["']/g, ''))
Expand Down Expand Up @@ -396,6 +396,25 @@
return `(${Array.from(contractsSet.contracts).join(', ')})`;
};

/**
* Returns the fields of a struct
* @param node The struct to extract the fields from
* @returns The fields of the struct
*/
const getStructFields = (node: TypeName) => {
const isStruct = node.typeString?.startsWith('struct');
if (!isStruct) return [];

const isArray = node.typeString.includes('[]');
const structTypeName = (isArray ? (node as ArrayTypeName).vBaseType : node) as UserDefinedTypeName;
if (!structTypeName) return [];

const struct = structTypeName?.vReferencedDeclaration;
if (!struct) return [];

return struct.children || [];
};

/**
* Returns if there are nested mappings in a struct
* @dev This function is recursive, loops through all the fields of the struct and nested structs
Expand All @@ -405,17 +424,7 @@
export const hasNestedMappings = (node: TypeName): boolean => {
let result = false;

const isStruct = node.typeString.startsWith('struct');
if (!isStruct) return false;

const isArray = node.typeString.includes('[]');
const structTypeName = (isArray ? (node as ArrayTypeName).vBaseType : node) as UserDefinedTypeName;
if (!structTypeName) return false;

const struct = structTypeName?.vReferencedDeclaration;
if (!struct) return false;

const fields = struct.children || [];
const fields = getStructFields(node);

for (const member of fields) {
const field = member as TypeName;
Expand All @@ -435,3 +444,15 @@

return result;
};

/**
* Extracts the fields of a struct
* @dev returns the fields names of the struct as a string array
* @param node The struct to extract the fields from
* @returns The fields names of the struct
*/
export const extractStructFieldsNames = (node: TypeName): string[] | null => {
const fields = getStructFields(node);

return fields.map((field) => (field as VariableDeclaration).name).filter((name) => name);
};
17 changes: 15 additions & 2 deletions test/unit/context/arrayVariableContext.spec.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { expect } from 'chai';
import { DataLocation, StateVariableVisibility } from 'solc-typed-ast';
import { mockArrayTypeName, mockTypeName, mockVariableDeclaration } from '../../mocks';
import { mockArrayTypeName, mockTypeName, mockUserDefinedTypeName, mockVariableDeclaration } from '../../mocks';
import { arrayVariableContext } from '../../../src/context';

describe('arrayVariableContext', () => {
Expand All @@ -25,6 +25,7 @@ describe('arrayVariableContext', () => {
functionName: 'testArrayVariable',
arrayType: 'uint256[] memory',
baseType: 'uint256',
structFields: [],
},
isInternal: false,
isStructArray: false,
Expand All @@ -35,7 +36,15 @@ describe('arrayVariableContext', () => {
const node = mockVariableDeclaration({
...defaultAttributes,
typeString: 'struct MyStruct[]',
vType: mockArrayTypeName({ vBaseType: mockTypeName({ typeString: 'struct MyStruct' }) }),
vType: mockArrayTypeName({
typeString: 'struct MyStruct[]',
vBaseType: mockUserDefinedTypeName({
typeString: 'struct MyStruct',
vReferencedDeclaration: mockTypeName({
children: [mockVariableDeclaration({ name: 'field1' }), mockVariableDeclaration({ name: 'field2' })],
}),
}),
}),
});
const context = arrayVariableContext(node);

Expand All @@ -49,6 +58,7 @@ describe('arrayVariableContext', () => {
functionName: 'testArrayVariable',
arrayType: 'MyStruct[] memory',
baseType: 'MyStruct memory',
structFields: ['field1', 'field2'],
},
isInternal: false,
isStructArray: true,
Expand All @@ -69,6 +79,7 @@ describe('arrayVariableContext', () => {
functionName: 'testArrayVariable',
arrayType: 'uint256[] memory',
baseType: 'uint256',
structFields: [],
},
isInternal: true,
isStructArray: false,
Expand All @@ -94,6 +105,7 @@ describe('arrayVariableContext', () => {
functionName: 'testArrayVariable',
arrayType: 'MyStruct[] memory',
baseType: 'MyStruct memory',
structFields: [],
},
isInternal: true,
isStructArray: true,
Expand All @@ -119,6 +131,7 @@ describe('arrayVariableContext', () => {
functionName: 'testArrayVariable',
arrayType: 'string[] memory',
baseType: 'string memory',
structFields: [],
},
isInternal: false,
isStructArray: false,
Expand Down
Loading
Loading