Skip to content

Commit

Permalink
wip: get first baselines again
Browse files Browse the repository at this point in the history
  • Loading branch information
jnsbck committed Nov 22, 2024
1 parent 9517fb8 commit c1b3ccd
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 45 deletions.
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def update_baseline():
results = load_json(results_fname)
with open(baseline_fname, "w") as f:
json.dump(results, f, indent=2)
os.remove(results_fname)
os.remove(results_fname)

def print_regression_report():
baselines = load_json(baseline_fname)
Expand All @@ -239,5 +239,5 @@ def print_regression_report():
print("\n\n\nRegression Test Report\n----------------------\n")
print(report)

request.addfinalizer(print_regression_report)
request.addfinalizer(update_baseline)
request.addfinalizer(print_regression_report)
83 changes: 40 additions & 43 deletions tests/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,17 +132,18 @@ def test_wrapper(**kwargs):

append_to_json(fpath_results, header["test_name"], header["input_kwargs"], runtimes)

assert key in baselines, f"No basline found for {header}"
func_baselines = baselines[key]["runtimes"]
for key, baseline in func_baselines.items():
diff = (
float("nan")
if np.isclose(baseline, 0)
else (runtimes[key] - baseline) / baseline
)
assert runtimes[key] < baseline * (
1 + tolerance
), f"{key} is {diff:.2%} slower than the baseline."
if not NEW_BASELINE:
assert key in baselines, f"No basline found for {header}"
func_baselines = baselines[key]["runtimes"]
for key, baseline in func_baselines.items():
diff = (
float("nan")
if np.isclose(baseline, 0)
else (runtimes[key] - baseline) / baseline
)
assert runtimes[key] < baseline * (
1 + tolerance
), f"{key} is {diff:.2%} slower than the baseline."

return test_wrapper

Expand Down Expand Up @@ -187,13 +188,13 @@ def build_net(num_cells, artificial=True, connect=True, connection_prob=0.0):
(
# Test a single SWC cell with both solvers.
pytest.param(1, False, False, 0.0, "jaxley.stone"),
# pytest.param(1, False, False, 0.0, "jax.sparse"),
# # Test a network of SWC cells with both solvers.
# pytest.param(10, False, True, 0.1, "jaxley.stone"),
# pytest.param(10, False, True, 0.1, "jax.sparse"),
# # Test a larger network of smaller neurons with both solvers.
# pytest.param(1000, True, True, 0.001, "jaxley.stone"),
# pytest.param(1000, True, True, 0.001, "jax.sparse"),
pytest.param(1, False, False, 0.0, "jax.sparse"),
# Test a network of SWC cells with both solvers.
pytest.param(10, False, True, 0.1, "jaxley.stone"),
pytest.param(10, False, True, 0.1, "jax.sparse"),
# Test a larger network of smaller neurons with both solvers.
pytest.param(1000, True, True, 0.001, "jaxley.stone"),
pytest.param(1000, True, True, 0.001, "jax.sparse"),
),
)
@compare_to_baseline(baseline_iters=3)
Expand All @@ -204,41 +205,37 @@ def test_runtime(
connection_prob: float,
voltage_solver: str,
):
import time
# delta_t = 0.025
# t_max = 100.0

# def simulate(params):
# return jx.integrate(
# net,
# params=params,
# t_max=t_max,
# delta_t=delta_t,
# voltage_solver=voltage_solver,
# )
delta_t = 0.025
t_max = 100.0

def simulate(params):
return jx.integrate(
net,
params=params,
t_max=t_max,
delta_t=delta_t,
voltage_solver=voltage_solver,
)

runtimes = {}

start_time = time.time()
# net, params = build_net(
# num_cells,
# artificial=artificial,
# connect=connect,
# connection_prob=connection_prob,
# )
time.sleep(0.1)
net, params = build_net(
num_cells,
artificial=artificial,
connect=connect,
connection_prob=connection_prob,
)
runtimes["build_time"] = time.time() - start_time

# jitted_simulate = jit(simulate)
jitted_simulate = jit(simulate)

start_time = time.time()
time.sleep(0.2)
# _ = jitted_simulate(params).block_until_ready()
_ = jitted_simulate(params).block_until_ready()
runtimes["compile_time"] = time.time() - start_time
# params[0]["radius"] = params[0]["radius"].at[0].set(0.5)
params[0]["radius"] = params[0]["radius"].at[0].set(0.5)

start_time = time.time()
# _ = jitted_simulate(params).block_until_ready()
time.sleep(0.31)
_ = jitted_simulate(params).block_until_ready()
runtimes["run_time"] = time.time() - start_time
return runtimes # @compare_to_baseline decorator will compare this to the baseline

0 comments on commit c1b3ccd

Please sign in to comment.