From e73d94e4a20842fbd2b7addaa070120e15ef3979 Mon Sep 17 00:00:00 2001 From: dristpunk <107591874+dristpunk@users.noreply.github.com> Date: Wed, 1 Nov 2023 09:34:24 +0300 Subject: [PATCH 1/2] feat: mock interfaces (#5) --- src/get-external-functions.ts | 11 ++ src/mock-contract-generator.ts | 15 +- src/run.ts | 15 +- src/templates/mockContractTemplate.hbs | 2 +- .../mockExternalFunctionTemplate.hbs | 3 + src/types.ts | 6 + src/utils.ts | 8 +- test/e2e/get-external-functions.spec.ts | 129 ++++++++++++++++++ test/e2e/get-internal-functions.spec.ts | 9 +- test/e2e/set-variables.spec.ts | 7 +- test/unit/get-constructor.spec.ts | 6 + test/unit/get-external-functions.spec.ts | 90 ++++++++++++ test/unit/get-imports.spec.ts | 1 + test/unit/get-internal-functions.spec.ts | 9 ++ 14 files changed, 294 insertions(+), 17 deletions(-) create mode 100644 test/e2e/get-external-functions.spec.ts diff --git a/src/get-external-functions.ts b/src/get-external-functions.ts index bc155d1..86638a8 100644 --- a/src/get-external-functions.ts +++ b/src/get-external-functions.ts @@ -9,6 +9,9 @@ export const getExternalMockFunctions = (contractNode: ContractDefinitionNode): // Filter the nodes and keep only the FunctionDefinition related ones const functionNodes = contractNode.nodes.filter((node) => node.nodeType === 'FunctionDefinition') as FunctionDefinitionNode[]; + // Get contract kind + const contractKind = contractNode.contractKind; + const externalFunctions: ExternalFunctionOptions[] = []; // Loop through the function nodes functionNodes.forEach((funcNode: FunctionDefinitionNode) => { @@ -17,6 +20,10 @@ export const getExternalMockFunctions = (contractNode: ContractDefinitionNode): // Check if the function is external or public if (funcNode.visibility != 'external' && funcNode.visibility != 'public') return; + // Save state mutability + const stateMutability = funcNode.stateMutability; + const stateMutabilityString = stateMutability == 'nonpayable' ? ' ' : ` ${stateMutability} `; + // Get the parameters of the function, if there are no parameters then we use an empty array const parameters: VariableDeclarationNode[] = funcNode.parameters.parameters ? funcNode.parameters.parameters : []; @@ -95,6 +102,10 @@ export const getExternalMockFunctions = (contractNode: ContractDefinitionNode): signature: signature, inputsStringNames: inputsStringNames, outputsStringNames: outputsStringNames, + inputString: inputsString, + outputString: outputsString, + isInterface: contractKind === 'interface', + stateMutabilityString: stateMutabilityString, }; externalFunctions.push(externalMockFunction); diff --git a/src/mock-contract-generator.ts b/src/mock-contract-generator.ts index 8283af1..673a579 100644 --- a/src/mock-contract-generator.ts +++ b/src/mock-contract-generator.ts @@ -12,7 +12,12 @@ import { StateVariablesOptions, ContractDefinitionNode } from './types'; * @param compiledArtifactsDir The directory where the compiled artifacts are located * @param generatedContractsDir The directory where the mock contracts will be generated */ -export const generateMockContracts = async (contractsDir: string, compiledArtifactsDir: string, generatedContractsDir: string) => { +export const generateMockContracts = async ( + contractsDir: string[], + compiledArtifactsDir: string, + generatedContractsDir: string, + ignoreDir: string[], +) => { const templateContent: string = registerHandlebarsTemplates(); const template = Handlebars.compile(templateContent); try { @@ -31,7 +36,7 @@ export const generateMockContracts = async (contractsDir: string, compiledArtifa } console.log('Parsing contracts...'); // Get all contracts directories - const contractPaths: string[] = getContractNames(contractsDir); + const contractPaths: string[] = getContractNames(contractsDir, ignoreDir); // Loop for each contract path contractPaths.forEach(async (contractPath: string) => { // Get the sub dir name @@ -71,11 +76,16 @@ export const generateMockContracts = async (contractsDir: string, compiledArtifa const contractImport: string = ast.absolutePath; if (!contractImport) return; + // Get all exported entities + const exportedSymbols = Object.keys(ast.exportedSymbols); + // Get the contract node and check if it's a library // Also check if is another contract inside the file and avoid it const contractNode = ast.nodes.find( (node) => node.nodeType === 'ContractDefinition' && node.canonicalName === contractName, ) as ContractDefinitionNode; + + // Skip unneeded contracts if (!contractNode || contractNode.abstract || contractNode.contractKind === 'library') return; const functions: StateVariablesOptions = getStateVariables(contractNode); @@ -84,6 +94,7 @@ export const generateMockContracts = async (contractsDir: string, compiledArtifa const data = { contractName: contractName, contractImport: contractImport, + exportedSymbols: exportedSymbols.join(', '), import: getImports(ast), constructor: getConstructor(contractNode), mockExternalFunctions: getExternalMockFunctions(contractNode), diff --git a/src/run.ts b/src/run.ts index 3969c4b..dd9311d 100644 --- a/src/run.ts +++ b/src/run.ts @@ -5,17 +5,18 @@ import { hideBin } from 'yargs/helpers'; import { generateMockContracts } from './index'; (async () => { - const { contracts, out, genDir } = getProcessArguments(); - generateMockContracts(contracts, out, genDir); + const { contracts, out, genDir, ignore } = getProcessArguments(); + generateMockContracts(contracts, out, genDir, ignore); })(); function getProcessArguments() { return yargs(hideBin(process.argv)) .options({ contracts: { - describe: 'Contracts directory', + describe: 'Contracts directories', demandOption: true, - type: 'string', + type: 'array', + string: true, }, out: { describe: 'Foundry compiled output directory', @@ -27,6 +28,12 @@ function getProcessArguments() { default: './solidity/test/mock-contracts', type: 'string', }, + ignore: { + describe: 'Ignore folders', + default: [], + type: 'array', + string: true, + }, }) .parseSync(); } diff --git a/src/templates/mockContractTemplate.hbs b/src/templates/mockContractTemplate.hbs index abf81e5..1ea0463 100644 --- a/src/templates/mockContractTemplate.hbs +++ b/src/templates/mockContractTemplate.hbs @@ -2,7 +2,7 @@ pragma solidity ^0.8.0; import {Test} from 'forge-std/Test.sol'; -import { {{~contractName~}} } from '{{contractImport}}'; +import { {{~exportedSymbols~}} } from '{{contractImport}}'; {{#each import}} {{this}}; {{/each}} diff --git a/src/templates/mockExternalFunctionTemplate.hbs b/src/templates/mockExternalFunctionTemplate.hbs index 8c29cb4..a9a162f 100644 --- a/src/templates/mockExternalFunctionTemplate.hbs +++ b/src/templates/mockExternalFunctionTemplate.hbs @@ -1,3 +1,6 @@ +{{#if isInterface}} +function {{functionName}}({{inputString}}) external{{stateMutabilityString}}returns ({{outputString}}) {} +{{/if}} function mock_call_{{functionName}}({{arguments}}) public { vm.mockCall( address(this), diff --git a/src/types.ts b/src/types.ts index 5998081..e49c98b 100644 --- a/src/types.ts +++ b/src/types.ts @@ -48,6 +48,7 @@ export interface FunctionDefinitionNode { }; virtual: boolean; visibility: string; + stateMutability: string; } export interface Ast { @@ -57,6 +58,7 @@ export interface Ast { src: string; nodes: AstNode[]; license: string; + exportedSymbols: { [key: string]: number[] }; } export interface ImportDirectiveNode { @@ -135,6 +137,10 @@ export interface ExternalFunctionOptions { signature: string; inputsStringNames: string; outputsStringNames: string; + inputString: string; + outputString: string; + isInterface: boolean; + stateMutabilityString: string; } export interface InternalFunctionOptions { diff --git a/src/utils.ts b/src/utils.ts index fb6cdf4..f48861d 100644 --- a/src/utils.ts +++ b/src/utils.ts @@ -88,7 +88,7 @@ export const registerHandlebarsTemplates = (): string => { * @param contractsDir The directory where the contracts are located * @returns The names of the contracts in the given directory and its subdirectories */ -export const getContractNames = (contractsDir: string): string[] => { +export const getContractNames = (contractsDir: string[], ignoreDir: string[]): string[] => { const contractFileNames: string[] = []; // Recursive function to traverse the directory and its subdirectories function traverseDirectory(currentPath: string) { @@ -100,13 +100,15 @@ export const getContractNames = (contractsDir: string): string[] => { // If the file is a contract then we add it to the array, if it is a directory then we call the function again if (stats.isFile() && fileName.endsWith('.sol')) { contractFileNames.push(fileName); - } else if (stats.isDirectory()) { + } else if (stats.isDirectory() && !ignoreDir.includes(fileName)) { traverseDirectory(filePath); } }); } - traverseDirectory(contractsDir); + contractsDir.map((dir: string) => { + traverseDirectory(dir); + }); return contractFileNames; }; diff --git a/test/e2e/get-external-functions.spec.ts b/test/e2e/get-external-functions.spec.ts new file mode 100644 index 0000000..44b2aef --- /dev/null +++ b/test/e2e/get-external-functions.spec.ts @@ -0,0 +1,129 @@ +// Write e2e tests for the getExternalFunctions function here like the other tests. +import { expect } from 'chai'; +import { ContractDefinitionNode, FunctionDefinitionNode } from '../../src/types'; +import { generateMockContracts } from '../../src/index'; +import { resolve } from 'path'; + +// We use the describe function to group together related tests +describe('E2E: getExternalMockFunctions', () => { + // We use the beforeEach function to reset the contract node before each test + let contractNodes: { [name: string]: ContractDefinitionNode }; + before(async () => { + // generate mock contracts + const contractsDir = ['solidity/contracts', 'solidity/interfaces']; + const compiledArtifactsDir = 'out'; + const generatedContractsDir = 'solidity/test/mock-contracts'; + const ignoreDir = []; + await generateMockContracts(contractsDir, compiledArtifactsDir, generatedContractsDir, ignoreDir); + + const contractsNames = ['ContractTest', 'IContractTest']; + + contractsNames.forEach((contractName: string) => { + const mockName = `Mock${contractName}`; + const compiledArtifactsPath = resolve(compiledArtifactsDir, `${mockName}.sol`, `${mockName}.json`); + const ast = require(compiledArtifactsPath).ast; + if (!ast) throw new Error(`AST for ${mockName} not found`); + const contractNode = ast.nodes.find( + (node) => node.nodeType === 'ContractDefinition' && node.canonicalName === mockName, + ) as ContractDefinitionNode; + if (!contractNode || contractNode.abstract || contractNode.contractKind === 'library') throw new Error(`Contract ${mockName} not found`); + + contractNodes = { ...contractNodes, [mockName]: contractNode }; + }); + }); + + // We use the it function to create a test + it('MockContractTest must include constructor', async () => { + const contractNode = contractNodes['MockContractTest']; + const constructor = contractNode.nodes.find( + (node) => node.nodeType === 'FunctionDefinition' && node.kind === 'constructor', + ) as FunctionDefinitionNode; + expect(constructor).to.not.be.undefined; + + const param1 = constructor.parameters.parameters.find((param) => param.name === '_uintVariable'); + expect(param1).to.not.be.undefined; + expect(param1?.typeDescriptions.typeString).to.equal('uint256'); + }); + + it('MockContractTest must include mock call', async () => { + const contractNode = contractNodes['MockContractTest']; + const func = contractNode.nodes.find( + (node) => node.nodeType === 'FunctionDefinition' && node.name === 'mock_call_setVariables' && node.parameters.parameters.length === 2, + ) as FunctionDefinitionNode; + expect(func).to.not.be.undefined; + expect(func.visibility).to.equal('public'); + + const param1 = func.parameters.parameters.find((param) => param.name === '_newValue'); + expect(param1).to.not.be.undefined; + expect(param1?.typeDescriptions.typeString).to.equal('uint256'); + + const param2 = func.parameters.parameters.find((param) => param.name === '_result'); + expect(param2).to.not.be.undefined; + expect(param2?.typeDescriptions.typeString).to.equal('bool'); + }); + + it('MockIContractTest must include interface functions', async () => { + const contractNode = contractNodes['MockIContractTest']; + const func = contractNode.nodes.find( + (node) => node.nodeType === 'FunctionDefinition' && node.name === 'setVariables' && node.parameters.parameters.length === 8, + ) as FunctionDefinitionNode; + expect(func).to.not.be.undefined; + + const param1 = func.parameters.parameters.find((param) => param.name === '_newValue'); + expect(param1).to.not.be.undefined; + expect(param1?.typeDescriptions.typeString).to.equal('uint256'); + + const param2 = func.parameters.parameters.find((param) => param.name === '_newString'); + expect(param2).to.not.be.undefined; + expect(param2?.typeDescriptions.typeString).to.equal('string'); + expect(param2?.storageLocation).to.equal('memory'); + + const param3 = func.parameters.parameters.find((param) => param.name === '_newBool'); + expect(param3).to.not.be.undefined; + expect(param3?.typeDescriptions.typeString).to.equal('bool'); + + const param4 = func.parameters.parameters.find((param) => param.name === '_newAddress'); + expect(param4).to.not.be.undefined; + expect(param4?.typeDescriptions.typeString).to.equal('address'); + + const param5 = func.parameters.parameters.find((param) => param.name === '_newBytes32'); + expect(param5).to.not.be.undefined; + expect(param5?.typeDescriptions.typeString).to.equal('bytes32'); + + const param6 = func.parameters.parameters.find((param) => param.name === '_addressArray'); + expect(param6).to.not.be.undefined; + expect(param6?.typeDescriptions.typeString).to.equal('address[]'); + expect(param6?.storageLocation).to.equal('memory'); + + const param7 = func.parameters.parameters.find((param) => param.name === '_uint256Array'); + expect(param7).to.not.be.undefined; + expect(param7?.typeDescriptions.typeString).to.equal('uint256[]'); + expect(param7?.storageLocation).to.equal('memory'); + + const param8 = func.parameters.parameters.find((param) => param.name === '_bytes32Array'); + expect(param8).to.not.be.undefined; + expect(param8?.typeDescriptions.typeString).to.equal('bytes32[]'); + expect(param8?.storageLocation).to.equal('memory'); + }); + + it('MockIContractTest must include mock call', async () => { + const contractNode = contractNodes['MockIContractTest']; + const func = contractNode.nodes.find( + (node) => node.nodeType === 'FunctionDefinition' && node.name === 'mock_call_setVariables' && node.parameters.parameters.length === 3, + ) as FunctionDefinitionNode; + expect(func).to.not.be.undefined; + expect(func.visibility).to.equal('public'); + + const param0 = func.parameters.parameters.find((param) => param.name === '_param0'); + expect(param0).to.not.be.undefined; + expect(param0?.typeDescriptions.typeString).to.equal('uint256'); + + const param1 = func.parameters.parameters.find((param) => param.name === '_param1'); + expect(param1).to.not.be.undefined; + expect(param1?.typeDescriptions.typeString).to.equal('bool'); + + const param2 = func.parameters.parameters.find((param) => param.name === '_return0'); + expect(param2).to.not.be.undefined; + expect(param2?.typeDescriptions.typeString).to.equal('bool'); + }); +}); diff --git a/test/e2e/get-internal-functions.spec.ts b/test/e2e/get-internal-functions.spec.ts index 9f64ddc..b2b48b0 100644 --- a/test/e2e/get-internal-functions.spec.ts +++ b/test/e2e/get-internal-functions.spec.ts @@ -1,19 +1,20 @@ -// Write unit tests for the getInternalFunctions function here like the other tests. +// Write e2e tests for the getInternalFunctions function here like the other tests. import { expect } from 'chai'; import { ContractDefinitionNode, FunctionDefinitionNode } from '../../src/types'; import { generateMockContracts } from '../../src/index'; import { resolve } from 'path'; // We use the describe function to group together related tests -describe('getInternalMockFunctions', () => { +describe('E2E: getInternalMockFunctions', () => { // We use the beforeEach function to reset the contract node before each test let contractNodes: { [name: string]: ContractDefinitionNode }; before(async () => { // generate mock contracts - const contractsDir = 'solidity/contracts'; + const contractsDir = ['solidity/contracts', 'solidity/interfaces']; const compiledArtifactsDir = 'out'; const generatedContractsDir = 'solidity/test/mock-contracts'; - await generateMockContracts(contractsDir, compiledArtifactsDir, generatedContractsDir); + const ignoreDir = []; + await generateMockContracts(contractsDir, compiledArtifactsDir, generatedContractsDir, ignoreDir); const contractsNames = ['ContractD']; diff --git a/test/e2e/set-variables.spec.ts b/test/e2e/set-variables.spec.ts index 6a8b797..1a292dd 100644 --- a/test/e2e/set-variables.spec.ts +++ b/test/e2e/set-variables.spec.ts @@ -3,14 +3,15 @@ import { ContractDefinitionNode, FunctionDefinitionNode } from '../../src/types' import { generateMockContracts } from '../../src/index'; import { resolve } from 'path'; -describe('getInternalMockFunctions', () => { +describe('E2E: getStateVariables', () => { let contractNodes: { [name: string]: ContractDefinitionNode }; before(async () => { // generate mock contracts - const contractsDir = 'solidity/contracts'; + const contractsDir = ['solidity/contracts', 'solidity/interfaces']; const compiledArtifactsDir = 'out'; const generatedContractsDir = 'solidity/test/mock-contracts'; - await generateMockContracts(contractsDir, compiledArtifactsDir, generatedContractsDir); + const ignoreDir = []; + await generateMockContracts(contractsDir, compiledArtifactsDir, generatedContractsDir, ignoreDir); const contractsNames = ['ContractTest']; diff --git a/test/unit/get-constructor.spec.ts b/test/unit/get-constructor.spec.ts index 96adcde..ab28439 100644 --- a/test/unit/get-constructor.spec.ts +++ b/test/unit/get-constructor.spec.ts @@ -34,6 +34,7 @@ describe('getConstructor', () => { }, virtual: false, visibility: 'public', + stateMutability: 'nonpayable', }, ]; const constructorSignature = getConstructor(contractNode); @@ -63,6 +64,7 @@ describe('getConstructor', () => { }, virtual: false, visibility: 'public', + stateMutability: 'nonpayable', }, ]; const constructorSignature = getConstructor(contractNode); @@ -92,6 +94,7 @@ describe('getConstructor', () => { }, virtual: false, visibility: 'public', + stateMutability: 'nonpayable', }, ]; const constructorSignature = getConstructor(contractNode); @@ -121,6 +124,7 @@ describe('getConstructor', () => { }, virtual: false, visibility: 'public', + stateMutability: 'nonpayable', }, ]; const constructorSignature = getConstructor(contractNode); @@ -150,6 +154,7 @@ describe('getConstructor', () => { }, virtual: false, visibility: 'public', + stateMutability: 'nonpayable', }, ]; const constructorSignature = getConstructor(contractNode); @@ -179,6 +184,7 @@ describe('getConstructor', () => { }, virtual: false, visibility: 'public', + stateMutability: 'nonpayable', }, ]; const constructorSignature = getConstructor(contractNode); diff --git a/test/unit/get-external-functions.spec.ts b/test/unit/get-external-functions.spec.ts index 78fd0dd..953cc33 100644 --- a/test/unit/get-external-functions.spec.ts +++ b/test/unit/get-external-functions.spec.ts @@ -39,6 +39,7 @@ describe('getExternalMockFunctions', () => { }, virtual: false, visibility: 'internal', + stateMutability: 'nonpayable', }, { name: 'myFunction2', @@ -52,6 +53,7 @@ describe('getExternalMockFunctions', () => { }, virtual: false, visibility: 'private', + stateMutability: 'nonpayable', }, ]; const externalFunctions = getExternalMockFunctions(contractNode); @@ -72,6 +74,7 @@ describe('getExternalMockFunctions', () => { }, virtual: false, visibility: 'public', + stateMutability: 'nonpayable', }, ]; const externalFunctions = getExternalMockFunctions(contractNode); @@ -109,6 +112,7 @@ describe('getExternalMockFunctions', () => { }, virtual: false, visibility: 'public', + stateMutability: 'view', }, ]; const externalFunctions = getExternalMockFunctions(contractNode); @@ -119,6 +123,10 @@ describe('getExternalMockFunctions', () => { signature: 'myFunction(string,string)', inputsStringNames: ', _param, _param2', outputsStringNames: '', + inputString: 'string memory _param, string calldata _param2', + outputString: '', + isInterface: false, + stateMutabilityString: ' view ', }, ]; expect(externalFunctions).to.be.an('array').that.is.not.empty; @@ -164,6 +172,7 @@ describe('getExternalMockFunctions', () => { }, virtual: false, visibility: 'public', + stateMutability: 'nonpayable', }, ]; const externalFunctions = getExternalMockFunctions(contractNode); @@ -174,6 +183,10 @@ describe('getExternalMockFunctions', () => { signature: 'myFunction(IERC20,MyStruct,MyEnum)', inputsStringNames: ', _param, _param2, _param3', outputsStringNames: '', + inputString: 'IERC20 _param, MyStruct _param2, MyEnum _param3', + outputString: '', + isInterface: false, + stateMutabilityString: ' ', }, ]; expect(externalFunctions).to.be.an('array').that.is.not.empty; @@ -211,6 +224,7 @@ describe('getExternalMockFunctions', () => { }, virtual: false, visibility: 'public', + stateMutability: 'nonpayable', }, ]; const externalFunctions = getExternalMockFunctions(contractNode); @@ -221,6 +235,10 @@ describe('getExternalMockFunctions', () => { signature: 'myFunction()', inputsStringNames: '', outputsStringNames: '_param, _param2', + inputString: '', + outputString: 'string memory _param, string calldata _param2', + isInterface: false, + stateMutabilityString: ' ', }, ]; expect(externalFunctions).to.be.an('array').that.is.not.empty; @@ -266,6 +284,7 @@ describe('getExternalMockFunctions', () => { }, virtual: false, visibility: 'public', + stateMutability: 'nonpayable', }, ]; const externalFunctions = getExternalMockFunctions(contractNode); @@ -276,6 +295,10 @@ describe('getExternalMockFunctions', () => { signature: 'myFunction()', inputsStringNames: '', outputsStringNames: '_param, _param2, _param3', + inputString: '', + outputString: 'IERC20 _param, MyStruct _param2, MyEnum _param3', + isInterface: false, + stateMutabilityString: ' ', }, ]; expect(externalFunctions).to.be.an('array').that.is.not.empty; @@ -314,6 +337,7 @@ describe('getExternalMockFunctions', () => { }, virtual: false, visibility: 'public', + stateMutability: 'nonpayable', }, ]; const externalFunctions = getExternalMockFunctions(contractNode); @@ -324,6 +348,72 @@ describe('getExternalMockFunctions', () => { signature: 'myFunction(string)', inputsStringNames: ', _param', outputsStringNames: '_output', + inputString: 'string memory _param', + outputString: 'string memory _output', + isInterface: false, + stateMutabilityString: ' ', + }, + ]; + expect(externalFunctions).to.be.an('array').that.is.not.empty; + expect(externalFunctions).to.deep.equal(expectedData); + }); + + it('should return the correct function data if the contract is interface', async () => { + contractNode = { + nodeType: 'ContractDefinition', + canonicalName: 'MyContract', + nodes: [], + abstract: false, + contractKind: 'interface', + name: 'MyContract', + }; + + contractNode.nodes = [ + { + name: 'myFunction', + nodeType: 'FunctionDefinition', + kind: 'function', + parameters: { + parameters: [ + { + name: '_param', + nodeType: 'VariableDeclaration', + typeDescriptions: { + typeString: 'string', + }, + storageLocation: 'memory', + }, + ], + }, + returnParameters: { + parameters: [ + { + name: '_output', + nodeType: 'VariableDeclaration', + typeDescriptions: { + typeString: 'string', + }, + storageLocation: 'memory', + }, + ], + }, + virtual: false, + visibility: 'public', + stateMutability: 'nonpayable', + }, + ]; + const externalFunctions = getExternalMockFunctions(contractNode); + const expectedData: ExternalFunctionOptions[] = [ + { + functionName: 'myFunction', + arguments: 'string memory _param, string memory _output', + signature: 'myFunction(string)', + inputsStringNames: ', _param', + outputsStringNames: '_output', + inputString: 'string memory _param', + outputString: 'string memory _output', + isInterface: true, + stateMutabilityString: ' ', }, ]; expect(externalFunctions).to.be.an('array').that.is.not.empty; diff --git a/test/unit/get-imports.spec.ts b/test/unit/get-imports.spec.ts index 9e35201..f37aaa3 100644 --- a/test/unit/get-imports.spec.ts +++ b/test/unit/get-imports.spec.ts @@ -12,6 +12,7 @@ describe('getImports', () => { src: '', nodes: [], license: '', + exportedSymbols: {}, }; }); it('should return an empty array if there are no import directives', async () => { diff --git a/test/unit/get-internal-functions.spec.ts b/test/unit/get-internal-functions.spec.ts index 24f0d6d..7ae2d33 100644 --- a/test/unit/get-internal-functions.spec.ts +++ b/test/unit/get-internal-functions.spec.ts @@ -70,6 +70,7 @@ describe('getInternalMockFunctions', () => { }, virtual: false, visibility: 'external', + stateMutability: 'nonpayable', }, { name: 'myFunction2', @@ -83,6 +84,7 @@ describe('getInternalMockFunctions', () => { }, virtual: false, visibility: 'private', + stateMutability: 'nonpayable', }, ]; const internalFunctions = getInternalMockFunctions(contractNode); @@ -103,6 +105,7 @@ describe('getInternalMockFunctions', () => { }, virtual: false, visibility: 'public', + stateMutability: 'nonpayable', }, ]; const internalFunctions = getInternalMockFunctions(contractNode); @@ -123,6 +126,7 @@ describe('getInternalMockFunctions', () => { }, virtual: true, visibility: 'internal', + stateMutability: 'nonpayable', }, ]; const internalFunctions = getInternalMockFunctions(contractNode); @@ -156,6 +160,7 @@ describe('getInternalMockFunctions', () => { }, virtual: false, visibility: 'internal', + stateMutability: 'nonpayable', }, ]; const internalFunctions = getInternalMockFunctions(contractNode); @@ -176,6 +181,7 @@ describe('getInternalMockFunctions', () => { }, virtual: true, visibility: 'internal', + stateMutability: 'nonpayable', }, ]; const internalFunctions = getInternalMockFunctions(contractNode); @@ -209,6 +215,7 @@ describe('getInternalMockFunctions', () => { }, virtual: true, visibility: 'internal', + stateMutability: 'nonpayable', }, ]; const internalFunctions = getInternalMockFunctions(contractNode); @@ -242,6 +249,7 @@ describe('getInternalMockFunctions', () => { }, virtual: true, visibility: 'internal', + stateMutability: 'nonpayable', }, ]; const internalFunctions = getInternalMockFunctions(contractNode); @@ -275,6 +283,7 @@ describe('getInternalMockFunctions', () => { }, virtual: true, visibility: 'internal', + stateMutability: 'nonpayable', }, ]; const internalFunctions = getInternalMockFunctions(contractNode); From 3718fc8274957c8cc4cac5e739a292a82889aa30 Mon Sep 17 00:00:00 2001 From: dristpunk <107591874+dristpunk@users.noreply.github.com> Date: Wed, 1 Nov 2023 16:26:56 +0300 Subject: [PATCH 2/2] feat: preserve folders (#6) --- solidity/test/ContractTest.t.sol | 2 +- src/mock-contract-generator.ts | 28 +++++++++++++++++++--------- src/utils.ts | 15 +++++++++------ 3 files changed, 29 insertions(+), 16 deletions(-) diff --git a/solidity/test/ContractTest.t.sol b/solidity/test/ContractTest.t.sol index 22a11a8..b83a5d1 100644 --- a/solidity/test/ContractTest.t.sol +++ b/solidity/test/ContractTest.t.sol @@ -3,7 +3,7 @@ pragma solidity ^0.8.0; import {Test} from 'forge-std/Test.sol'; import {IERC20} from 'isolmate/interfaces/tokens/IERC20.sol'; -import {MockContractTest} from 'test/mock-contracts/MockContractTest.sol'; +import {MockContractTest} from 'test/mock-contracts/contracts/MockContractTest.sol'; import {console} from 'forge-std/console.sol'; contract CommonE2EBase is Test { diff --git a/src/mock-contract-generator.ts b/src/mock-contract-generator.ts index 673a579..3cab41e 100644 --- a/src/mock-contract-generator.ts +++ b/src/mock-contract-generator.ts @@ -1,5 +1,5 @@ import { getExternalMockFunctions, getInternalMockFunctions, getConstructor, getImports, getStateVariables, Ast } from './index'; -import { getSubDirNameFromPath, registerHandlebarsTemplates, getContractNames, compileSolidityFilesFoundry } from './utils'; +import { getSubDirNameFromPath, registerHandlebarsTemplates, getContractNamesAndFolders, compileSolidityFilesFoundry } from './utils'; import Handlebars from 'handlebars'; import { writeFileSync, existsSync, readdirSync } from 'fs'; import { ensureDir, emptyDir } from 'fs-extra'; @@ -35,12 +35,14 @@ export const generateMockContracts = async ( console.error('Error while trying to empty the mock directory: ', error); } console.log('Parsing contracts...'); - // Get all contracts directories - const contractPaths: string[] = getContractNames(contractsDir, ignoreDir); + + // Get all contracts names and paths + const [contractFileNames, contractFolders] = getContractNamesAndFolders(contractsDir, ignoreDir); + // Loop for each contract path - contractPaths.forEach(async (contractPath: string) => { + contractFileNames.forEach(async (contractFileName: string, ind: number) => { // Get the sub dir name - const subDirName: string = getSubDirNameFromPath(contractPath); + const subDirName: string = getSubDirNameFromPath(contractFileName); // Get contract name // If the contract and the file have different names, it will be modified. @@ -48,11 +50,11 @@ export const generateMockContracts = async ( // Get the compiled path // If the contract and the file have different names, it will be modified. - let compiledArtifactsPath = resolve(compiledArtifactsDir, contractPath, subDirName); + let compiledArtifactsPath = resolve(compiledArtifactsDir, contractFileName, subDirName); // Check if contract and file have different names if (!existsSync(compiledArtifactsPath)) { - const directoryPath = resolve(compiledArtifactsDir, contractPath); + const directoryPath = resolve(compiledArtifactsDir, contractFileName); // If the directory path does not exist, the contract is not compiled. if (!existsSync(directoryPath)) return; @@ -62,7 +64,7 @@ export const generateMockContracts = async ( // Get the real path of the json file // If this !path means that the file is not compiled - compiledArtifactsPath = resolve(compiledArtifactsDir, contractPath, subDirContractName[0]); + compiledArtifactsPath = resolve(compiledArtifactsDir, contractFileName, subDirContractName[0]); if (!compiledArtifactsPath) return; contractName = subDirContractName[0].replace('.json', ''); @@ -116,7 +118,15 @@ export const generateMockContracts = async ( .replace(/;;/g, ';'); // Write the contract - writeFileSync(`${generatedContractsDir}/Mock${contractName}.sol`, cleanedCode); + const contractFolder = `${generatedContractsDir}/${contractFolders[ind]}`; + // Create the directory if it doesn't exist + try { + await ensureDir(contractFolder); + } catch (error) { + console.error('Error while creating the mock directory: ', error); + } + + writeFileSync(`${contractFolder}/Mock${contractName}.sol`, cleanedCode); }); console.log('Mock contracts generated successfully'); diff --git a/src/utils.ts b/src/utils.ts index f48861d..cdd5c6e 100644 --- a/src/utils.ts +++ b/src/utils.ts @@ -1,5 +1,5 @@ import { arrayRegex, memoryTypes, structRegex } from './types'; -import { resolve, join } from 'path'; +import { resolve, join, relative, dirname } from 'path'; import { readFileSync, readdirSync, statSync } from 'fs'; import { exec } from 'child_process'; import Handlebars from 'handlebars'; @@ -88,10 +88,12 @@ export const registerHandlebarsTemplates = (): string => { * @param contractsDir The directory where the contracts are located * @returns The names of the contracts in the given directory and its subdirectories */ -export const getContractNames = (contractsDir: string[], ignoreDir: string[]): string[] => { +export const getContractNamesAndFolders = (contractsDir: string[], ignoreDir: string[]): [string[], string[]] => { + const contractFileNames: string[] = []; + const contractFolders: string[] = []; // Recursive function to traverse the directory and its subdirectories - function traverseDirectory(currentPath: string) { + function traverseDirectory(currentPath: string, baseDir: string) { const fileNames = readdirSync(currentPath); // Loop through the files and directories fileNames.forEach((fileName: string) => { @@ -100,17 +102,18 @@ export const getContractNames = (contractsDir: string[], ignoreDir: string[]): s // If the file is a contract then we add it to the array, if it is a directory then we call the function again if (stats.isFile() && fileName.endsWith('.sol')) { contractFileNames.push(fileName); + contractFolders.push(dirname(relative(baseDir, filePath))); } else if (stats.isDirectory() && !ignoreDir.includes(fileName)) { - traverseDirectory(filePath); + traverseDirectory(filePath, baseDir); } }); } contractsDir.map((dir: string) => { - traverseDirectory(dir); + traverseDirectory(dir, dirname(dir)); }); - return contractFileNames; + return [contractFileNames, contractFolders]; }; /**