Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
dweindl committed Dec 18, 2024
1 parent fc17260 commit b098ea0
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 54 deletions.
117 changes: 67 additions & 50 deletions petabtests/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import re
import sys
from pathlib import Path
from collections.abc import Iterable

from petab.v1.calculate import calculate_chi2, calculate_llh

Expand All @@ -18,7 +19,7 @@
write_solution,
)

__all__ = ["get_cases", "create", "clear", "get_cases_dir"]
__all__ = ["get_cases", "create_all", "clear", "get_cases_dir"]

test_formats = ("sbml", "pysb")
test_versions = ("v1.0.0", "v2.0.0")
Expand All @@ -27,10 +28,14 @@


def get_cases_dir(format_: str, version: str) -> Path:
"""Get the directory of the test cases for the given PEtab version and
model format."""
return CASES_DIR / version / format_


def get_cases(format_: str, version: str):
def get_cases(format_: str, version: str) -> Iterable[str]:
"""Get the list of test case IDs for the given PEtab version and model
format."""
cases_dir = get_cases_dir(format_=format_, version=version)
if not cases_dir.exists():
return []
Expand All @@ -41,13 +46,15 @@ def get_cases(format_: str, version: str):
)


def create():
def create_all():
"""Create all test files."""
for version, format_ in itertools.product(test_versions, test_formats):
case_list = get_cases(format_=format_, version=version)
if not case_list:
continue

# Table of contents markdown string for the current format x version
# directory README
toc = ""

for case_id in case_list:
Expand All @@ -58,55 +65,12 @@ def create():
f"Processing {version}/{format_} #{case_id} at {case_dir}"
)

# load test case module
# directory needs to be removed from path again and the module
# has to be unloaded, as modules from different model format
# suites have the same name
sys.path.append(str(case_dir))
case_module = importlib.import_module(case_id)
sys.path.pop()
# noinspection PyUnresolvedReferences
case: PetabTestCase = case_module.case
del sys.modules[case_id]
case = load_case(case_dir, case_id)

id_str = test_id_str(case.id)
toc += f"# [{id_str}]({id_str}/)\n\n{case.brief}\n\n"

write_info(case, format_, version=version)

write_problem(
test_id=case.id,
parameter_df=case.parameter_df,
condition_dfs=case.condition_dfs,
experiment_dfs=case.experiment_dfs,
observable_dfs=case.observable_dfs,
measurement_dfs=case.measurement_dfs,
model_files=case.model,
format_=format_,
version=version,
mapping_df=case.mapping_df,
)

chi2 = calculate_chi2(
case.measurement_dfs,
case.simulation_dfs,
case.observable_dfs,
case.parameter_df,
)
llh = calculate_llh(
case.measurement_dfs,
case.simulation_dfs,
case.observable_dfs,
case.parameter_df,
)
write_solution(
test_id=case.id,
chi2=chi2,
llh=llh,
simulation_dfs=case.simulation_dfs,
format_=format_,
version=version,
)
create_case(format_, version, case_id)

toc_path = (
get_cases_dir(format_=format_, version=version) / "README.md"
Expand All @@ -115,7 +79,60 @@ def create():
f.write(toc)


def clear():
def load_case(case_dir: Path, case_id: str) -> PetabTestCase:
"""Load a test case definition module."""
sys.path.append(str(case_dir))
case_module = importlib.import_module(case_id)
sys.path.pop()
# noinspection PyUnresolvedReferences
case: PetabTestCase = case_module.case
del sys.modules[case_id]
return case


def create_case(format_: str, version: str, id_: str) -> None:
"""Create a single test case."""
case_dir = get_case_dir(format_=format_, version=version, id_=id_)
case = load_case(case_dir, id_)

write_info(case, format_, version=version)

write_problem(
test_id=case.id,
parameter_df=case.parameter_df,
condition_dfs=case.condition_dfs,
experiment_dfs=case.experiment_dfs,
observable_dfs=case.observable_dfs,
measurement_dfs=case.measurement_dfs,
model_files=case.model,
format_=format_,
version=version,
mapping_df=case.mapping_df,
)

chi2 = calculate_chi2(
case.measurement_dfs,
case.simulation_dfs,
case.observable_dfs,
case.parameter_df,
)
llh = calculate_llh(
case.measurement_dfs,
case.simulation_dfs,
case.observable_dfs,
case.parameter_df,
)
write_solution(
test_id=case.id,
chi2=chi2,
llh=llh,
simulation_dfs=case.simulation_dfs,
format_=format_,
version=version,
)


def clear() -> None:
"""Clear all model folders."""
for version, format_ in itertools.product(test_versions, test_formats):
case_list = get_cases(format_=format_, version=version)
Expand All @@ -134,4 +151,4 @@ def _cli_create():
"""`petabtests_create` entry point."""
# initialize logging
logging.basicConfig(level=logging.INFO)
create()
create_all()
8 changes: 6 additions & 2 deletions petabtests/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,25 +49,29 @@ class PetabTestCase:


def get_case_dir(id_: int | str, format_: str, version: str) -> Path:
"""Get the directory of a test case."""
id_str = test_id_str(id_)
dir_ = CASES_DIR / version / format_ / id_str
dir_.mkdir(parents=True, exist_ok=True)
return dir_


def problem_yaml_name(_id: int | str) -> str:
"""Get the name of the problem yaml file."""
return "_" + test_id_str(_id) + ".yaml"


def solution_yaml_name(_id: int | str) -> str:
"""Get the name of the solution yaml file."""
return "_" + test_id_str(_id) + "_solution.yaml"


def test_id_str(_id: int | str) -> str:
"""Get the test id as a string."""
return f"{_id:0>4}"


def write_info(case: PetabTestCase, format_: str, version: str):
def write_info(case: PetabTestCase, format_: str, version: str) -> None:
"""Write test info markdown file"""
# id to string
dir_ = get_case_dir(id_=case.id, format_=format_, version=version)
Expand All @@ -91,7 +95,7 @@ def write_problem(
mapping_df: pd.DataFrame = None,
format_: str = "sbml",
) -> None:
"""Write problem to files.
"""Write the PEtab problem for a given test to files.
Parameters
----------
Expand Down
4 changes: 2 additions & 2 deletions test/test_generate.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from petabtests.core import create
from petabtests.core import create_all
from petabtests.C import CASES_DIR

import sys
Expand All @@ -7,7 +7,7 @@

def test_check_cases_up_to_date():
sys.path.insert(0, CASES_DIR)
create()
create_all()
res = subprocess.run(
["git", "diff", "--exit-code", CASES_DIR], capture_output=True
)
Expand Down

0 comments on commit b098ea0

Please sign in to comment.