From 88b0a79bf9bc480228e1966f6936ed62e6fba72e Mon Sep 17 00:00:00 2001 From: tuturu-tech Date: Tue, 2 Apr 2024 12:56:33 +0200 Subject: [PATCH 1/4] derive configuration without CLI arguments --- fuzz_utils/generate/FoundryTest.py | 27 +++++------------ fuzz_utils/generate/fuzzers/Echidna.py | 16 ++-------- fuzz_utils/generate/fuzzers/Medusa.py | 18 +++--------- fuzz_utils/parsing/commands/generate.py | 39 +++++++++++++++++++++---- fuzz_utils/parsing/commands/template.py | 20 +++++-------- fuzz_utils/parsing/parser_util.py | 31 ++++++++++++++++++++ fuzz_utils/template/HarnessGenerator.py | 22 ++++---------- fuzz_utils/utils/slither_utils.py | 15 ++++++++++ tests/test_harness.py | 3 +- 9 files changed, 111 insertions(+), 80 deletions(-) create mode 100644 fuzz_utils/parsing/parser_util.py create mode 100644 fuzz_utils/utils/slither_utils.py diff --git a/fuzz_utils/generate/FoundryTest.py b/fuzz_utils/generate/FoundryTest.py index 66e0ae3..d9b05c9 100644 --- a/fuzz_utils/generate/FoundryTest.py +++ b/fuzz_utils/generate/FoundryTest.py @@ -1,20 +1,19 @@ """The FoundryTest class that handles generation of unit tests from call sequences""" import os -import sys import json from typing import Any import jinja2 from slither import Slither -from slither.core.declarations.contract import Contract from fuzz_utils.utils.crytic_print import CryticPrint +from fuzz_utils.utils.slither_utils import get_target_contract from fuzz_utils.generate.fuzzers.Medusa import Medusa from fuzz_utils.generate.fuzzers.Echidna import Echidna from fuzz_utils.templates.foundry_templates import templates - -class FoundryTest: # pylint: disable=too-many-instance-attributes +# pylint: disable=too-few-public-methods,too-many-instance-attributes +class FoundryTest: """ Handles the generation of Foundry test files """ @@ -31,20 +30,10 @@ def __init__( self.test_dir = config["testsDir"] self.all_sequences = config["allSequences"] self.slither = slither - self.target = self.get_target_contract() + self.target = get_target_contract(self.slither, self.target_name) + self.target_file_name = self.target.source_mapping.filename.relative.split("/")[-1] self.fuzzer = fuzzer - def get_target_contract(self) -> Contract: - """Gets the Slither Contract object for the specified contract file""" - contracts = self.slither.get_contract_from_name(self.target_name) - # Loop in case slither fetches multiple contracts for some reason (e.g., similar names?) - for contract in contracts: - if contract.name == self.target_name: - return contract - - # TODO throw error if no contract found - sys.exit(-1) - def create_poc(self) -> str: """Takes in a directory path to the echidna reproducers and generates a test file""" @@ -79,12 +68,12 @@ def create_poc(self) -> str: # 4. Generate the test file template = jinja2.Template(templates["CONTRACT"]) - write_path = f"{self.test_dir}{self.target_name}" - inheritance_path = f"{self.inheritance_path}{self.target_name}" + write_path = os.path.join(self.test_dir, self.target_name) + inheritance_path = os.path.join(self.inheritance_path, self.target_file_name) # 5. Save the test file test_file_str = template.render( - file_path=f"{inheritance_path}.sol", + file_path=inheritance_path, target_name=self.target_name, amount=0, tests=tests_list, diff --git a/fuzz_utils/generate/fuzzers/Echidna.py b/fuzz_utils/generate/fuzzers/Echidna.py index d561745..9c25e09 100644 --- a/fuzz_utils/generate/fuzzers/Echidna.py +++ b/fuzz_utils/generate/fuzzers/Echidna.py @@ -4,7 +4,6 @@ import jinja2 from slither import Slither -from slither.core.declarations.contract import Contract from slither.core.declarations.function_contract import FunctionContract from slither.core.solidity_types.elementary_type import ElementaryType from slither.core.solidity_types.user_defined_type import UserDefinedType @@ -16,9 +15,10 @@ from fuzz_utils.templates.foundry_templates import templates from fuzz_utils.utils.encoding import parse_echidna_byte_string from fuzz_utils.utils.error_handler import handle_exit +from fuzz_utils.utils.slither_utils import get_target_contract -# pylint: disable=too-many-instance-attributes +# pylint: disable=too-few-public-methods,too-many-instance-attributes class Echidna: """ Handles the generation of Foundry test files from Echidna reproducers @@ -30,22 +30,12 @@ def __init__( self.name = "Echidna" self.target_name = target_name self.slither = slither - self.target = self.get_target_contract() + self.target = get_target_contract(slither, target_name) self.reproducer_dir = f"{corpus_path}/reproducers" self.corpus_dirs = [f"{corpus_path}/coverage", self.reproducer_dir] self.named_inputs = named_inputs self.declared_variables: set[tuple[str, str]] = set() - def get_target_contract(self) -> Contract: - """Finds and returns Slither Contract""" - contracts = self.slither.get_contract_from_name(self.target_name) - # Loop in case slither fetches multiple contracts for some reason (e.g., similar names?) - for contract in contracts: - if contract.name == self.target_name: - return contract - - handle_exit(f"\n* Slither could not find the specified contract `{self.target_name}`.") - def parse_reproducer(self, file_path: str, calls: Any, index: int) -> str: """ Takes a list of call dicts and returns a Foundry unit test string containing the call sequence. diff --git a/fuzz_utils/generate/fuzzers/Medusa.py b/fuzz_utils/generate/fuzzers/Medusa.py index f3a0ff1..3b573a1 100644 --- a/fuzz_utils/generate/fuzzers/Medusa.py +++ b/fuzz_utils/generate/fuzzers/Medusa.py @@ -4,7 +4,6 @@ from eth_abi import abi from eth_utils import to_checksum_address from slither import Slither -from slither.core.declarations.contract import Contract from slither.core.declarations.function_contract import FunctionContract from slither.core.solidity_types.elementary_type import ElementaryType from slither.core.solidity_types.user_defined_type import UserDefinedType @@ -16,9 +15,10 @@ from fuzz_utils.templates.foundry_templates import templates from fuzz_utils.utils.encoding import byte_to_escape_sequence from fuzz_utils.utils.error_handler import handle_exit +from fuzz_utils.utils.slither_utils import get_target_contract - -class Medusa: # pylint: disable=too-many-instance-attributes +# pylint: disable=too-few-public-methods,too-many-instance-attributes +class Medusa: """ Handles the generation of Foundry test files from Medusa reproducers """ @@ -30,7 +30,7 @@ def __init__( self.target_name = target_name self.corpus_path = corpus_path self.slither = slither - self.target = self.get_target_contract() + self.target = get_target_contract(slither, target_name) self.reproducer_dir = f"{corpus_path}/test_results" self.corpus_dirs = [ f"{corpus_path}/call_sequences/immutable", @@ -40,16 +40,6 @@ def __init__( self.named_inputs = named_inputs self.declared_variables: set[tuple[str, str]] = set() - def get_target_contract(self) -> Contract: - """Finds and returns Slither Contract""" - contracts = self.slither.get_contract_from_name(self.target_name) - # Loop in case slither fetches multiple contracts for some reason (e.g., similar names?) - for contract in contracts: - if contract.name == self.target_name: - return contract - - handle_exit(f"\n* Slither could not find the specified contract `{self.target_name}`.") - def parse_reproducer(self, file_path: str, calls: Any, index: int) -> str: """ Takes a list of call dicts and returns a Foundry unit test string containing the call sequence. diff --git a/fuzz_utils/parsing/commands/generate.py b/fuzz_utils/parsing/commands/generate.py index e1ef8b8..4dfa967 100644 --- a/fuzz_utils/parsing/commands/generate.py +++ b/fuzz_utils/parsing/commands/generate.py @@ -1,5 +1,5 @@ """Defines the flags and logic associated with the `generate` command""" -import json +from pathlib import Path from argparse import Namespace, ArgumentParser from slither import Slither from fuzz_utils.utils.crytic_print import CryticPrint @@ -7,6 +7,10 @@ from fuzz_utils.generate.fuzzers.Medusa import Medusa from fuzz_utils.generate.fuzzers.Echidna import Echidna from fuzz_utils.utils.error_handler import handle_exit +from fuzz_utils.parsing.parser_util import check_config_and_set_default_values, open_config +from fuzz_utils.utils.slither_utils import get_target_contract + +COMMAND: str = "generate" def generate_flags(parser: ArgumentParser) -> None: @@ -57,15 +61,13 @@ def generate_flags(parser: ArgumentParser) -> None: ) +# pylint: disable=too-many-branches def generate_command(args: Namespace) -> None: """The execution logic of the `generate` command""" config: dict = {} # If the config file is defined, read it if args.config: - with open(args.config, "r", encoding="utf-8") as readFile: - complete_config = json.load(readFile) - if "generate" in complete_config: - config = complete_config["generate"] + config = open_config(args.config, COMMAND) # Override the config with the CLI values if args.compilation_path: config["compilationPath"] = args.compilation_path @@ -90,10 +92,37 @@ def generate_command(args: Namespace) -> None: if "allSequences" not in config: config["allSequences"] = False + check_config_and_set_default_values( + config, + ["compilationPath", "testsDir", "fuzzer", "corpusDir"], + [".", "test", "medusa", "corpus"], + ) + CryticPrint().print_information("Running Slither...") slither = Slither(args.compilation_path) fuzzer: Echidna | Medusa + # Derive target if it is not defined but the compilationPath only contains one contract + if "targetContract" not in config or len(config["targetContract"]) == 0: + if len(slither.contracts_derived) == 1: + config["targetContract"] = slither.contracts_derived[0].name + CryticPrint().print_information( + f"Target contract not specified. Using derived target: {config['targetContract']}." + ) + else: + handle_exit( + "Target contract cannot be determined. Please specify the target with `-c targetName`" + ) + + # Derive inheritance path if it is not defined + if "inheritancePath" not in config or len(config["inheritancePath"]) == 0: + contract = get_target_contract(slither, config["targetContract"]) + contract_path = Path(contract.source_mapping.filename.relative) + tests_path = Path(config["testsDir"]) + config["inheritancePath"] = str( + Path(*([".." * len(tests_path.parts)])).joinpath(contract_path) + ) + match config["fuzzer"]: case "echidna": fuzzer = Echidna( diff --git a/fuzz_utils/parsing/commands/template.py b/fuzz_utils/parsing/commands/template.py index 3de8a3d..46dd165 100644 --- a/fuzz_utils/parsing/commands/template.py +++ b/fuzz_utils/parsing/commands/template.py @@ -1,12 +1,17 @@ """Defines the flags and logic associated with the `template` command""" import os -import json from argparse import Namespace, ArgumentParser from slither import Slither from fuzz_utils.template.HarnessGenerator import HarnessGenerator from fuzz_utils.utils.crytic_print import CryticPrint from fuzz_utils.utils.remappings import find_remappings from fuzz_utils.utils.error_handler import handle_exit +from fuzz_utils.parsing.parser_util import ( + check_configuration_field_exists_and_non_empty, + open_config, +) + +COMMAND: str = "template" def template_flags(parser: ArgumentParser) -> None: @@ -42,10 +47,7 @@ def template_command(args: Namespace) -> None: else: output_dir = os.path.join("./test", "fuzzing") if args.config: - with open(args.config, "r", encoding="utf-8") as readFile: - complete_config = json.load(readFile) - if "template" in complete_config: - config = complete_config["template"] + config = open_config(args.config, COMMAND) if args.target_contracts: config["targets"] = args.target_contracts @@ -72,15 +74,9 @@ def check_configuration(config: dict) -> None: """Checks the configuration""" mandatory_configuration_fields = ["mode", "targets", "compilationPath"] for field in mandatory_configuration_fields: - check_configuration_field_exists_and_non_empty(config, field) + check_configuration_field_exists_and_non_empty(config, COMMAND, field) if config["mode"].lower() not in ("simple", "prank", "actor"): handle_exit( f"The selected mode {config['mode']} is not a valid harness generation strategy." ) - - -def check_configuration_field_exists_and_non_empty(config: dict, field: str) -> None: - """Checks that the configuration dictionary contains a non-empty field""" - if field not in config or len(config[field]) == 0: - handle_exit(f"The template configuration field {field} is not configured.") diff --git a/fuzz_utils/parsing/parser_util.py b/fuzz_utils/parsing/parser_util.py new file mode 100644 index 0000000..992aa35 --- /dev/null +++ b/fuzz_utils/parsing/parser_util.py @@ -0,0 +1,31 @@ +"""Utility functions used in the command parsers""" +import json +from fuzz_utils.utils.error_handler import handle_exit + + +def check_config_and_set_default_values( + config: dict, fields: list[str], defaults: list[str] +) -> None: + """Checks that the configuration dictionary contains a non-empty field""" + assert len(fields) == len(defaults) + for idx, field in enumerate(fields): + if field not in config or len(config[field]) == 0: + config[field] = defaults[idx] + + +def check_configuration_field_exists_and_non_empty(config: dict, command: str, field: str) -> None: + """Checks that the configuration dictionary contains a non-empty field""" + if field not in config or len(config[field]) == 0: + handle_exit(f"The {command} configuration field {field} is not configured.") + + +def open_config(cli_config: str, command: str) -> dict: + """Open config file if provided return its contents""" + with open(cli_config, "r", encoding="utf-8") as readFile: + complete_config = json.load(readFile) + if command in complete_config: + return complete_config[command] + + handle_exit( + f"The provided configuration file does not contain the `{command}` command configuration field." + ) diff --git a/fuzz_utils/template/HarnessGenerator.py b/fuzz_utils/template/HarnessGenerator.py index 69e69c5..3c85909 100644 --- a/fuzz_utils/template/HarnessGenerator.py +++ b/fuzz_utils/template/HarnessGenerator.py @@ -13,6 +13,7 @@ from fuzz_utils.utils.crytic_print import CryticPrint from fuzz_utils.utils.file_manager import check_and_create_dirs, save_file from fuzz_utils.utils.error_handler import handle_exit +from fuzz_utils.utils.slither_utils import get_target_contract from fuzz_utils.templates.harness_templates import templates # pylint: disable=too-many-instance-attributes @@ -69,6 +70,7 @@ def set_path(self, path: str) -> None: self.path = path +# pylint: disable=too-few-public-methods class HarnessGenerator: """ Handles the generation of Foundry test files from Echidna reproducers @@ -140,7 +142,7 @@ def __init__( self.slither = slither self.targets = [ - self.get_target_contract(slither, contract) for contract in self.config["targets"] + get_target_contract(slither, contract) for contract in self.config["targets"] ] self.output_dir = self.config["outputDir"] @@ -321,7 +323,7 @@ def _generate_attacks(self) -> list[Actor]: attack.set_path(path) attack_slither = Slither(f"{attack_output_path}/Attack{name}.sol") - attack.set_contract(self.get_target_contract(attack_slither, f"{name}Attack")) + attack.set_contract(get_target_contract(attack_slither, f"{name}Attack")) attacks.append(attack) else: @@ -392,8 +394,7 @@ def _generate_actors(self) -> list[Actor]: for actor_config in self.config["actors"]: name = actor_config["name"] target_contracts: list[Contract] = [ - self.get_target_contract(self.slither, contract) - for contract in actor_config["targets"] + get_target_contract(self.slither, contract) for contract in actor_config["targets"] ] CryticPrint().print_information(f" Actor: {name}Actor...") @@ -409,7 +410,7 @@ def _generate_actors(self) -> list[Actor]: actor.set_path(path) actor_slither = Slither(f"{actor_output_path}/Actor{name}.sol") - actor.set_contract(self.get_target_contract(actor_slither, f"Actor{name}")) + actor.set_contract(get_target_contract(actor_slither, f"Actor{name}")) actor_contracts.append(actor) @@ -519,17 +520,6 @@ def _render_template( return content, f"../{directory_name}/{file_name}.sol" - # pylint: disable=no-self-use - def get_target_contract(self, slither: Slither, target_name: str) -> Contract: - """Finds and returns Slither Contract""" - contracts = slither.get_contract_from_name(target_name) - # Loop in case slither fetches multiple contracts for some reason (e.g., similar names?) - for contract in contracts: - if contract.name == target_name: - return contract - - handle_exit(f"\n* Slither could not find the specified contract `{target_name}`.") - # Utility functions def should_skip_contract_functions(contract: Contract) -> bool: diff --git a/fuzz_utils/utils/slither_utils.py b/fuzz_utils/utils/slither_utils.py new file mode 100644 index 0000000..dd402d2 --- /dev/null +++ b/fuzz_utils/utils/slither_utils.py @@ -0,0 +1,15 @@ +"""Common utilities for Slither""" +from slither import Slither +from slither.core.declarations.contract import Contract +from fuzz_utils.utils.error_handler import handle_exit + + +def get_target_contract(slither: Slither, target_name: str) -> Contract: + """Gets the Slither Contract object for the specified contract file""" + contracts = slither.get_contract_from_name(target_name) + # Loop in case slither fetches multiple contracts for some reason (e.g., similar names?) + for contract in contracts: + if contract.name == target_name: + return contract + + handle_exit(f"\n* Slither could not find the specified contract `{target_name}`.") diff --git a/tests/test_harness.py b/tests/test_harness.py index 9b8876e..be05cee 100644 --- a/tests/test_harness.py +++ b/tests/test_harness.py @@ -7,6 +7,7 @@ from slither.core.declarations.contract import Contract from slither.core.declarations.function_contract import FunctionContract from fuzz_utils.utils.remappings import find_remappings +from fuzz_utils.utils.slither_utils import get_target_contract from fuzz_utils.template.HarnessGenerator import HarnessGenerator @@ -208,7 +209,7 @@ def run_harness( # Ensure the harness only contains the functions we're expecting slither = Slither(f"./test/fuzzing/harnesses/{harness_name}.sol") - target: Contract = generator.get_target_contract(slither, harness_name) + target: Contract = get_target_contract(slither, harness_name) compare_with_declared_functions(target, set(expected_functions)) From 4c5077c6640972075a3e90c288d3da93672beab9 Mon Sep 17 00:00:00 2001 From: tuturu-tech Date: Tue, 2 Apr 2024 13:08:33 +0200 Subject: [PATCH 2/4] fix inheritance path derivation, update readme --- README.md | 20 ++++++++++---------- fuzz_utils/generate/FoundryTest.py | 4 ++-- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index d0064dd..ac949e6 100644 --- a/README.md +++ b/README.md @@ -40,21 +40,21 @@ The available tool commands are: The `generate` command is used to generate Foundry unit tests from Echidna or Medusa corpus call sequences. **Command-line options:** -- `compilation_path`: The path to the Solidity file or Foundry directory -- `-cd`/`--corpus-dir` `path_to_corpus_dir`: The path to the corpus directory relative to the working directory. -- `-c`/`--contract` `contract_name`: The name of the target contract. -- `-td`/`--test-directory` `path_to_test_directory`: The path to the test directory relative to the working directory. -- `-i`/`--inheritance-path` `relative_path_to_contract`: The relative path from the test directory to the contract (used for inheritance). -- `-f`/`--fuzzer` `fuzzer_name`: The name of the fuzzer, currently supported: `echidna` and `medusa` -- `--named-inputs`: Includes function input names when making calls -- `--config`: Path to the fuzz-utils config JSON file -- `--all-sequences`: Include all corpus sequences when generating unit tests. +- `compilation_path`: The path to the Solidity file or Foundry directory. By default `.` +- `-cd`/`--corpus-dir` `path_to_corpus_dir`: The path to the corpus directory relative to the working directory. By default `corpus` +- `-c`/`--contract` `contract_name`: The name of the target contract. If the compilation path only contains one contract the target will be automatically derived. +- `-td`/`--test-directory` `path_to_test_directory`: The path to the test directory relative to the working directory. By default `test` +- `-i`/`--inheritance-path` `relative_path_to_contract`: The relative path from the test directory to the contract (used for overriding inheritance). If this configuration option is not provided the inheritance path will be automatically derived. +- `-f`/`--fuzzer` `fuzzer_name`: The name of the fuzzer, currently supported: `echidna` and `medusa`. By default `medusa` +- `--named-inputs`: Includes function input names when making calls. By default`false` +- `--config`: Path to the fuzz-utils config JSON file. Empty by default. +- `--all-sequences`: Include all corpus sequences when generating unit tests. By default `false` **Example** In order to generate a test file for the [BasicTypes.sol](tests/test_data/src/BasicTypes.sol) contract, based on the Echidna corpus reproducers for this contract ([corpus-basic](tests/test_data/echidna-corpora/corpus-basic/)), we need to `cd` into the `tests/test_data` directory which contains the Foundry project and run the command: ```bash -fuzz-utils generate ./src/BasicTypes.sol --corpus-dir echidna-corpora/corpus-basic --contract "BasicTypes" --test-directory "./test/" --inheritance-path "../src/" --fuzzer echidna +fuzz-utils generate ./src/BasicTypes.sol --corpus-dir echidna-corpora/corpus-basic --contract "BasicTypes" --fuzzer echidna ``` Running this command should generate a `BasicTypes_Echidna_Test.sol` file in the [test](/tests/test_data/test/) directory of the Foundry project. diff --git a/fuzz_utils/generate/FoundryTest.py b/fuzz_utils/generate/FoundryTest.py index d9b05c9..020d15e 100644 --- a/fuzz_utils/generate/FoundryTest.py +++ b/fuzz_utils/generate/FoundryTest.py @@ -69,8 +69,8 @@ def create_poc(self) -> str: # 4. Generate the test file template = jinja2.Template(templates["CONTRACT"]) write_path = os.path.join(self.test_dir, self.target_name) - inheritance_path = os.path.join(self.inheritance_path, self.target_file_name) - + inheritance_path = os.path.join(self.inheritance_path) + print("INHERITANCE PATH", inheritance_path) # 5. Save the test file test_file_str = template.render( file_path=inheritance_path, From 61f1fe801bbe2e5945b224a13f2aebd0e030f0cd Mon Sep 17 00:00:00 2001 From: tuturu-tech Date: Tue, 2 Apr 2024 13:11:10 +0200 Subject: [PATCH 3/4] update default config and readme --- README.md | 6 +++--- fuzz_utils/templates/default_config.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index ac949e6..cf8979f 100644 --- a/README.md +++ b/README.md @@ -65,9 +65,9 @@ The `template` command is used to generate a fuzzing harness. The harness can in **Command-line options:** - `compilation_path`: The path to the Solidity file or Foundry directory -- `-n`/`--name` `name: str`: The name of the fuzzing harness. -- `-c`/`--contracts` `target_contracts: list`: The name of the target contract. -- `-o`/`--output-dir` `output_directory: str`: Output directory name. By default it is `fuzzing` +- `-n`/`--name` `name: str`: The name of the fuzzing harness. By default `DefaultHarness` +- `-c`/`--contracts` `target_contracts: list`: The name of the target contract. Empty by default. +- `-o`/`--output-dir` `output_directory: str`: Output directory name. By default `fuzzing` - `--config`: Path to the `fuzz-utils` config JSON file - `--mode`: The strategy to use when generating the harnesses. Valid options: `simple`, `prank`, `actor` diff --git a/fuzz_utils/templates/default_config.py b/fuzz_utils/templates/default_config.py index db41bb4..779a559 100644 --- a/fuzz_utils/templates/default_config.py +++ b/fuzz_utils/templates/default_config.py @@ -3,9 +3,9 @@ "generate": { "targetContract": "", "compilationPath": ".", - "corpusDir": "", - "fuzzer": "", - "testsDir": "", + "corpusDir": "corpus", + "fuzzer": "medusa", + "testsDir": "test", "inheritancePath": "", "namedInputs": False, "allSequences": False, From 3bb314176d7e0522c80de5d438c618001ac3ceb2 Mon Sep 17 00:00:00 2001 From: tuturu-tech Date: Tue, 2 Apr 2024 14:21:16 +0200 Subject: [PATCH 4/4] fix test setup, refactor default config for classes --- fuzz_utils/generate/FoundryTest.py | 24 +++++++------ fuzz_utils/parsing/commands/generate.py | 45 ++++++++++++++----------- fuzz_utils/template/HarnessGenerator.py | 22 ++---------- tests/conftest.py | 2 +- 4 files changed, 41 insertions(+), 52 deletions(-) diff --git a/fuzz_utils/generate/FoundryTest.py b/fuzz_utils/generate/FoundryTest.py index 020d15e..cfa8b77 100644 --- a/fuzz_utils/generate/FoundryTest.py +++ b/fuzz_utils/generate/FoundryTest.py @@ -1,12 +1,14 @@ """The FoundryTest class that handles generation of unit tests from call sequences""" import os import json +import copy from typing import Any import jinja2 from slither import Slither from fuzz_utils.utils.crytic_print import CryticPrint from fuzz_utils.utils.slither_utils import get_target_contract +from fuzz_utils.templates.default_config import default_config from fuzz_utils.generate.fuzzers.Medusa import Medusa from fuzz_utils.generate.fuzzers.Echidna import Echidna @@ -18,19 +20,20 @@ class FoundryTest: Handles the generation of Foundry test files """ + config: dict = copy.deepcopy(default_config["generate"]) + def __init__( self, config: dict, slither: Slither, fuzzer: Echidna | Medusa, ) -> None: - self.inheritance_path = config["inheritancePath"] - self.target_name = config["targetContract"] - self.corpus_path = config["corpusDir"] - self.test_dir = config["testsDir"] - self.all_sequences = config["allSequences"] self.slither = slither - self.target = get_target_contract(self.slither, self.target_name) + for key, value in config.items(): + if key in self.config: + self.config[key] = value + + self.target = get_target_contract(self.slither, self.config["targetContract"]) self.target_file_name = self.target.source_mapping.filename.relative.split("/")[-1] self.fuzzer = fuzzer @@ -40,7 +43,7 @@ def create_poc(self) -> str: file_list: list[dict[str, Any]] = [] tests_list = [] dir_list = [] - if self.all_sequences: + if self.config["allSequences"]: dir_list = self.fuzzer.corpus_dirs else: dir_list = [self.fuzzer.reproducer_dir] @@ -68,13 +71,12 @@ def create_poc(self) -> str: # 4. Generate the test file template = jinja2.Template(templates["CONTRACT"]) - write_path = os.path.join(self.test_dir, self.target_name) - inheritance_path = os.path.join(self.inheritance_path) - print("INHERITANCE PATH", inheritance_path) + write_path = os.path.join(self.config["testsDir"], self.config["targetContract"]) + inheritance_path = os.path.join(self.config["inheritancePath"]) # 5. Save the test file test_file_str = template.render( file_path=inheritance_path, - target_name=self.target_name, + target_name=self.config["targetContract"], amount=0, tests=tests_list, fuzzer=self.fuzzer.name, diff --git a/fuzz_utils/parsing/commands/generate.py b/fuzz_utils/parsing/commands/generate.py index 4dfa967..29141b7 100644 --- a/fuzz_utils/parsing/commands/generate.py +++ b/fuzz_utils/parsing/commands/generate.py @@ -102,26 +102,7 @@ def generate_command(args: Namespace) -> None: slither = Slither(args.compilation_path) fuzzer: Echidna | Medusa - # Derive target if it is not defined but the compilationPath only contains one contract - if "targetContract" not in config or len(config["targetContract"]) == 0: - if len(slither.contracts_derived) == 1: - config["targetContract"] = slither.contracts_derived[0].name - CryticPrint().print_information( - f"Target contract not specified. Using derived target: {config['targetContract']}." - ) - else: - handle_exit( - "Target contract cannot be determined. Please specify the target with `-c targetName`" - ) - - # Derive inheritance path if it is not defined - if "inheritancePath" not in config or len(config["inheritancePath"]) == 0: - contract = get_target_contract(slither, config["targetContract"]) - contract_path = Path(contract.source_mapping.filename.relative) - tests_path = Path(config["testsDir"]) - config["inheritancePath"] = str( - Path(*([".." * len(tests_path.parts)])).joinpath(contract_path) - ) + derive_config(slither, config) match config["fuzzer"]: case "echidna": @@ -143,3 +124,27 @@ def generate_command(args: Namespace) -> None: foundry_test = FoundryTest(config, slither, fuzzer) foundry_test.create_poc() CryticPrint().print_success("Done!") + + +def derive_config(slither: Slither, config: dict) -> None: + """Derive values for the target contract and inheritance path""" + # Derive target if it is not defined but the compilationPath only contains one contract + if "targetContract" not in config or len(config["targetContract"]) == 0: + if len(slither.contracts_derived) == 1: + config["targetContract"] = slither.contracts_derived[0].name + CryticPrint().print_information( + f"Target contract not specified. Using derived target: {config['targetContract']}." + ) + else: + handle_exit( + "Target contract cannot be determined. Please specify the target with `-c targetName`" + ) + + # Derive inheritance path if it is not defined + if "inheritancePath" not in config or len(config["inheritancePath"]) == 0: + contract = get_target_contract(slither, config["targetContract"]) + contract_path = Path(contract.source_mapping.filename.relative) + tests_path = Path(config["testsDir"]) + config["inheritancePath"] = str( + Path(*([".." * len(tests_path.parts)])).joinpath(contract_path) + ) diff --git a/fuzz_utils/template/HarnessGenerator.py b/fuzz_utils/template/HarnessGenerator.py index 3c85909..5eef4c3 100644 --- a/fuzz_utils/template/HarnessGenerator.py +++ b/fuzz_utils/template/HarnessGenerator.py @@ -15,6 +15,7 @@ from fuzz_utils.utils.error_handler import handle_exit from fuzz_utils.utils.slither_utils import get_target_contract from fuzz_utils.templates.harness_templates import templates +from fuzz_utils.templates.default_config import default_config # pylint: disable=too-many-instance-attributes @dataclass @@ -76,26 +77,7 @@ class HarnessGenerator: Handles the generation of Foundry test files from Echidna reproducers """ - config: dict = { - "name": "DefaultHarness", - "compilationPath": ".", - "targets": [], - "outputDir": "./test/fuzzing", - "actors": [ - { - "name": "Default", - "targets": [], - "number": 3, - "filters": { - "strict": False, - "onlyModifiers": [], - "onlyPayable": False, - "onlyExternalCalls": [], - }, - } - ], - "attacks": [], - } + config: dict = copy.deepcopy(default_config["template"]) def __init__( self, diff --git a/tests/conftest.py b/tests/conftest.py index dbc9064..5afa030 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -23,7 +23,7 @@ def __init__(self, target: str, target_path: str, corpus_dir: str): medusa = Medusa(target, f"medusa-corpora/{corpus_dir}", slither, False) config = { "targetContract": target, - "inheritancePath": "../src/", + "inheritancePath": f"../src/{target}.sol", "corpusDir": f"echidna-corpora/{corpus_dir}", "testsDir": "./test/", "allSequences": False,