diff --git a/.github/workflows/regression_tests.yml b/.github/workflows/regression_tests.yml
new file mode 100644
index 00000000..78fabd09
--- /dev/null
+++ b/.github/workflows/regression_tests.yml
@@ -0,0 +1,40 @@
+# .github/workflows/regression_tests.yml
+name: Regression Tests
+
+on:
+ # pull_request:
+ # branches:
+ # - main
+
+jobs:
+ regression_tests:
+ name: regression_tests
+ runs-on: ubuntu-20.04
+
+ steps:
+ - uses: actions/checkout@v3
+ with:
+ lfs: true
+ fetch-depth: 0 # This ensures we can checkout main branch too
+
+ - uses: actions/setup-python@v4
+ with:
+ python-version: '3.10'
+ architecture: 'x64'
+
+ - name: Install dependencies
+ run: |
+ python -m pip install --upgrade pip
+ pip install -e ".[dev]"
+
+ - name: Run benchmarks and compare to baseline
+ if: github.event.pull_request.base.ref == 'main'
+ run: |
+ # # Check if regression test results exist in main branch
+ # if [ -f 'git cat-file -e main:tests/regression_test_baselines.json' ]; then
+ # git checkout main tests/regression_test_baselines.json
+ # else
+ # echo "No regression test results found in main branch"
+ # fi
+ pytest -m regression
+ # git checkout
\ No newline at end of file
diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml
index 3eb90b0a..5b5ebf07 100644
--- a/.github/workflows/tests.yml
+++ b/.github/workflows/tests.yml
@@ -1,12 +1,12 @@
name: Tests
on:
- push:
- branches:
- - main
- pull_request:
- branches:
- - main
+# push:
+# branches:
+# - main
+# pull_request:
+# branches:
+# - main
jobs:
build:
@@ -39,4 +39,4 @@ jobs:
- name: Test with pytest
run: |
pip install pytest pytest-cov
- pytest tests/ --cov=jaxley --cov-report=xml
+ pytest tests/ -m "not regression" --cov=jaxley --cov-report=xml
diff --git a/.github/workflows/update_regression_baseline.yml b/.github/workflows/update_regression_baseline.yml
new file mode 100644
index 00000000..98ada263
--- /dev/null
+++ b/.github/workflows/update_regression_baseline.yml
@@ -0,0 +1,75 @@
+# .github/workflows/update_regression_tests.yml
+name: Update Regression Baseline
+
+on:
+ # issue_comment: # event runs on the default branch
+ # types: [created]
+ pull_request:
+ branches:
+ - main
+
+jobs:
+ update_regression_tests:
+ name: update_regression_tests
+ runs-on: ubuntu-20.04
+ # if: github.event.issue.pull_request && contains(github.event.comment.body, '/update_baseline')
+ permissions:
+ contents: write
+ pull-requests: write
+ env:
+ username: ${{ github.event.pull_request.user.login }} # ${{ github.actor }}
+
+ steps:
+ # - name: Get PR branch
+ # uses: xt0rted/pull-request-comment-branch@v1
+ # id: comment-branch
+
+ - name: Checkout PR branch
+ uses: actions/checkout@v3
+ with:
+ # ref: ${{ steps.comment-branch.outputs.head_sha }} # using head_sha vs. head_ref makes this work for forks
+ lfs: true
+ fetch-depth: 0 # This ensures we can checkout main branch too
+
+ - uses: actions/setup-python@v4
+ with:
+ python-version: '3.10'
+ architecture: 'x64'
+
+ - name: Install dependencies
+ run: |
+ python -m pip install --upgrade pip
+ pip install -e ".[dev]"
+
+ - name: Update baseline
+ if: github.event.pull_request.base.ref == 'main'
+ run: |
+ git config --global user.name '$username'
+ git config --global user.email '$username@users.noreply.github.com'
+ mv tests/regression_test_baselines.json tests/regression_test_baselines.json.bak
+ NEW_BASELINE=1 pytest -m regression
+
+ - name: Add Baseline update report to PR comment
+ uses: actions/github-script@v7
+ if: always()
+ with:
+ github-token: ${{ secrets.GITHUB_TOKEN }}
+ script: |
+ const fs = require('fs');
+ const TestReport = fs.readFileSync('tests/regression_test_report.txt', 'utf8');
+
+ await github.rest.issues.createComment({
+ issue_number: context.issue.number,
+ owner: context.repo.owner,
+ repo: context.repo.repo,
+ body: `## New Baselines \n\`\`\`\n${TestReport}\n\`\`\``
+ });
+
+
+ - name: Commit and push
+ if: github.event.pull_request.base.ref == 'main'
+ run: |
+ git add -f tests/regression_test_baselines.json # since it's in .gitignore
+ git commit -m "Update regression test baselines"
+ # git push origin HEAD:${{ steps.comment-branch.outputs.head_sha }}
+ git push origin HEAD:${{ github.head_ref }}
\ No newline at end of file
diff --git a/.gitignore b/.gitignore
index 6162a95b..d5638eb6 100644
--- a/.gitignore
+++ b/.gitignore
@@ -55,6 +55,8 @@ coverage.xml
*.py,cover
.hypothesis/
.pytest_cache/
+tests/regression_test_results.json
+tests/regression_test_baselines.json
# Translations
*.mo
diff --git a/pyproject.toml b/pyproject.toml
index d37e2bda..d1072768 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -69,6 +69,7 @@ dev = [
[tool.pytest.ini_options]
markers = [
"slow: marks tests as slow (T > 10s)",
+ "regression: marks regression tests",
]
[tool.isort]
diff --git a/tests/conftest.py b/tests/conftest.py
index dad1c4a5..15f01295 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1,6 +1,7 @@
# This file is part of Jaxley, a differentiable neuroscience simulator. Jaxley is
# licensed under the Apache License Version 2.0, see
+import json
import os
from copy import deepcopy
from typing import Optional
@@ -9,6 +10,7 @@
import jaxley as jx
from jaxley.synapses import IonotropicSynapse
+from tests.test_regression import generate_regression_report
@pytest.fixture(scope="session")
@@ -202,3 +204,36 @@ def get_or_compute_swc2jaxley_params(
yield get_or_compute_swc2jaxley_params
params = {}
+
+
+@pytest.fixture(scope="session", autouse=True)
+def print_session_report(request):
+ """Cleanup a testing directory once we are finished."""
+
+ def print_regression_report():
+ dirname = os.path.dirname(__file__)
+ baseline_fname = os.path.join(dirname, "regression_test_baselines.json")
+ results_fname = os.path.join(dirname, "regression_test_results.json")
+
+ baselines = {}
+ if os.path.exists(baseline_fname):
+ with open(baseline_fname) as f:
+ baselines = json.load(f)
+
+ results = {}
+ if os.path.exists(results_fname):
+ with open(results_fname) as f:
+ results = json.load(f)
+
+ # the following allows to print the report to the console despite pytest
+ # capturing the output and without specifying the "-s" flag
+ capmanager = request.config.pluginmanager.getplugin("capturemanager")
+ with capmanager.global_and_fixture_disabled():
+ print("\n\n\nRegression Test Report\n----------------------\n")
+ if not baselines:
+ print(
+ "No baselines found. Run `git checkout main;UPDATE_BASELINE=1 pytest -m regression; git checkout -`"
+ )
+ print(generate_regression_report(baselines, results))
+
+ request.addfinalizer(print_regression_report)
diff --git a/tests/regression_test_baselines.json b/tests/regression_test_baselines.json
new file mode 100644
index 00000000..cdf9e30b
--- /dev/null
+++ b/tests/regression_test_baselines.json
@@ -0,0 +1,17 @@
+{
+ "ec3a4fad11d2bfb1bc5f8f10529cb06f2ff9919b377e9c0a3419c7f7f237f06e": {
+ "test_name": "test_runtime",
+ "input_kwargs": {
+ "num_cells": 1,
+ "artificial": false,
+ "connect": false,
+ "connection_prob": 0.0,
+ "voltage_solver": "jaxley.stone"
+ },
+ "runtimes": {
+ "build_time": 0.10014088948567708,
+ "compile_time": 0.3103648026784261,
+ "run_time": 0.2102543512980143
+ }
+ }
+}
\ No newline at end of file
diff --git a/tests/test_regression.py b/tests/test_regression.py
new file mode 100644
index 00000000..1b14a5f3
--- /dev/null
+++ b/tests/test_regression.py
@@ -0,0 +1,259 @@
+# This file is part of Jaxley, a differentiable neuroscience simulator. Jaxley is
+# licensed under the Apache License Version 2.0, see
+
+import hashlib
+import json
+import os
+import time
+from functools import wraps
+
+import numpy as np
+import pytest
+from jax import jit
+
+import jaxley as jx
+from jaxley.channels import HH
+from jaxley.connect import sparse_connect
+from jaxley.synapses import IonotropicSynapse
+
+# Every runtime test needs to have the following structure:
+#
+# @compare_to_baseline()
+# def test_runtime_of_x(**kwargs) -> Dict:
+# t1 = time.time()
+# time.sleep(0.1)
+# # do something
+# t2 = time.time()
+# # do something else
+# t3 = time.time()
+# return {"sth": t2-t1, sth_else: t3-t2}
+
+def load_json(fpath):
+ dct = {}
+ if os.path.exists(fpath):
+ with open(fpath, "r") as f:
+ dct = json.load(f)
+ return dct
+
+
+pytestmark = pytest.mark.regression # mark all tests as regression tests in this file
+NEW_BASELINE = os.environ["NEW_BASELINE"] if "NEW_BASELINE" in os.environ else 0
+dirname = os.path.dirname(__file__)
+fpath_baselines = os.path.join(dirname, "regression_test_baselines.json")
+fpath_results = os.path.join(dirname, "regression_test_results.json")
+
+tolerance = 0.2
+
+baselines = load_json(fpath_baselines)
+
+
+
+def generate_regression_report(base_results, new_results):
+ """Compare two sets of benchmark results and generate a diff report."""
+ report = []
+ for key in new_results:
+ new_data = new_results[key]
+ base_data = base_results.get(key)
+ kwargs = ", ".join([f"{k}={v}" for k, v in new_data["input_kwargs"].items()])
+ func_name = new_data["test_name"]
+ func_signature = f"{func_name}({kwargs})"
+
+ new_runtimes = new_data["runtimes"]
+ base_runtimes = (
+ {k: None for k in new_data.keys()}
+ if base_data is None
+ else base_data["runtimes"]
+ )
+
+ report.append(func_signature)
+ for key, new_time in new_runtimes.items():
+ base_time = base_runtimes.get(key)
+ diff = None if base_time is None else ((new_time - base_time) / base_time)
+
+ emoji = ""
+ if diff is None:
+ emoji = "🆕"
+ elif diff > tolerance:
+ emoji = "🔴"
+ elif diff < 0:
+ emoji = "🟢"
+ else:
+ emoji = "⚪"
+
+ time_str = (
+ f"({new_time:.6f}s)"
+ if diff is None
+ else f"({diff:+.1f}% vs {base_time:.6f}s)"
+ )
+ report.append(f"{emoji} {key}: {time_str}.")
+ report.append("")
+
+ return "\n".join(report)
+
+
+def generate_unique_key(d):
+ # Generate a unique key for each test case. Makes it possible to compare tests
+ # with different input_kwargs.
+ hash_obj = hashlib.sha256(bytes(json.dumps(d, sort_keys=True), encoding="utf-8"))
+ hash = hash_obj.hexdigest()
+ return str(hash)
+
+
+def append_to_json(fpath, test_name, input_kwargs, runtimes):
+ header = {"test_name": test_name, "input_kwargs": input_kwargs}
+ data = {generate_unique_key(header): {**header, "runtimes": runtimes}}
+
+ # Save data to a JSON file
+ result_data = load_json(fpath)
+ result_data.update(data)
+
+ with open(fpath, "w") as f:
+ json.dump(result_data, f, indent=2)
+
+
+class compare_to_baseline:
+ def __init__(self, baseline_iters=1, test_iters=1):
+ self.baseline_iters = baseline_iters
+ self.test_iters = test_iters
+
+ def __call__(self, func):
+ @wraps(func) # ensures kwargs exposed to pytest
+ def test_wrapper(**kwargs):
+ header = {"test_name": func.__name__, "input_kwargs": kwargs}
+ key = generate_unique_key(header)
+
+ new_data, old_data = None, None
+
+ runs = []
+ num_iters = self.baseline_iters if NEW_BASELINE else self.test_iters
+ for _ in range(num_iters):
+ runtimes = func(**kwargs)
+ runs.append(runtimes)
+ runtimes = {k: np.mean([d[k] for d in runs]) for k in runs[0]}
+
+ fpath = fpath_results if not NEW_BASELINE else fpath_baselines
+ append_to_json(fpath, header["test_name"], header["input_kwargs"], runtimes)
+
+ if not NEW_BASELINE:
+ assert key in baselines, f"No basline found for {header}"
+ func_baselines = baselines[key]["runtimes"]
+ for key, baseline in func_baselines.items():
+ diff = (
+ float("nan")
+ if np.isclose(baseline, 0)
+ else (runtimes[key] - baseline) / baseline
+ )
+ assert runtimes[key] < baseline * (
+ 1 + tolerance
+ ), f"{key} is {diff:.2%} slower than the baseline."
+
+ # save report
+ if NEW_BASELINE:
+ new_data = load_json(fpath_baselines)
+ old_data = load_json(fpath_baselines + ".bak")
+ else:
+ new_data = load_json(fpath_results)
+ old_data = load_json(fpath_baselines)
+ report = generate_regression_report(old_data, new_data)
+
+ with open(dirname + "/regression_test_report.txt", "w") as f:
+ f.write(report)
+
+ return test_wrapper
+
+
+def build_net(num_cells, artificial=True, connect=True, connection_prob=0.0):
+ _ = np.random.seed(1) # For sparse connectivity matrix.
+
+ if artificial:
+ comp = jx.Compartment()
+ branch = jx.Branch(comp, 2)
+ depth = 3
+ parents = [-1] + [b // 2 for b in range(0, 2**depth - 2)]
+ cell = jx.Cell(branch, parents=parents)
+ else:
+ dirname = os.path.dirname(__file__)
+ fname = os.path.join(dirname, "swc_files", "morph.swc")
+ cell = jx.read_swc(fname, nseg=4)
+ net = jx.Network([cell for _ in range(num_cells)])
+
+ # Channels.
+ net.insert(HH())
+
+ # Synapses.
+ if connect:
+ sparse_connect(
+ net.cell("all"), net.cell("all"), IonotropicSynapse(), connection_prob
+ )
+
+ # Recordings.
+ net[0, 1, 0].record(verbose=False)
+
+ # Trainables.
+ net.make_trainable("radius", verbose=False)
+ params = net.get_parameters()
+
+ net.to_jax()
+ return net, params
+
+
+@pytest.mark.parametrize(
+ "num_cells, artificial, connect, connection_prob, voltage_solver",
+ (
+ # Test a single SWC cell with both solvers.
+ pytest.param(1, False, False, 0.0, "jaxley.stone"),
+ # pytest.param(1, False, False, 0.0, "jax.sparse"),
+ # # Test a network of SWC cells with both solvers.
+ # pytest.param(10, False, True, 0.1, "jaxley.stone"),
+ # pytest.param(10, False, True, 0.1, "jax.sparse"),
+ # # Test a larger network of smaller neurons with both solvers.
+ # pytest.param(1000, True, True, 0.001, "jaxley.stone"),
+ # pytest.param(1000, True, True, 0.001, "jax.sparse"),
+ ),
+)
+@compare_to_baseline(baseline_iters=3)
+def test_runtime(
+ num_cells: int,
+ artificial: bool,
+ connect: bool,
+ connection_prob: float,
+ voltage_solver: str,
+):
+ import time
+ delta_t = 0.025
+ t_max = 100.0
+
+ # def simulate(params):
+ # return jx.integrate(
+ # net,
+ # params=params,
+ # t_max=t_max,
+ # delta_t=delta_t,
+ # voltage_solver=voltage_solver,
+ # )
+
+ runtimes = {}
+
+ start_time = time.time()
+ # net, params = build_net(
+ # num_cells,
+ # artificial=artificial,
+ # connect=connect,
+ # connection_prob=connection_prob,
+ # )
+ time.sleep(0.1)
+ runtimes["build_time"] = time.time() - start_time
+
+ # jitted_simulate = jit(simulate)
+
+ start_time = time.time()
+ time.sleep(0.31)
+ # _ = jitted_simulate(params).block_until_ready()
+ runtimes["compile_time"] = time.time() - start_time
+ # params[0]["radius"] = params[0]["radius"].at[0].set(0.5)
+
+ start_time = time.time()
+ # _ = jitted_simulate(params).block_until_ready()
+ time.sleep(0.21)
+ runtimes["run_time"] = time.time() - start_time
+ return runtimes # @compare_to_baseline decorator will compare this to the baseline