Skip to content

Commit

Permalink
fix test setup, refactor default config for classes
Browse files Browse the repository at this point in the history
  • Loading branch information
tuturu-tech committed Apr 2, 2024
1 parent 61f1fe8 commit 3bb3141
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 52 deletions.
24 changes: 13 additions & 11 deletions fuzz_utils/generate/FoundryTest.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
"""The FoundryTest class that handles generation of unit tests from call sequences"""
import os
import json
import copy
from typing import Any
import jinja2

from slither import Slither
from fuzz_utils.utils.crytic_print import CryticPrint
from fuzz_utils.utils.slither_utils import get_target_contract
from fuzz_utils.templates.default_config import default_config

from fuzz_utils.generate.fuzzers.Medusa import Medusa
from fuzz_utils.generate.fuzzers.Echidna import Echidna
Expand All @@ -18,19 +20,20 @@ class FoundryTest:
Handles the generation of Foundry test files
"""

config: dict = copy.deepcopy(default_config["generate"])

def __init__(
self,
config: dict,
slither: Slither,
fuzzer: Echidna | Medusa,
) -> None:
self.inheritance_path = config["inheritancePath"]
self.target_name = config["targetContract"]
self.corpus_path = config["corpusDir"]
self.test_dir = config["testsDir"]
self.all_sequences = config["allSequences"]
self.slither = slither
self.target = get_target_contract(self.slither, self.target_name)
for key, value in config.items():
if key in self.config:
self.config[key] = value

self.target = get_target_contract(self.slither, self.config["targetContract"])
self.target_file_name = self.target.source_mapping.filename.relative.split("/")[-1]
self.fuzzer = fuzzer

Expand All @@ -40,7 +43,7 @@ def create_poc(self) -> str:
file_list: list[dict[str, Any]] = []
tests_list = []
dir_list = []
if self.all_sequences:
if self.config["allSequences"]:
dir_list = self.fuzzer.corpus_dirs
else:
dir_list = [self.fuzzer.reproducer_dir]
Expand Down Expand Up @@ -68,13 +71,12 @@ def create_poc(self) -> str:

# 4. Generate the test file
template = jinja2.Template(templates["CONTRACT"])
write_path = os.path.join(self.test_dir, self.target_name)
inheritance_path = os.path.join(self.inheritance_path)
print("INHERITANCE PATH", inheritance_path)
write_path = os.path.join(self.config["testsDir"], self.config["targetContract"])
inheritance_path = os.path.join(self.config["inheritancePath"])
# 5. Save the test file
test_file_str = template.render(
file_path=inheritance_path,
target_name=self.target_name,
target_name=self.config["targetContract"],
amount=0,
tests=tests_list,
fuzzer=self.fuzzer.name,
Expand Down
45 changes: 25 additions & 20 deletions fuzz_utils/parsing/commands/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,26 +102,7 @@ def generate_command(args: Namespace) -> None:
slither = Slither(args.compilation_path)
fuzzer: Echidna | Medusa

# Derive target if it is not defined but the compilationPath only contains one contract
if "targetContract" not in config or len(config["targetContract"]) == 0:
if len(slither.contracts_derived) == 1:
config["targetContract"] = slither.contracts_derived[0].name
CryticPrint().print_information(
f"Target contract not specified. Using derived target: {config['targetContract']}."
)
else:
handle_exit(
"Target contract cannot be determined. Please specify the target with `-c targetName`"
)

# Derive inheritance path if it is not defined
if "inheritancePath" not in config or len(config["inheritancePath"]) == 0:
contract = get_target_contract(slither, config["targetContract"])
contract_path = Path(contract.source_mapping.filename.relative)
tests_path = Path(config["testsDir"])
config["inheritancePath"] = str(
Path(*([".." * len(tests_path.parts)])).joinpath(contract_path)
)
derive_config(slither, config)

match config["fuzzer"]:
case "echidna":
Expand All @@ -143,3 +124,27 @@ def generate_command(args: Namespace) -> None:
foundry_test = FoundryTest(config, slither, fuzzer)
foundry_test.create_poc()
CryticPrint().print_success("Done!")


def derive_config(slither: Slither, config: dict) -> None:
"""Derive values for the target contract and inheritance path"""
# Derive target if it is not defined but the compilationPath only contains one contract
if "targetContract" not in config or len(config["targetContract"]) == 0:
if len(slither.contracts_derived) == 1:
config["targetContract"] = slither.contracts_derived[0].name
CryticPrint().print_information(
f"Target contract not specified. Using derived target: {config['targetContract']}."
)
else:
handle_exit(
"Target contract cannot be determined. Please specify the target with `-c targetName`"
)

# Derive inheritance path if it is not defined
if "inheritancePath" not in config or len(config["inheritancePath"]) == 0:
contract = get_target_contract(slither, config["targetContract"])
contract_path = Path(contract.source_mapping.filename.relative)
tests_path = Path(config["testsDir"])
config["inheritancePath"] = str(
Path(*([".." * len(tests_path.parts)])).joinpath(contract_path)
)
22 changes: 2 additions & 20 deletions fuzz_utils/template/HarnessGenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from fuzz_utils.utils.error_handler import handle_exit
from fuzz_utils.utils.slither_utils import get_target_contract
from fuzz_utils.templates.harness_templates import templates
from fuzz_utils.templates.default_config import default_config

# pylint: disable=too-many-instance-attributes
@dataclass
Expand Down Expand Up @@ -76,26 +77,7 @@ class HarnessGenerator:
Handles the generation of Foundry test files from Echidna reproducers
"""

config: dict = {
"name": "DefaultHarness",
"compilationPath": ".",
"targets": [],
"outputDir": "./test/fuzzing",
"actors": [
{
"name": "Default",
"targets": [],
"number": 3,
"filters": {
"strict": False,
"onlyModifiers": [],
"onlyPayable": False,
"onlyExternalCalls": [],
},
}
],
"attacks": [],
}
config: dict = copy.deepcopy(default_config["template"])

def __init__(
self,
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __init__(self, target: str, target_path: str, corpus_dir: str):
medusa = Medusa(target, f"medusa-corpora/{corpus_dir}", slither, False)
config = {
"targetContract": target,
"inheritancePath": "../src/",
"inheritancePath": f"../src/{target}.sol",
"corpusDir": f"echidna-corpora/{corpus_dir}",
"testsDir": "./test/",
"allSequences": False,
Expand Down

0 comments on commit 3bb3141

Please sign in to comment.