Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: mock interfaces #5

Merged
merged 6 commits into from
Nov 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/get-external-functions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ export const getExternalMockFunctions = (contractNode: ContractDefinitionNode):
// Filter the nodes and keep only the FunctionDefinition related ones
const functionNodes = contractNode.nodes.filter((node) => node.nodeType === 'FunctionDefinition') as FunctionDefinitionNode[];

// Get contract kind
const contractKind = contractNode.contractKind;

const externalFunctions: ExternalFunctionOptions[] = [];
// Loop through the function nodes
functionNodes.forEach((funcNode: FunctionDefinitionNode) => {
Expand All @@ -17,6 +20,10 @@ export const getExternalMockFunctions = (contractNode: ContractDefinitionNode):
// Check if the function is external or public
if (funcNode.visibility != 'external' && funcNode.visibility != 'public') return;

// Save state mutability
const stateMutability = funcNode.stateMutability;
const stateMutabilityString = stateMutability == 'nonpayable' ? ' ' : ` ${stateMutability} `;

// Get the parameters of the function, if there are no parameters then we use an empty array
const parameters: VariableDeclarationNode[] = funcNode.parameters.parameters ? funcNode.parameters.parameters : [];

Expand Down Expand Up @@ -95,6 +102,10 @@ export const getExternalMockFunctions = (contractNode: ContractDefinitionNode):
signature: signature,
inputsStringNames: inputsStringNames,
outputsStringNames: outputsStringNames,
inputString: inputsString,
outputString: outputsString,
isInterface: contractKind === 'interface',
stateMutabilityString: stateMutabilityString,
};

externalFunctions.push(externalMockFunction);
Expand Down
15 changes: 13 additions & 2 deletions src/mock-contract-generator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@ import { StateVariablesOptions, ContractDefinitionNode } from './types';
* @param compiledArtifactsDir The directory where the compiled artifacts are located
* @param generatedContractsDir The directory where the mock contracts will be generated
*/
export const generateMockContracts = async (contractsDir: string, compiledArtifactsDir: string, generatedContractsDir: string) => {
export const generateMockContracts = async (
contractsDir: string[],
compiledArtifactsDir: string,
generatedContractsDir: string,
ignoreDir: string[],
) => {
const templateContent: string = registerHandlebarsTemplates();
const template = Handlebars.compile(templateContent);
try {
Expand All @@ -31,7 +36,7 @@ export const generateMockContracts = async (contractsDir: string, compiledArtifa
}
console.log('Parsing contracts...');
// Get all contracts directories
const contractPaths: string[] = getContractNames(contractsDir);
const contractPaths: string[] = getContractNames(contractsDir, ignoreDir);
// Loop for each contract path
contractPaths.forEach(async (contractPath: string) => {
// Get the sub dir name
Expand Down Expand Up @@ -71,11 +76,16 @@ export const generateMockContracts = async (contractsDir: string, compiledArtifa
const contractImport: string = ast.absolutePath;
if (!contractImport) return;

// Get all exported entities
const exportedSymbols = Object.keys(ast.exportedSymbols);

// Get the contract node and check if it's a library
// Also check if is another contract inside the file and avoid it
const contractNode = ast.nodes.find(
(node) => node.nodeType === 'ContractDefinition' && node.canonicalName === contractName,
) as ContractDefinitionNode;

// Skip unneeded contracts
if (!contractNode || contractNode.abstract || contractNode.contractKind === 'library') return;

const functions: StateVariablesOptions = getStateVariables(contractNode);
Expand All @@ -84,6 +94,7 @@ export const generateMockContracts = async (contractsDir: string, compiledArtifa
const data = {
contractName: contractName,
contractImport: contractImport,
exportedSymbols: exportedSymbols.join(', '),
import: getImports(ast),
constructor: getConstructor(contractNode),
mockExternalFunctions: getExternalMockFunctions(contractNode),
Expand Down
15 changes: 11 additions & 4 deletions src/run.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,18 @@ import { hideBin } from 'yargs/helpers';
import { generateMockContracts } from './index';

(async () => {
const { contracts, out, genDir } = getProcessArguments();
generateMockContracts(contracts, out, genDir);
const { contracts, out, genDir, ignore } = getProcessArguments();
generateMockContracts(contracts, out, genDir, ignore);
})();

function getProcessArguments() {
return yargs(hideBin(process.argv))
.options({
contracts: {
describe: 'Contracts directory',
describe: 'Contracts directories',
demandOption: true,
type: 'string',
type: 'array',
string: true,
},
out: {
describe: 'Foundry compiled output directory',
Expand All @@ -27,6 +28,12 @@ function getProcessArguments() {
default: './solidity/test/mock-contracts',
type: 'string',
},
ignore: {
describe: 'Ignore folders',
default: [],
type: 'array',
string: true,
},
})
.parseSync();
}
2 changes: 1 addition & 1 deletion src/templates/mockContractTemplate.hbs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
pragma solidity ^0.8.0;

import {Test} from 'forge-std/Test.sol';
import { {{~contractName~}} } from '{{contractImport}}';
import { {{~exportedSymbols~}} } from '{{contractImport}}';
{{#each import}}
{{this}};
{{/each}}
Expand Down
3 changes: 3 additions & 0 deletions src/templates/mockExternalFunctionTemplate.hbs
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
{{#if isInterface}}
function {{functionName}}({{inputString}}) external{{stateMutabilityString}}returns ({{outputString}}) {}
{{/if}}
function mock_call_{{functionName}}({{arguments}}) public {
vm.mockCall(
address(this),
Expand Down
6 changes: 6 additions & 0 deletions src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ export interface FunctionDefinitionNode {
};
virtual: boolean;
visibility: string;
stateMutability: string;
}

export interface Ast {
Expand All @@ -57,6 +58,7 @@ export interface Ast {
src: string;
nodes: AstNode[];
license: string;
exportedSymbols: { [key: string]: number[] };
}

export interface ImportDirectiveNode {
Expand Down Expand Up @@ -135,6 +137,10 @@ export interface ExternalFunctionOptions {
signature: string;
inputsStringNames: string;
outputsStringNames: string;
inputString: string;
outputString: string;
isInterface: boolean;
stateMutabilityString: string;
}

export interface InternalFunctionOptions {
Expand Down
8 changes: 5 additions & 3 deletions src/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ export const registerHandlebarsTemplates = (): string => {
* @param contractsDir The directory where the contracts are located
* @returns The names of the contracts in the given directory and its subdirectories
*/
export const getContractNames = (contractsDir: string): string[] => {
export const getContractNames = (contractsDir: string[], ignoreDir: string[]): string[] => {
const contractFileNames: string[] = [];
// Recursive function to traverse the directory and its subdirectories
function traverseDirectory(currentPath: string) {
Expand All @@ -100,13 +100,15 @@ export const getContractNames = (contractsDir: string): string[] => {
// If the file is a contract then we add it to the array, if it is a directory then we call the function again
if (stats.isFile() && fileName.endsWith('.sol')) {
contractFileNames.push(fileName);
} else if (stats.isDirectory()) {
} else if (stats.isDirectory() && !ignoreDir.includes(fileName)) {
traverseDirectory(filePath);
}
});
}

traverseDirectory(contractsDir);
contractsDir.map((dir: string) => {
traverseDirectory(dir);
});

return contractFileNames;
};
Expand Down
129 changes: 129 additions & 0 deletions test/e2e/get-external-functions.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
// Write e2e tests for the getExternalFunctions function here like the other tests.
import { expect } from 'chai';
import { ContractDefinitionNode, FunctionDefinitionNode } from '../../src/types';
import { generateMockContracts } from '../../src/index';
import { resolve } from 'path';

// We use the describe function to group together related tests
describe('E2E: getExternalMockFunctions', () => {
// We use the beforeEach function to reset the contract node before each test
let contractNodes: { [name: string]: ContractDefinitionNode };
before(async () => {
// generate mock contracts
const contractsDir = ['solidity/contracts', 'solidity/interfaces'];
const compiledArtifactsDir = 'out';
const generatedContractsDir = 'solidity/test/mock-contracts';
const ignoreDir = [];
await generateMockContracts(contractsDir, compiledArtifactsDir, generatedContractsDir, ignoreDir);

const contractsNames = ['ContractTest', 'IContractTest'];

contractsNames.forEach((contractName: string) => {
const mockName = `Mock${contractName}`;
const compiledArtifactsPath = resolve(compiledArtifactsDir, `${mockName}.sol`, `${mockName}.json`);
const ast = require(compiledArtifactsPath).ast;
if (!ast) throw new Error(`AST for ${mockName} not found`);
const contractNode = ast.nodes.find(
(node) => node.nodeType === 'ContractDefinition' && node.canonicalName === mockName,
) as ContractDefinitionNode;
if (!contractNode || contractNode.abstract || contractNode.contractKind === 'library') throw new Error(`Contract ${mockName} not found`);

contractNodes = { ...contractNodes, [mockName]: contractNode };
});
});

// We use the it function to create a test
it('MockContractTest must include constructor', async () => {
const contractNode = contractNodes['MockContractTest'];
const constructor = contractNode.nodes.find(
(node) => node.nodeType === 'FunctionDefinition' && node.kind === 'constructor',
) as FunctionDefinitionNode;
expect(constructor).to.not.be.undefined;

const param1 = constructor.parameters.parameters.find((param) => param.name === '_uintVariable');
expect(param1).to.not.be.undefined;
expect(param1?.typeDescriptions.typeString).to.equal('uint256');
});

it('MockContractTest must include mock call', async () => {
const contractNode = contractNodes['MockContractTest'];
const func = contractNode.nodes.find(
(node) => node.nodeType === 'FunctionDefinition' && node.name === 'mock_call_setVariables' && node.parameters.parameters.length === 2,
) as FunctionDefinitionNode;
expect(func).to.not.be.undefined;
expect(func.visibility).to.equal('public');

const param1 = func.parameters.parameters.find((param) => param.name === '_newValue');
expect(param1).to.not.be.undefined;
expect(param1?.typeDescriptions.typeString).to.equal('uint256');

const param2 = func.parameters.parameters.find((param) => param.name === '_result');
expect(param2).to.not.be.undefined;
expect(param2?.typeDescriptions.typeString).to.equal('bool');
});

it('MockIContractTest must include interface functions', async () => {
const contractNode = contractNodes['MockIContractTest'];
const func = contractNode.nodes.find(
(node) => node.nodeType === 'FunctionDefinition' && node.name === 'setVariables' && node.parameters.parameters.length === 8,
) as FunctionDefinitionNode;
expect(func).to.not.be.undefined;

const param1 = func.parameters.parameters.find((param) => param.name === '_newValue');
expect(param1).to.not.be.undefined;
expect(param1?.typeDescriptions.typeString).to.equal('uint256');

const param2 = func.parameters.parameters.find((param) => param.name === '_newString');
expect(param2).to.not.be.undefined;
expect(param2?.typeDescriptions.typeString).to.equal('string');
expect(param2?.storageLocation).to.equal('memory');

const param3 = func.parameters.parameters.find((param) => param.name === '_newBool');
expect(param3).to.not.be.undefined;
expect(param3?.typeDescriptions.typeString).to.equal('bool');

const param4 = func.parameters.parameters.find((param) => param.name === '_newAddress');
expect(param4).to.not.be.undefined;
expect(param4?.typeDescriptions.typeString).to.equal('address');

const param5 = func.parameters.parameters.find((param) => param.name === '_newBytes32');
expect(param5).to.not.be.undefined;
expect(param5?.typeDescriptions.typeString).to.equal('bytes32');

const param6 = func.parameters.parameters.find((param) => param.name === '_addressArray');
expect(param6).to.not.be.undefined;
expect(param6?.typeDescriptions.typeString).to.equal('address[]');
expect(param6?.storageLocation).to.equal('memory');

const param7 = func.parameters.parameters.find((param) => param.name === '_uint256Array');
expect(param7).to.not.be.undefined;
expect(param7?.typeDescriptions.typeString).to.equal('uint256[]');
expect(param7?.storageLocation).to.equal('memory');

const param8 = func.parameters.parameters.find((param) => param.name === '_bytes32Array');
expect(param8).to.not.be.undefined;
expect(param8?.typeDescriptions.typeString).to.equal('bytes32[]');
expect(param8?.storageLocation).to.equal('memory');
});

it('MockIContractTest must include mock call', async () => {
const contractNode = contractNodes['MockIContractTest'];
const func = contractNode.nodes.find(
(node) => node.nodeType === 'FunctionDefinition' && node.name === 'mock_call_setVariables' && node.parameters.parameters.length === 3,
) as FunctionDefinitionNode;
expect(func).to.not.be.undefined;
expect(func.visibility).to.equal('public');

const param0 = func.parameters.parameters.find((param) => param.name === '_param0');
expect(param0).to.not.be.undefined;
expect(param0?.typeDescriptions.typeString).to.equal('uint256');

const param1 = func.parameters.parameters.find((param) => param.name === '_param1');
expect(param1).to.not.be.undefined;
expect(param1?.typeDescriptions.typeString).to.equal('bool');

const param2 = func.parameters.parameters.find((param) => param.name === '_return0');
expect(param2).to.not.be.undefined;
expect(param2?.typeDescriptions.typeString).to.equal('bool');
});
});
9 changes: 5 additions & 4 deletions test/e2e/get-internal-functions.spec.ts
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
// Write unit tests for the getInternalFunctions function here like the other tests.
// Write e2e tests for the getInternalFunctions function here like the other tests.
import { expect } from 'chai';
import { ContractDefinitionNode, FunctionDefinitionNode } from '../../src/types';
import { generateMockContracts } from '../../src/index';
import { resolve } from 'path';

// We use the describe function to group together related tests
describe('getInternalMockFunctions', () => {
describe('E2E: getInternalMockFunctions', () => {
// We use the beforeEach function to reset the contract node before each test
let contractNodes: { [name: string]: ContractDefinitionNode };
before(async () => {
// generate mock contracts
const contractsDir = 'solidity/contracts';
const contractsDir = ['solidity/contracts', 'solidity/interfaces'];
const compiledArtifactsDir = 'out';
const generatedContractsDir = 'solidity/test/mock-contracts';
await generateMockContracts(contractsDir, compiledArtifactsDir, generatedContractsDir);
const ignoreDir = [];
await generateMockContracts(contractsDir, compiledArtifactsDir, generatedContractsDir, ignoreDir);

const contractsNames = ['ContractD'];

Expand Down
7 changes: 4 additions & 3 deletions test/e2e/set-variables.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@ import { ContractDefinitionNode, FunctionDefinitionNode } from '../../src/types'
import { generateMockContracts } from '../../src/index';
import { resolve } from 'path';

describe('getInternalMockFunctions', () => {
describe('E2E: getStateVariables', () => {
let contractNodes: { [name: string]: ContractDefinitionNode };
before(async () => {
// generate mock contracts
const contractsDir = 'solidity/contracts';
const contractsDir = ['solidity/contracts', 'solidity/interfaces'];
const compiledArtifactsDir = 'out';
const generatedContractsDir = 'solidity/test/mock-contracts';
await generateMockContracts(contractsDir, compiledArtifactsDir, generatedContractsDir);
const ignoreDir = [];
await generateMockContracts(contractsDir, compiledArtifactsDir, generatedContractsDir, ignoreDir);

const contractsNames = ['ContractTest'];

Expand Down
Loading