From 3fc9c34dbbad4c672767e19ed04c9e3ee2206ac7 Mon Sep 17 00:00:00 2001 From: agusduha Date: Mon, 15 Apr 2024 10:12:52 -0300 Subject: [PATCH 1/2] feat: fix mock calling structs --- solidity/contracts/utils/ContractG.sol | 11 +++- solidity/test/ContractTest.t.sol | 48 ++++++++--------- src/context.ts | 15 ++++++ .../partials/array-state-variable.hbs | 2 +- .../partials/mapping-state-variable.hbs | 2 +- src/templates/partials/state-variable.hbs | 4 +- src/types.ts | 5 ++ src/utils.ts | 45 ++++++++++++---- .../unit/context/arrayVariableContext.spec.ts | 17 +++++- .../context/mappingVariableContext.spec.ts | 53 +++++++++++++++++-- .../unit/context/stateVariableContext.spec.ts | 41 +++++++++++++- 11 files changed, 196 insertions(+), 47 deletions(-) diff --git a/solidity/contracts/utils/ContractG.sol b/solidity/contracts/utils/ContractG.sol index 78dc4a5..21e7795 100644 --- a/solidity/contracts/utils/ContractG.sol +++ b/solidity/contracts/utils/ContractG.sol @@ -25,7 +25,12 @@ contract ContractG { Set _set; } - mapping(bytes32 _disputeId => bool _finished) internal _finished; + struct NestedStruct { + uint256 _counter; + CommonStruct _common; + } + + mapping(bytes32 _disputeId => bool _finished) public _finished; mapping(bytes32 _disputeId => Set _votersSet) internal _votersA; @@ -36,4 +41,8 @@ contract ContractG { mapping(bytes32 _disputeId => AddressSets _votersSets) internal _votersD; mapping(bytes32 _disputeId => ComplexStruct _complexStruct) internal _complexStructs; + + mapping(bytes32 _disputeId => NestedStruct _nestedStruct) public _nestedStructs; + + NestedStruct public nestedStruct; } diff --git a/solidity/test/ContractTest.t.sol b/solidity/test/ContractTest.t.sol index ded8618..64ccb9f 100644 --- a/solidity/test/ContractTest.t.sol +++ b/solidity/test/ContractTest.t.sol @@ -205,12 +205,12 @@ contract E2EMockContractTest_Mock_call_Simple_Vars is CommonE2EBase { assertEq(_contractTest.bytes32Variable(), bytes32('40')); } - // function test_MockCall_MyStructVar() public { - // _contractTest.mock_call_myStructVariable(IContractTest.MyStruct(100, 'hundred')); - // (uint256 _value, string memory _name) = _contractTest.myStructVariable(); - // assertEq(_value, 100); - // assertEq(_name, 'hundred'); - // } + function test_MockCall_MyStructVar() public { + _contractTest.mock_call_myStructVariable(IContractTest.MyStruct(100, 'hundred')); + (uint256 _value, string memory _name) = _contractTest.myStructVariable(); + assertEq(_value, 100); + assertEq(_name, 'hundred'); + } function test_MockCall_InternalUintVar_Fail() public { // no mock calls for internal vars @@ -236,12 +236,12 @@ contract E2EMockContractTest_Mock_call_Array_Vars is CommonE2EBase { assertEq(_contractTest.bytes32Array(0), bytes32('40')); } - // function test_MockCall_MyStructArray() public { - // _contractTest.mock_call_myStructArray(0, IContractTest.MyStruct(100, 'hundred')); - // (uint256 _value, string memory _name) = _contractTest.myStructArray(0); - // assertEq(_value, 100); - // assertEq(_name, 'hundred'); - // } + function test_MockCall_MyStructArray() public { + _contractTest.mock_call_myStructArray(0, IContractTest.MyStruct(100, 'hundred')); + (uint256 _value, string memory _name) = _contractTest.myStructArray(0); + assertEq(_value, 100); + assertEq(_name, 'hundred'); + } function test_MockCall_InternalAddressArray_Fail() public { // no mock calls for internal vars @@ -267,24 +267,24 @@ contract E2EMockContractTest_Mock_call_Mapping_Vars is CommonE2EBase { assertEq(_contractTest.bytes32ToBytes(bytes32('40')), bytes('50')); } - // function test_MockCall_Uint256ToMyStructMappings() public { - // _contractTest.mock_call_uint256ToMyStruct(10, IContractTest.MyStruct(100, 'hundred')); - // (uint256 _value, string memory _name) = _contractTest.uint256ToMyStruct(10); - // assertEq(_value, 100); - // assertEq(_name, 'hundred'); - // } + function test_MockCall_Uint256ToMyStructMappings() public { + _contractTest.mock_call_uint256ToMyStruct(10, IContractTest.MyStruct(100, 'hundred')); + (uint256 _value, string memory _name) = _contractTest.uint256ToMyStruct(10); + assertEq(_value, 100); + assertEq(_name, 'hundred'); + } function test_MockCall_Uint256ToAddressArrayMappings() public { _contractTest.mock_call_uint256ToAddressArray(10, 0, _user); assertEq(_contractTest.uint256ToAddressArray(10, 0), _user); } - // function test_MockCall_Uint256ToMyStructArrayMappings() public { - // _contractTest.mock_call_uint256ToMyStructArray(10, 0, IContractTest.MyStruct(100, 'hundred')); - // (uint256 _value, string memory _name) = _contractTest.uint256ToMyStructArray(10, 0); - // assertEq(_value, 100); - // assertEq(_name, 'hundred'); - // } + function test_MockCall_Uint256ToMyStructArrayMappings() public { + _contractTest.mock_call_uint256ToMyStructArray(10, 0, IContractTest.MyStruct(100, 'hundred')); + (uint256 _value, string memory _name) = _contractTest.uint256ToMyStructArray(10, 0); + assertEq(_value, 100); + assertEq(_name, 'hundred'); + } function test_MockCall_Uint256ToAddressToBytes32Mappings() public { _contractTest.mock_call_uint256ToAddressToBytes32(10, _owner, bytes32('40')); diff --git a/src/context.ts b/src/context.ts index 9e50474..9b9c2a5 100644 --- a/src/context.ts +++ b/src/context.ts @@ -15,6 +15,7 @@ import { extractReturnParameters, extractOverrides, hasNestedMappings, + extractStructFields, } from './utils'; import { ContractDefinition, FunctionDefinition, VariableDeclaration, Identifier, ImportDirective } from 'solc-typed-ast'; @@ -165,6 +166,9 @@ export function mappingVariableContext(node: VariableDeclaration): MappingVariab // Check if value is a struct and has nested mappings const hasNestedMapping = hasNestedMappings(mappingTypeNameNode); + // Check if the variable is a struct and get its fields + const structFields = extractStructFields(mappingTypeNameNode); + // If the mapping is internal we don't create mockCall for it const isInternal: boolean = node.visibility === 'internal'; @@ -179,9 +183,11 @@ export function mappingVariableContext(node: VariableDeclaration): MappingVariab keyTypes: keyTypes, valueType: valueType, baseType: baseType, + structFields, }, isInternal: isInternal, isArray: isArray, + isStruct: !!structFields, isStructArray: isStructArray, hasNestedMapping, }; @@ -200,6 +206,9 @@ export function arrayVariableContext(node: VariableDeclaration): ArrayVariableCo // Struct flag const isStructArray: boolean = node.typeString.startsWith('struct '); + // Check if the variable is a struct and get its fields + const structFields = extractStructFields(node.vType); + // If the array is internal we don't create mockCall for it const isInternal: boolean = node.visibility === 'internal'; @@ -213,6 +222,7 @@ export function arrayVariableContext(node: VariableDeclaration): ArrayVariableCo functionName: arrayName, arrayType: arrayType, baseType: baseType, + structFields, }, isInternal: isInternal, isStructArray: isStructArray, @@ -229,6 +239,9 @@ export function stateVariableContext(node: VariableDeclaration): StateVariableCo // If the variable is internal we don't create mockCall for it const isInternal: boolean = node.visibility === 'internal'; + // Check if the variable is a struct and get its fields + const structFields = extractStructFields(node.vType); + // Save the state variable information return { setFunction: { @@ -239,7 +252,9 @@ export function stateVariableContext(node: VariableDeclaration): StateVariableCo mockFunction: { functionName: variableName, paramType: variableType, + structFields, }, isInternal: isInternal, + isStruct: !!structFields, }; } diff --git a/src/templates/partials/array-state-variable.hbs b/src/templates/partials/array-state-variable.hbs index 1bca793..b6b6d57 100644 --- a/src/templates/partials/array-state-variable.hbs +++ b/src/templates/partials/array-state-variable.hbs @@ -13,7 +13,7 @@ function mock_call_{{mockFunction.functionName}}(uint256 _index, {{mockFunction. vm.mockCall( address(this), abi.encodeWithSignature('{{mockFunction.functionName}}(uint256)', _index), - abi.encode(_value) + abi.encode({{#if isStructArray}}{{#each mockFunction.structFields}}_value.{{this}}{{#unless @last}}, {{/unless}}{{/each}}{{else}}_value{{/if}}) ); } {{/unless}} diff --git a/src/templates/partials/mapping-state-variable.hbs b/src/templates/partials/mapping-state-variable.hbs index 4cf946a..6cd9796 100644 --- a/src/templates/partials/mapping-state-variable.hbs +++ b/src/templates/partials/mapping-state-variable.hbs @@ -14,7 +14,7 @@ function mock_call_{{mockFunction.functionName}}({{#each mockFunction.keyTypes}} vm.mockCall( address(this), abi.encodeWithSignature('{{mockFunction.functionName}}({{#each mockFunction.keyTypes}}{{this}}{{#unless @last}},{{/unless}}{{/each}}{{#if isArray}},uint256{{/if}})'{{#each mockFunction.keyTypes}}, _key{{@index}}{{/each}}{{#if isArray}}, _index{{/if}}), - abi.encode(_value) + abi.encode({{#if isStruct}}{{#each mockFunction.structFields}}_value.{{this}}{{#unless @last}}, {{/unless}}{{/each}}{{else}}_value{{/if}}) ); } {{/unless}} diff --git a/src/templates/partials/state-variable.hbs b/src/templates/partials/state-variable.hbs index b310ae6..b0b0d46 100644 --- a/src/templates/partials/state-variable.hbs +++ b/src/templates/partials/state-variable.hbs @@ -3,11 +3,11 @@ function set_{{setFunction.functionName}}({{setFunction.paramType}} _{{setFuncti } {{#unless isInternal}} -function mock_call_{{mockFunction.functionName}}({{mockFunction.paramType}} _{{mockFunction.functionName}}) public { +function mock_call_{{mockFunction.functionName}}({{mockFunction.paramType}} _value) public { vm.mockCall( address(this), abi.encodeWithSignature('{{mockFunction.functionName}}()'), - abi.encode(_{{mockFunction.functionName}}) + abi.encode({{#if isStruct}}{{#each mockFunction.structFields}}_value.{{this}}{{#unless @last}}, {{/unless}}{{/each}}{{else}}_value{{/if}}) ); } {{/unless}} diff --git a/src/types.ts b/src/types.ts index 4617de5..5436bbb 100644 --- a/src/types.ts +++ b/src/types.ts @@ -46,9 +46,11 @@ export interface MappingVariableContext { keyTypes: string[]; valueType: string; baseType: string; + structFields?: string[]; }; isInternal: boolean; isArray: boolean; + isStruct: boolean; isStructArray: boolean; hasNestedMapping: boolean; } @@ -63,6 +65,7 @@ export interface ArrayVariableContext { functionName: string; arrayType: string; baseType: string; + structFields?: string[]; }; isInternal: boolean; isStructArray: boolean; @@ -70,6 +73,7 @@ export interface ArrayVariableContext { export interface StateVariableContext { isInternal: boolean; + isStruct: boolean; setFunction: { functionName: string; paramType: string; @@ -78,6 +82,7 @@ export interface StateVariableContext { mockFunction: { functionName: string; paramType: string; + structFields?: string[]; }; } interface Selector { diff --git a/src/utils.ts b/src/utils.ts index dbf70d2..296f7b6 100644 --- a/src/utils.ts +++ b/src/utils.ts @@ -396,6 +396,25 @@ export const extractOverrides = (node: FullFunctionDefinition): string | null => return `(${Array.from(contractsSet.contracts).join(', ')})`; }; +/** + * Returns the fields of a struct + * @param node The struct to extract the fields from + * @returns The fields of the struct + */ +const getStructFields = (node: TypeName) => { + const isStruct = node.typeString?.startsWith('struct'); + if (!isStruct) return []; + + const isArray = node.typeString.includes('[]'); + const structTypeName = (isArray ? (node as ArrayTypeName).vBaseType : node) as UserDefinedTypeName; + if (!structTypeName) return []; + + const struct = structTypeName?.vReferencedDeclaration; + if (!struct) return []; + + return struct.children || []; +}; + /** * Returns if there are nested mappings in a struct * @dev This function is recursive, loops through all the fields of the struct and nested structs @@ -405,17 +424,7 @@ export const extractOverrides = (node: FullFunctionDefinition): string | null => export const hasNestedMappings = (node: TypeName): boolean => { let result = false; - const isStruct = node.typeString.startsWith('struct'); - if (!isStruct) return false; - - const isArray = node.typeString.includes('[]'); - const structTypeName = (isArray ? (node as ArrayTypeName).vBaseType : node) as UserDefinedTypeName; - if (!structTypeName) return false; - - const struct = structTypeName?.vReferencedDeclaration; - if (!struct) return false; - - const fields = struct.children || []; + const fields = getStructFields(node); for (const member of fields) { const field = member as TypeName; @@ -435,3 +444,17 @@ export const hasNestedMappings = (node: TypeName): boolean => { return result; }; + +/** + * Extracts the fields of a struct + * @dev returns the fields names of the struct as a string array + * @param node The struct to extract the fields from + * @returns The fields names of the struct + */ +export const extractStructFields = (node: TypeName): string[] | null => { + const fields = getStructFields(node); + + if (!fields.length) return null; + + return fields.map((field) => (field as VariableDeclaration).name).filter((name) => name); +}; diff --git a/test/unit/context/arrayVariableContext.spec.ts b/test/unit/context/arrayVariableContext.spec.ts index b801afc..1d8ba10 100644 --- a/test/unit/context/arrayVariableContext.spec.ts +++ b/test/unit/context/arrayVariableContext.spec.ts @@ -1,6 +1,6 @@ import { expect } from 'chai'; import { DataLocation, StateVariableVisibility } from 'solc-typed-ast'; -import { mockArrayTypeName, mockTypeName, mockVariableDeclaration } from '../../mocks'; +import { mockArrayTypeName, mockTypeName, mockUserDefinedTypeName, mockVariableDeclaration } from '../../mocks'; import { arrayVariableContext } from '../../../src/context'; describe('arrayVariableContext', () => { @@ -25,6 +25,7 @@ describe('arrayVariableContext', () => { functionName: 'testArrayVariable', arrayType: 'uint256[] memory', baseType: 'uint256', + structFields: null, }, isInternal: false, isStructArray: false, @@ -35,7 +36,15 @@ describe('arrayVariableContext', () => { const node = mockVariableDeclaration({ ...defaultAttributes, typeString: 'struct MyStruct[]', - vType: mockArrayTypeName({ vBaseType: mockTypeName({ typeString: 'struct MyStruct' }) }), + vType: mockArrayTypeName({ + typeString: 'struct MyStruct[]', + vBaseType: mockUserDefinedTypeName({ + typeString: 'struct MyStruct', + vReferencedDeclaration: mockTypeName({ + children: [mockVariableDeclaration({ name: 'field1' }), mockVariableDeclaration({ name: 'field2' })], + }), + }), + }), }); const context = arrayVariableContext(node); @@ -49,6 +58,7 @@ describe('arrayVariableContext', () => { functionName: 'testArrayVariable', arrayType: 'MyStruct[] memory', baseType: 'MyStruct memory', + structFields: ['field1', 'field2'], }, isInternal: false, isStructArray: true, @@ -69,6 +79,7 @@ describe('arrayVariableContext', () => { functionName: 'testArrayVariable', arrayType: 'uint256[] memory', baseType: 'uint256', + structFields: null, }, isInternal: true, isStructArray: false, @@ -94,6 +105,7 @@ describe('arrayVariableContext', () => { functionName: 'testArrayVariable', arrayType: 'MyStruct[] memory', baseType: 'MyStruct memory', + structFields: null, }, isInternal: true, isStructArray: true, @@ -119,6 +131,7 @@ describe('arrayVariableContext', () => { functionName: 'testArrayVariable', arrayType: 'string[] memory', baseType: 'string memory', + structFields: null, }, isInternal: false, isStructArray: false, diff --git a/test/unit/context/mappingVariableContext.spec.ts b/test/unit/context/mappingVariableContext.spec.ts index ceb391f..a73fca2 100644 --- a/test/unit/context/mappingVariableContext.spec.ts +++ b/test/unit/context/mappingVariableContext.spec.ts @@ -26,9 +26,11 @@ describe('mappingVariableContext', () => { keyTypes: ['uint256'], valueType: 'uint256', baseType: 'uint256', + structFields: null, }, isInternal: false, isArray: false, + isStruct: false, isStructArray: false, hasNestedMapping: false, }); @@ -56,9 +58,11 @@ describe('mappingVariableContext', () => { keyTypes: ['uint256'], valueType: 'MyStruct memory', baseType: 'MyStruct memory', + structFields: null, }, isInternal: false, isArray: false, + isStruct: false, isStructArray: false, hasNestedMapping: false, }); @@ -86,9 +90,11 @@ describe('mappingVariableContext', () => { keyTypes: ['uint256'], valueType: 'uint256[] memory', baseType: 'uint256', + structFields: null, }, isInternal: false, isArray: true, + isStruct: false, isStructArray: false, hasNestedMapping: false, }); @@ -109,9 +115,11 @@ describe('mappingVariableContext', () => { keyTypes: ['uint256'], valueType: 'uint256', baseType: 'uint256', + structFields: null, }, isInternal: true, isArray: false, + isStruct: false, isStructArray: false, hasNestedMapping: false, }); @@ -123,9 +131,18 @@ describe('mappingVariableContext', () => { typeString: 'mapping(uint256 => struct MyStruct[])', vType: mockMapping({ vKeyType: mockTypeName({ typeString: 'uint256' }), - vValueType: mockArrayTypeName({ vBaseType: mockTypeName({ typeString: 'struct MyStruct' }), typeString: 'struct MyStruct[]' }), + vValueType: mockArrayTypeName({ + typeString: 'struct MyStruct[]', + vBaseType: mockUserDefinedTypeName({ + typeString: 'struct MyStruct', + vReferencedDeclaration: mockTypeName({ + children: [mockVariableDeclaration({ name: 'field1' }), mockVariableDeclaration({ name: 'field2' })], + }), + }), + }), }), }); + const context = mappingVariableContext(node); expect(context).to.eql({ @@ -139,9 +156,11 @@ describe('mappingVariableContext', () => { keyTypes: ['uint256'], valueType: 'MyStruct[] memory', baseType: 'MyStruct memory', + structFields: ['field1', 'field2'], }, isInternal: false, isArray: true, + isStruct: true, isStructArray: true, hasNestedMapping: false, }); @@ -173,9 +192,11 @@ describe('mappingVariableContext', () => { keyTypes: ['uint256', 'uint256'], valueType: 'uint256', baseType: 'uint256', + structFields: null, }, isInternal: false, isArray: false, + isStruct: false, isStructArray: false, hasNestedMapping: false, }); @@ -190,7 +211,15 @@ describe('mappingVariableContext', () => { vValueType: mockMapping({ typeString: 'mapping(uint256 => struct MyStruct[])', vKeyType: mockTypeName({ typeString: 'uint256' }), - vValueType: mockArrayTypeName({ vBaseType: mockTypeName({ typeString: 'struct MyStruct' }), typeString: 'struct MyStruct[]' }), + vValueType: mockArrayTypeName({ + typeString: 'struct MyStruct[]', + vBaseType: mockUserDefinedTypeName({ + typeString: 'struct MyStruct', + vReferencedDeclaration: mockTypeName({ + children: [mockVariableDeclaration({ name: 'field1' }), mockVariableDeclaration({ name: 'field2' })], + }), + }), + }), }), }), }); @@ -207,9 +236,11 @@ describe('mappingVariableContext', () => { keyTypes: ['uint256', 'uint256'], valueType: 'MyStruct[] memory', baseType: 'MyStruct memory', + structFields: ['field1', 'field2'], }, isInternal: false, isArray: true, + isStruct: true, isStructArray: true, hasNestedMapping: false, }); @@ -245,9 +276,11 @@ describe('mappingVariableContext', () => { keyTypes: ['uint256', 'uint128', 'uint64'], valueType: 'uint8', baseType: 'uint8', + structFields: null, }, isInternal: false, isArray: false, + isStruct: false, isStructArray: false, hasNestedMapping: false, }); @@ -261,7 +294,12 @@ describe('mappingVariableContext', () => { vKeyType: mockTypeName({ typeString: 'uint256' }), vValueType: mockUserDefinedTypeName({ typeString: 'struct MyStruct', - vReferencedDeclaration: mockUserDefinedTypeName({ children: [mockTypeName({ typeString: 'mapping(uint256 => uint256)' })] }), + vReferencedDeclaration: mockUserDefinedTypeName({ + children: [ + mockVariableDeclaration({ name: 'field1', typeString: 'mapping(uint256 => uint256)' }), + mockVariableDeclaration({ name: 'field2', typeString: 'uint256)' }), + ], + }), }), }), }); @@ -278,9 +316,11 @@ describe('mappingVariableContext', () => { keyTypes: ['uint256'], valueType: 'MyStruct memory', baseType: 'MyStruct memory', + structFields: ['field1', 'field2'], }, isInternal: false, isArray: false, + isStruct: true, isStructArray: false, hasNestedMapping: true, }); @@ -301,7 +341,10 @@ describe('mappingVariableContext', () => { typeString: 'struct MyStruct', vReferencedDeclaration: mockTypeName({ typeString: 'struct MyStruct', - children: [mockTypeName({ typeString: 'mapping(uint256 => uint256)' })], + children: [ + mockVariableDeclaration({ name: 'field1', typeString: 'mapping(uint256 => uint256)' }), + mockVariableDeclaration({ name: 'field2', typeString: 'uint256)' }), + ], }), }), }), @@ -321,9 +364,11 @@ describe('mappingVariableContext', () => { keyTypes: ['uint256', 'uint256'], valueType: 'MyStruct[] memory', baseType: 'MyStruct memory', + structFields: ['field1', 'field2'], }, isInternal: false, isArray: true, + isStruct: true, isStructArray: true, hasNestedMapping: true, }); diff --git a/test/unit/context/stateVariableContext.spec.ts b/test/unit/context/stateVariableContext.spec.ts index 69c2d2c..e61c23d 100644 --- a/test/unit/context/stateVariableContext.spec.ts +++ b/test/unit/context/stateVariableContext.spec.ts @@ -1,5 +1,5 @@ import { DataLocation, StateVariableVisibility } from 'solc-typed-ast'; -import { mockVariableDeclaration } from '../../mocks'; +import { mockTypeName, mockUserDefinedTypeName, mockVariableDeclaration } from '../../mocks'; import { stateVariableContext } from '../../../src/context'; import { expect } from 'chai'; @@ -8,6 +8,7 @@ describe('stateVariableContext', () => { name: 'testStateVariable', typeString: 'uint256', visibility: StateVariableVisibility.Default, + vType: mockVariableDeclaration({ typeString: 'uint256' }), }; it('processes state variables', () => { @@ -23,8 +24,10 @@ describe('stateVariableContext', () => { mockFunction: { functionName: 'testStateVariable', paramType: 'uint256', + structFields: null, }, isInternal: false, + isStruct: false, }); }); @@ -41,8 +44,10 @@ describe('stateVariableContext', () => { mockFunction: { functionName: 'testStateVariable', paramType: 'uint256', + structFields: null, }, isInternal: true, + isStruct: false, }); }); @@ -59,8 +64,42 @@ describe('stateVariableContext', () => { mockFunction: { functionName: 'testStateVariable', paramType: 'string memory', + structFields: null, }, isInternal: false, + isStruct: false, + }); + }); + + it('processes struct state variables', () => { + const node = mockVariableDeclaration({ + ...defaultAttributes, + typeString: 'struct MyStruct', + vType: mockUserDefinedTypeName({ + typeString: 'struct MyStruct', + vReferencedDeclaration: mockUserDefinedTypeName({ + children: [ + mockVariableDeclaration({ name: 'field1', typeString: 'mapping(uint256 => uint256)' }), + mockVariableDeclaration({ name: 'field2', typeString: 'uint256)' }), + ], + }), + }), + }); + const context = stateVariableContext(node); + + expect(context).to.eql({ + setFunction: { + functionName: 'testStateVariable', + paramType: 'MyStruct memory', + paramName: 'testStateVariable', + }, + mockFunction: { + functionName: 'testStateVariable', + paramType: 'MyStruct memory', + structFields: ['field1', 'field2'], + }, + isInternal: false, + isStruct: true, }); }); }); From eab13a1647bad4ec153f4d7ce19671de60204ea0 Mon Sep 17 00:00:00 2001 From: agusduha Date: Mon, 15 Apr 2024 11:53:57 -0300 Subject: [PATCH 2/2] fix: rename extract struct fields function --- src/context.ts | 12 ++++++------ src/utils.ts | 4 +--- test/unit/context/arrayVariableContext.spec.ts | 8 ++++---- test/unit/context/mappingVariableContext.spec.ts | 12 ++++++------ test/unit/context/stateVariableContext.spec.ts | 8 ++++---- 5 files changed, 21 insertions(+), 23 deletions(-) diff --git a/src/context.ts b/src/context.ts index 9b9c2a5..718325b 100644 --- a/src/context.ts +++ b/src/context.ts @@ -15,7 +15,7 @@ import { extractReturnParameters, extractOverrides, hasNestedMappings, - extractStructFields, + extractStructFieldsNames, } from './utils'; import { ContractDefinition, FunctionDefinition, VariableDeclaration, Identifier, ImportDirective } from 'solc-typed-ast'; @@ -167,7 +167,7 @@ export function mappingVariableContext(node: VariableDeclaration): MappingVariab const hasNestedMapping = hasNestedMappings(mappingTypeNameNode); // Check if the variable is a struct and get its fields - const structFields = extractStructFields(mappingTypeNameNode); + const structFields = extractStructFieldsNames(mappingTypeNameNode); // If the mapping is internal we don't create mockCall for it const isInternal: boolean = node.visibility === 'internal'; @@ -187,7 +187,7 @@ export function mappingVariableContext(node: VariableDeclaration): MappingVariab }, isInternal: isInternal, isArray: isArray, - isStruct: !!structFields, + isStruct: structFields.length > 0, isStructArray: isStructArray, hasNestedMapping, }; @@ -207,7 +207,7 @@ export function arrayVariableContext(node: VariableDeclaration): ArrayVariableCo const isStructArray: boolean = node.typeString.startsWith('struct '); // Check if the variable is a struct and get its fields - const structFields = extractStructFields(node.vType); + const structFields = extractStructFieldsNames(node.vType); // If the array is internal we don't create mockCall for it const isInternal: boolean = node.visibility === 'internal'; @@ -240,7 +240,7 @@ export function stateVariableContext(node: VariableDeclaration): StateVariableCo const isInternal: boolean = node.visibility === 'internal'; // Check if the variable is a struct and get its fields - const structFields = extractStructFields(node.vType); + const structFields = extractStructFieldsNames(node.vType); // Save the state variable information return { @@ -255,6 +255,6 @@ export function stateVariableContext(node: VariableDeclaration): StateVariableCo structFields, }, isInternal: isInternal, - isStruct: !!structFields, + isStruct: structFields.length > 0, }; } diff --git a/src/utils.ts b/src/utils.ts index 296f7b6..a87d3b7 100644 --- a/src/utils.ts +++ b/src/utils.ts @@ -451,10 +451,8 @@ export const hasNestedMappings = (node: TypeName): boolean => { * @param node The struct to extract the fields from * @returns The fields names of the struct */ -export const extractStructFields = (node: TypeName): string[] | null => { +export const extractStructFieldsNames = (node: TypeName): string[] | null => { const fields = getStructFields(node); - if (!fields.length) return null; - return fields.map((field) => (field as VariableDeclaration).name).filter((name) => name); }; diff --git a/test/unit/context/arrayVariableContext.spec.ts b/test/unit/context/arrayVariableContext.spec.ts index 1d8ba10..5aab853 100644 --- a/test/unit/context/arrayVariableContext.spec.ts +++ b/test/unit/context/arrayVariableContext.spec.ts @@ -25,7 +25,7 @@ describe('arrayVariableContext', () => { functionName: 'testArrayVariable', arrayType: 'uint256[] memory', baseType: 'uint256', - structFields: null, + structFields: [], }, isInternal: false, isStructArray: false, @@ -79,7 +79,7 @@ describe('arrayVariableContext', () => { functionName: 'testArrayVariable', arrayType: 'uint256[] memory', baseType: 'uint256', - structFields: null, + structFields: [], }, isInternal: true, isStructArray: false, @@ -105,7 +105,7 @@ describe('arrayVariableContext', () => { functionName: 'testArrayVariable', arrayType: 'MyStruct[] memory', baseType: 'MyStruct memory', - structFields: null, + structFields: [], }, isInternal: true, isStructArray: true, @@ -131,7 +131,7 @@ describe('arrayVariableContext', () => { functionName: 'testArrayVariable', arrayType: 'string[] memory', baseType: 'string memory', - structFields: null, + structFields: [], }, isInternal: false, isStructArray: false, diff --git a/test/unit/context/mappingVariableContext.spec.ts b/test/unit/context/mappingVariableContext.spec.ts index a73fca2..9686332 100644 --- a/test/unit/context/mappingVariableContext.spec.ts +++ b/test/unit/context/mappingVariableContext.spec.ts @@ -26,7 +26,7 @@ describe('mappingVariableContext', () => { keyTypes: ['uint256'], valueType: 'uint256', baseType: 'uint256', - structFields: null, + structFields: [], }, isInternal: false, isArray: false, @@ -58,7 +58,7 @@ describe('mappingVariableContext', () => { keyTypes: ['uint256'], valueType: 'MyStruct memory', baseType: 'MyStruct memory', - structFields: null, + structFields: [], }, isInternal: false, isArray: false, @@ -90,7 +90,7 @@ describe('mappingVariableContext', () => { keyTypes: ['uint256'], valueType: 'uint256[] memory', baseType: 'uint256', - structFields: null, + structFields: [], }, isInternal: false, isArray: true, @@ -115,7 +115,7 @@ describe('mappingVariableContext', () => { keyTypes: ['uint256'], valueType: 'uint256', baseType: 'uint256', - structFields: null, + structFields: [], }, isInternal: true, isArray: false, @@ -192,7 +192,7 @@ describe('mappingVariableContext', () => { keyTypes: ['uint256', 'uint256'], valueType: 'uint256', baseType: 'uint256', - structFields: null, + structFields: [], }, isInternal: false, isArray: false, @@ -276,7 +276,7 @@ describe('mappingVariableContext', () => { keyTypes: ['uint256', 'uint128', 'uint64'], valueType: 'uint8', baseType: 'uint8', - structFields: null, + structFields: [], }, isInternal: false, isArray: false, diff --git a/test/unit/context/stateVariableContext.spec.ts b/test/unit/context/stateVariableContext.spec.ts index e61c23d..2f5f484 100644 --- a/test/unit/context/stateVariableContext.spec.ts +++ b/test/unit/context/stateVariableContext.spec.ts @@ -1,5 +1,5 @@ import { DataLocation, StateVariableVisibility } from 'solc-typed-ast'; -import { mockTypeName, mockUserDefinedTypeName, mockVariableDeclaration } from '../../mocks'; +import { mockUserDefinedTypeName, mockVariableDeclaration } from '../../mocks'; import { stateVariableContext } from '../../../src/context'; import { expect } from 'chai'; @@ -24,7 +24,7 @@ describe('stateVariableContext', () => { mockFunction: { functionName: 'testStateVariable', paramType: 'uint256', - structFields: null, + structFields: [], }, isInternal: false, isStruct: false, @@ -44,7 +44,7 @@ describe('stateVariableContext', () => { mockFunction: { functionName: 'testStateVariable', paramType: 'uint256', - structFields: null, + structFields: [], }, isInternal: true, isStruct: false, @@ -64,7 +64,7 @@ describe('stateVariableContext', () => { mockFunction: { functionName: 'testStateVariable', paramType: 'string memory', - structFields: null, + structFields: [], }, isInternal: false, isStruct: false,