From fc0ab9f9648adcb8651f702c6fa7a664be1bb327 Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Wed, 6 Nov 2024 13:44:38 +0100 Subject: [PATCH] enh: better regression tests. works locally, need to update baseline --- .github/workflows/regression_tests.yml | 36 ++-- .../workflows/update_regression_baseline.yml | 17 +- tests/regression_test_runner.py | 95 --------- tests/test_regression.py | 196 ++++++++++++++---- 4 files changed, 185 insertions(+), 159 deletions(-) delete mode 100644 tests/regression_test_runner.py diff --git a/.github/workflows/regression_tests.yml b/.github/workflows/regression_tests.yml index 317dc3e9..69dfb87b 100644 --- a/.github/workflows/regression_tests.yml +++ b/.github/workflows/regression_tests.yml @@ -31,26 +31,26 @@ jobs: if: github.event.pull_request.base.ref == 'main' run: | # Check if regression test results exist in main branch - if [ -f 'git cat-file -e main:tests/regression_test_results.json' ]; then - git checkout main tests/regression_test_results.json + if [ -f 'git cat-file -e main:tests/regression_test_baselines.json' ]; then + git checkout main tests/regression_test_baselines.json else echo "No regression test results found in main branch" fi - python tests/regression_test_runner.py > regression_test_report.txt - git checkout . + PYTHONHASHSEED=0 pytest tests/test_regression.py + git checkout - - name: Comment PR - if: github.event.pull_request.base.ref == 'main' - uses: actions/github-script@v7 - with: - github-token: ${{ secrets.GITHUB_TOKEN }} - script: | - const fs = require('fs'); - const TestReport = fs.readFileSync('regression_test_report.txt', 'utf8'); + # - name: Comment PR + # if: github.event.pull_request.base.ref == 'main' + # uses: actions/github-script@v7 + # with: + # github-token: ${{ secrets.GITHUB_TOKEN }} + # script: | + # const fs = require('fs'); + # const TestReport = fs.readFileSync('regression_test_report.txt', 'utf8'); - await github.rest.issues.createComment({ - issue_number: context.issue.number, - owner: context.repo.owner, - repo: context.repo.repo, - body: `## Regression Test Results\n\`\`\`\n${TestReport}\n\`\`\`` - }); \ No newline at end of file + # await github.rest.issues.createComment({ + # issue_number: context.issue.number, + # owner: context.repo.owner, + # repo: context.repo.repo, + # body: `## Regression Test Results\n\`\`\`\n${TestReport}\n\`\`\`` + # }); \ No newline at end of file diff --git a/.github/workflows/update_regression_baseline.yml b/.github/workflows/update_regression_baseline.yml index 122f6d40..afe5202d 100644 --- a/.github/workflows/update_regression_baseline.yml +++ b/.github/workflows/update_regression_baseline.yml @@ -2,7 +2,9 @@ name: Regression Tests on: - workflow_dispatch: + pull_request: + branches: + - main jobs: regression_tests: @@ -24,13 +26,8 @@ jobs: run: | python -m pip install --upgrade pip pip install -e ".[dev]" - - - name: Run benchmarks on PR branch + + - name: Update baseline + if: github.event.pull_request.base.ref == 'main' run: | - python tests/regression_test_runner.py - - - name: Save new regression baseline - uses: actions/upload-artifact@v3 - with: - name: regression-test-results - path: tests/regression_test_results.json \ No newline at end of file + PYTHONHASHSEED=0 UPDATE_BASELINE=1 pytest tests/test_regression.py \ No newline at end of file diff --git a/tests/regression_test_runner.py b/tests/regression_test_runner.py deleted file mode 100644 index 19c792ed..00000000 --- a/tests/regression_test_runner.py +++ /dev/null @@ -1,95 +0,0 @@ -# tests/benchmark_runner.py -import importlib.util -import json -import timeit -from pathlib import Path - -from pytest import MarkDecorator - - -def load_module(file_path): - """Dynamically load a Python file as a module.""" - spec = importlib.util.spec_from_file_location("module", file_path) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - return module - - -def get_test_functions(module): - """Get all functions in the module that start with test.""" - - def is_test(value): - if not isinstance(value, MarkDecorator): - return callable(value) and value.__name__.startswith("test") - return False - - return [value for value in vars(module).values() if is_test(value)] - - -def run_benchmarks(functions_to_test): - """Run benchmarks for the specified functions.""" - results = {} - for func in functions_to_test: - # Run each function 1000 times and take the average - time_taken = timeit.timeit(lambda: func(), number=1) - # Use the function's qualified name (module.function) as the key - results[f"{func.__module__}.{func.__name__}"] = time_taken - return results - - -def save_results(results, output_file): - """Save benchmark results to JSON file.""" - with open(output_file, "w") as f: - json.dump(results, f, indent=2) - - -def compare_results(base_results, new_results): - """Compare two sets of benchmark results and generate a diff report.""" - report = [] - for func_name in new_results: - new_time = new_results[func_name] - base_time = base_results.get(func_name) - - if base_time is None: - report.append(f"🆕 {func_name}: {new_time:.6f}s (new function)") - continue - - diff_pct = ((new_time - base_time) / base_time) * 100 - if diff_pct > 0: - emoji = "🔴" - elif diff_pct < 0: - emoji = "🟢" - else: - emoji = "⚪" - - report.append( - f"{emoji} {func_name}: {new_time:.6f}s " - f"({diff_pct:+.1f}% vs {base_time:.6f}s)" - ) - - return "\n".join(report) - - -if __name__ == "__main__": - regression_test_file = "tests/test_regression.py" - output_file = "tests/regression_test_results.json" - base_results_file = ( - "tests/regression_test_results.json" - if Path("tests/regression_test_results.json").exists() - else None - ) - - # Load the regression test module - test_module = load_module(regression_test_file) - test_functions = get_test_functions(test_module) - results = run_benchmarks(test_functions) - - if base_results_file: - with open(base_results_file) as f: - base_results = json.load(f) - else: - base_results = {} - print(compare_results(base_results, results)) - - # save new results - # save_results(results, output_file) diff --git a/tests/test_regression.py b/tests/test_regression.py index e5347fa9..0f0ae2ef 100644 --- a/tests/test_regression.py +++ b/tests/test_regression.py @@ -12,9 +12,134 @@ from jaxley.channels import HH from jaxley.connect import sparse_connect from jaxley.synapses import IonotropicSynapse +from functools import wraps +import json -# mark all tests as runtime tests in this file -pytestmark = pytest.mark.runtime + +# due to the use of hash functions, `test_regression.py` has to be run with the +# environment variable PYTHONHASHSEED=0, i.e. `:~$PYTHONHASHSEED=0 pytest `tests/test_regression.py`. +# For details see https://stackoverflow.com/questions/30585108/disable-hash-randomization-from-within-python-program + + +# Every runtime test needs to have the following structure: +# +# @compare_to_baseline() +# def test_runtime_of_x(**kwargs) -> Dict: +# t1 = time.time() +# time.sleep(0.1) +# # do something +# t2 = time.time() +# # do something else +# t3 = time.time() +# return {"dt1": t2-t1, dt2: t3-t2} + +pytestmark = pytest.mark.runtime # mark all tests as runtime tests in this file +UPDATE_BASELINE = ( + os.environ["UPDATE_BASELINE"] if "UPDATE_BASELINE" in os.environ else 0 +) +tolerance = 0.2 + +BASELINES = {} +if os.path.exists("tests/regression_test_baselines.json"): + with open("tests/regression_test_baselines.json") as f: + BASELINES = json.load(f) + + +def generate_unique_key(d): + return str(hash(json.dumps(d, sort_keys=True))) + + +def append_to_json(fpath, test_name, input_kwargs, runtimes): + header = {"test_name": test_name, "input_kwargs": input_kwargs} + data = {generate_unique_key(header): {**header, "runtimes": runtimes}} + + # Save data to a JSON file + if os.path.exists(fpath): + with open(fpath, "r") as f: + result_data = json.load(f) + result_data.update(data) + else: + result_data = data + with open(fpath, "w") as f: + json.dump(result_data, f, indent=2) + + +class compare_to_baseline: + def __init__(self, baseline_iters=1, test_iters=1): + self.baseline_iters = baseline_iters + self.test_iters = test_iters + + def __call__(self, func): + @wraps(func) # ensures kwargs exposed to pytest + def test_wrapper(**kwargs): + header = {"test_name": func.__name__, "input_kwargs": kwargs} + key = generate_unique_key(header) + + runs = [] + num_iters = self.baseline_iters if UPDATE_BASELINE else self.test_iters + for _ in range(num_iters): + runtimes = func(**kwargs) + runs.append(runtimes) + runtimes = {k: np.mean([d[k] for d in runs]) for k in runs[0]} + + fpath = ( + "tests/regression_test_results.json" + if not UPDATE_BASELINE + else "tests/regression_test_baselines.json" + ) + append_to_json(fpath, header["test_name"], header["input_kwargs"], runtimes) + + if not UPDATE_BASELINE: + assert key in BASELINES, f"No basline found for {header}" + func_baselines = BASELINES[key]["runtimes"] + for key, baseline in func_baselines.items(): + diff = ( + float("nan") + if np.isclose(baseline, 0) + else (runtimes[key] - baseline) / baseline + ) + assert runtimes[key] < baseline * ( + 1 + tolerance + ), f"{key} is {diff:.2%} slower than the baseline." + + return test_wrapper + + +def generate_report(base_results, new_results): + """Compare two sets of benchmark results and generate a diff report.""" + report = [] + for key in new_results: + new_data = new_results[key] + base_data = base_results.get(key) + kwargs = ", ".join([f"{k}={v}" for k, v in new_data["input_kwargs"].items()]) + func_name = new_data["test_name"] + func_signature = f"{func_name}({kwargs})" + + new_runtimes = new_data["runtimes"] + base_runtimes = ( + {k: None for k in new_data.keys()} + if base_data is None + else base_data["runtimes"] + ) + + report.append(func_signature) + for key, new_time in new_runtimes.items(): + base_time = base_runtimes.get(key) + diff = None if base_time is None else ((new_time - base_time) / base_time) + + emoji = "🆕" if diff is None else "⚪" + emoji = "🔴" if diff > tolerance else emoji + emoji = "🟢" if diff < 0 else emoji + + time_str = ( + f"({new_time:.6f}s)" + if diff is None + else f"({diff:+.1f}% vs {base_time:.6f}s)" + ) + report.append(f"{emoji} {key}: {time_str}.") + report.append("") + + return "\n".join(report) def build_net(num_cells, artificial=True, connect=True, connection_prob=0.0): @@ -52,24 +177,31 @@ def build_net(num_cells, artificial=True, connect=True, connection_prob=0.0): return net, params -def setup_runtime( +@pytest.mark.parametrize( + "num_cells, artificial, connect, connection_prob, voltage_solver", + ( + # Test a single SWC cell with both solvers. + pytest.param(1, False, False, 0.0, "jaxley.stone"), + pytest.param(1, False, False, 0.0, "jax.sparse"), + # Test a network of SWC cells with both solvers. + pytest.param(10, False, True, 0.1, "jaxley.stone"), + pytest.param(10, False, True, 0.1, "jax.sparse"), + # Test a larger network of smaller neurons with both solvers. + pytest.param(1000, True, True, 0.001, "jaxley.stone"), + pytest.param(1000, True, True, 0.001, "jax.sparse"), + ), +) +@compare_to_baseline(baseline_iters=3) +def test_runtime( num_cells: int, artificial: bool, connect: bool, connection_prob: float, voltage_solver: str, - identifier: int, ): delta_t = 0.025 t_max = 100.0 - net, params = build_net( - num_cells, - artificial=artificial, - connect=connect, - connection_prob=connection_prob, - ) - def simulate(params): return jx.integrate( net, @@ -79,33 +211,25 @@ def simulate(params): voltage_solver=voltage_solver, ) + runtimes = {} + + start_time = time.time() + net, params = build_net( + num_cells, + artificial=artificial, + connect=connect, + connection_prob=connection_prob, + ) + runtimes["build_time"] = time.time() - start_time + jitted_simulate = jit(simulate) + start_time = time.time() _ = jitted_simulate(params).block_until_ready() - + runtimes["compile_time"] = time.time() - start_time params[0]["radius"] = params[0]["radius"].at[0].set(0.5) - _ = jitted_simulate(params).block_until_ready() - -def test_runtime1(): - setup_runtime(1, False, False, 0.0, "jaxley.stone", 0) - - -def test_runtime2(): - setup_runtime(1, False, False, 0.0, "jax.sparse", 1) - - -def test_runtime3(): - setup_runtime(10, False, True, 0.1, "jaxley.stone", 2) - - -def test_runtime4(): - setup_runtime(10, False, True, 0.1, "jax.sparse", 3) - - -def test_runtime5(): - setup_runtime(1000, True, True, 0.001, "jaxley.stone", 4) - - -def test_runtime6(): - setup_runtime(1000, True, True, 0.001, "jax.sparse", 5) + start_time = time.time() + _ = jitted_simulate(params).block_until_ready() + runtimes["run_time"] = time.time() - start_time + return runtimes