diff --git a/.github/workflows/regression_tests.yml b/.github/workflows/regression_tests.yml index 08a67907..ebfd4b64 100644 --- a/.github/workflows/regression_tests.yml +++ b/.github/workflows/regression_tests.yml @@ -1,29 +1,33 @@ +# .github/workflows/regression_tests.yml name: Regression Tests on: pull_request: - branches: [ main ] + branches: + - main jobs: - Regression-Test: - runs-on: ubuntu-latest + regression_tests: + name: regression_tests + runs-on: ubuntu-20.04 steps: - uses: actions/checkout@v3 with: + lfs: true fetch-depth: 0 # This ensures we can checkout main branch too - - name: Set up Python - uses: actions/setup-python@v4 + - uses: actions/setup-python@v4 with: - python-version: '3.x' + python-version: '3.10' + architecture: 'x64' - name: Install dependencies run: | python -m pip install --upgrade pip - pip install -r requirements.txt + pip install -e ".[dev]" - - name: Run regression tests on PR branch + - name: Run benchmarks on PR branch run: | python tests/regression_test_runner.py @@ -41,20 +45,17 @@ jobs: git checkout - - name: Comment PR - uses: peter-evans/create-or-update-comment@v2 if: github.event.pull_request.base.ref == 'main' - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + uses: actions/github-script@v7 with: - issue-number: ${{ github.event.pull_request.number }} - body: | - ## Regression Test Results - ``` - $(cat regression_test_report.txt) - ``` - - - name: Save regression test results - uses: actions/upload-artifact@v3 - with: - name: regression-test-results - path: tests/regression_test_results.json \ No newline at end of file + github-token: ${{ secrets.GITHUB_TOKEN }} + script: | + const fs = require('fs'); + const TestReport = fs.readFileSync('regression_test_report.txt', 'utf8'); + + await github.rest.issues.createComment({ + issue_number: context.issue.number, + owner: context.repo.owner, + repo: context.repo.repo, + body: `## Regression Test Results\n\`\`\`\n${TestReport}\n\`\`\`` + }); \ No newline at end of file diff --git a/.github/workflows/update_regression_baseline.yml b/.github/workflows/update_regression_baseline.yml new file mode 100644 index 00000000..19c92f2c --- /dev/null +++ b/.github/workflows/update_regression_baseline.yml @@ -0,0 +1,37 @@ +# .github/workflows/regression_tests.yml +name: Regression Tests + +on: + workflow_dispatch: + +jobs: + regression_tests: + name: regression_tests + runs-on: ubuntu-20.04 + + steps: + - uses: actions/checkout@v3 + with: + lfs: true + fetch-depth: 0 # This ensures we can checkout main branch too + + - uses: actions/setup-python@v4 + with: + python-version: '3.10' + architecture: 'x64' + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -e ".[dev]" + + - name: Run benchmarks on PR branch + run: | + python tests/regression_test_runner.py + + + - name: Save new regression baseline + uses: actions/upload-artifact@v3 + with: + name: regression-test-results + path: tests/regression_test_results.json \ No newline at end of file diff --git a/tests/regression_test_results.json b/tests/regression_test_results.json index 8484c477..40925f57 100644 --- a/tests/regression_test_results.json +++ b/tests/regression_test_results.json @@ -1,4 +1,9 @@ { - "module.example_function_1": 0.0003229847801849246, - "module.example_function_2": 0.0011760416850447656 + "module.test_runtime1": 1.0000774711370468, + "module.test_runtime2": 1.0000772399362177, + "module.test_runtime3": 1.000094958813861, + "module.test_runtime4": 1.0000748871825635, + "module.test_runtime5": 1.0000761719420552, + "module.test_runtime6": 1.0001780251041055, + "module.test_runtime7": 1.0000949839595705 } \ No newline at end of file diff --git a/tests/regression_test_runner.py b/tests/regression_test_runner.py index c02e6df9..bc0e5f5b 100644 --- a/tests/regression_test_runner.py +++ b/tests/regression_test_runner.py @@ -12,17 +12,17 @@ def load_module(file_path): return module def get_test_functions(module): - """Get all functions in the module (assumes they are test functions).""" - return [ - value for value in vars(module).values() if callable(value) - ] + """Get all functions in the module that start with test.""" + is_test = lambda value: callable(value) and value.__name__.startswith("test") + return [value for value in vars(module).values() if is_test(value)] def run_benchmarks(functions_to_test): """Run benchmarks for the specified functions.""" results = {} for func in functions_to_test: + print(f"Running benchmark for {func.__name__}...") # Run each function 1000 times and take the average - time_taken = timeit.timeit(lambda: func(), number=1000) / 1000 + time_taken = timeit.timeit(lambda: func(), number=1) # Use the function's qualified name (module.function) as the key results[f"{func.__module__}.{func.__name__}"] = time_taken return results @@ -60,8 +60,8 @@ def compare_results(base_results, new_results): if __name__ == "__main__": regression_test_file = 'tests/test_regression.py' - output_file = 'tests/benchmark_results.json' - base_results_file = 'tests/benchmark_results.json' if Path('tests/benchmark_results.json').exists() else None + output_file = 'tests/regression_test_results.json' + base_results_file = 'tests/regression_test_results.json' if Path('tests/regression_test_results.json').exists() else None # Load the regression test module test_module = load_module(regression_test_file) diff --git a/tests/test_regression.py b/tests/test_regression.py index 5637941d..b3364555 100644 --- a/tests/test_regression.py +++ b/tests/test_regression.py @@ -1,14 +1,103 @@ -# your_module.py +# This file is part of Jaxley, a differentiable neuroscience simulator. Jaxley is +# licensed under the Apache License Version 2.0, see + +import os import time -def example_function_1(): - """Example function that does some work.""" - result = 0 - for i in range(1000): - result += i - return result - -def example_function_2(): - """Another example function with different performance characteristics.""" - time.sleep(0.001) # Simulate some work - return sum(range(1000)) \ No newline at end of file +import numpy as np +import pytest +from jax import jit + +import jaxley as jx +from jaxley.channels import HH +from jaxley.connect import sparse_connect +from jaxley.synapses import IonotropicSynapse + +# mark all tests as runtime tests in this file +pytestmark = pytest.mark.runtime + +def build_net(num_cells, artificial=True, connect=True, connection_prob=0.0): + _ = np.random.seed(1) # For sparse connectivity matrix. + + if artificial: + comp = jx.Compartment() + branch = jx.Branch(comp, 2) + depth = 3 + parents = [-1] + [b // 2 for b in range(0, 2**depth - 2)] + cell = jx.Cell(branch, parents=parents) + else: + dirname = os.path.dirname(__file__) + fname = os.path.join(dirname, "swc_files", "morph.swc") + cell = jx.read_swc(fname, nseg=4) + net = jx.Network([cell for _ in range(num_cells)]) + + # Channels. + net.insert(HH()) + + # Synapses. + if connect: + sparse_connect( + net.cell("all"), net.cell("all"), IonotropicSynapse(), connection_prob + ) + + # Recordings. + net[0, 1, 0].record(verbose=False) + + # Trainables. + net.make_trainable("radius", verbose=False) + params = net.get_parameters() + + net.to_jax() + return net, params + +def setup_runtime( + num_cells: int, + artificial: bool, + connect: bool, + connection_prob: float, + voltage_solver: str, + identifier: int, +): + delta_t = 0.025 + t_max = 100.0 + + net, params = build_net( + num_cells, + artificial=artificial, + connect=connect, + connection_prob=connection_prob, + ) + + def simulate(params): + return jx.integrate( + net, + params=params, + t_max=t_max, + delta_t=delta_t, + voltage_solver=voltage_solver, + ) + + jitted_simulate = jit(simulate) + + _ = jitted_simulate(params).block_until_ready() + + params[0]["radius"] = params[0]["radius"].at[0].set(0.5) + _ = jitted_simulate(params).block_until_ready() + +def test_runtime1(): + setup_runtime(1, False, False, 0.0, "jaxley.stone", 0) + +def test_runtime2(): + setup_runtime(1, False, False, 0.0, "jax.sparse", 1) + +def test_runtime3(): + setup_runtime(10, False, True, 0.1, "jaxley.stone", 2) + +def test_runtime4(): + setup_runtime(10, False, True, 0.1, "jax.sparse", 3) + +def test_runtime5(): + setup_runtime(1000, True, True, 0.001, "jaxley.stone", 4) + +def test_runtime6(): + setup_runtime(1000, True, True, 0.001, "jax.sparse", 5) \ No newline at end of file