Skip to content

Commit

Permalink
Update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
bentaculum committed May 28, 2024
1 parent a2defb9 commit d84a150
Show file tree
Hide file tree
Showing 8 changed files with 77 additions and 57 deletions.
57 changes: 20 additions & 37 deletions tests/test_inference_api.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
import os

os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
from pathlib import Path

import napari
import numpy as np
import pytest
from trackastra.data import load_tiff_timeseries
import torch
from trackastra.model import Trackastra
from trackastra.tracking import graph_to_ctc, graph_to_napari_tracks
from trackastra.utils import normalize


@pytest.mark.parametrize("device", ["cpu", "cuda"])
@pytest.mark.parametrize(
"data_path",
[
Expand All @@ -19,39 +17,34 @@
],
ids=["2d", "3d"],
)
def test_api(
data_path: str,
device: str,
):
# if __name__ == "__main__":
def test_api(data_path: str):
# data_path = "data/ctc_2024/DIC-C2DH-HeLa/train/01"
# device = "cuda"

# TODO download/commit datasets

data_path = Path(data_path)

imgs = load_tiff_timeseries(data_path, dtype=float)
imgs = np.stack([normalize(x) for x in imgs])
masks = load_tiff_timeseries(f"{data_path}_ST/SEG", dtype=int)
# imgs = load_tiff_timeseries(data_path, dtype=float)
# imgs = np.stack([normalize(x) for x in imgs])
# masks = load_tiff_timeseries(f"{data_path}_ST/SEG", dtype=int)

model = Trackastra.load_pretrained(
name="ctc",
from trackastra.data import example_data_hela

imgs, masks = example_data_hela()
device = "cuda" if torch.cuda.is_available() else "cpu"
model = Trackastra.from_pretrained(
name="general_2d",
device=device,
)
# model = Trackastra.load_from_folder(
# # Path(__file__).parent.parent.resolve() / "scripts/runs/ctc_3d_new_3",
# device=device,
# )

# Steps
# TODO it would probably make sense to already store the prediction as a trackastra.TrackGraph
# TODO store predictions already on trackastra.TrackGraph
predictions = model._predict(imgs, masks)

track_graph = model._track_from_predictions(predictions)

# TODO: TrackGraph class that wraps a networkx graph
track_graph = model.track(imgs, masks, mode="ilp", ilp_config="deepcell_gt")
track_graph = model.track(
imgs,
masks,
mode="greedy",
ilp_config="gt",
)

# track_graph = model.track_from_disk(
# imgs_path=...,
Expand All @@ -65,13 +58,3 @@ def test_api(
)

napari_tracks, napari_tracks_graph, _ = graph_to_napari_tracks(track_graph)

if "DISPLAY" in os.environ:
if napari.current_viewer() is None:
v = napari.Viewer()

v.add_image(imgs)
v.add_labels(masks_tracked)
v.add_tracks(data=napari_tracks, graph=napari_tracks_graph)
else:
print("No display available.")
1 change: 0 additions & 1 deletion tests/test_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ def repeat_tile(x, repeats=(4, 4)):
return x


# if __name__ == "__main__":
def test_matching():
np.random.seed(42)

Expand Down
1 change: 0 additions & 1 deletion tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from trackastra.model import TrackingTransformer


# if __name__ == "__main__":
def test_model():
torch.manual_seed(0)
coords = torch.randint(0, 400, (1, 100, 3)).float()
Expand Down
43 changes: 40 additions & 3 deletions tests/test_pretrained.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,42 @@
from trackastra.model.pretrained import download_pretrained
import os

os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"

if __name__ == "__main__":
folder=download_pretrained("general_2d")
import pytest
import torch
from trackastra.data import example_data_hela
from trackastra.model import Trackastra


@pytest.mark.parametrize("name", ["ctc", "general_2d"])
@pytest.mark.parametrize("device", ["cpu", "mps", "cuda"])
def test_pretrained(name, device):
"""Each pretrained model should run on all (available) device."""
if device == "cuda":
if torch.cuda.is_available():
run_predictions(name, "cuda")
else:
pytest.skip("cuda not available")
elif device == "mps":
if torch.backends.mps.is_available():
run_predictions(name, "mps")
else:
pytest.skip("mps not available")
elif device == "cpu":
# pytest.skip("cpu not needed")
run_predictions(name, "cpu")
else:
raise ValueError()

assert True


def run_predictions(name, device):
model = Trackastra.from_pretrained(
name=name,
device=device,
)
imgs, masks = example_data_hela()

_ = model._predict(imgs, masks)
assert True
2 changes: 1 addition & 1 deletion trackastra/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
collate_sequence_padding,
extract_features_regionprops,
)
from .example_data import example_data_bacteria, example_data_hela
from .sampler import (
BalancedBatchSampler,
BalancedDataModule,
BalancedDistributedSampler,
)
from .utils import filter_track_df, load_tiff_timeseries, load_tracklet_links
from .wrfeat import WRFeatures, build_windows, get_features
from .example_data import test_data_bacteria, test_data_hela
27 changes: 14 additions & 13 deletions trackastra/data/example_data.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,29 @@
import tifffile
from pathlib import Path

import tifffile

root = Path(__file__).parent/'resources'
root = Path(__file__).parent / "resources"

def test_data_bacteria():
""" Bacteria images and masks from

def example_data_bacteria():
"""Bacteria images and masks from.
Van Vliet et al. Local interactions lead to spatially correlated gene expression levels in bacterial group (2018)
subset of timelapse trpL/150310-11
"""
img = tifffile.imread(root/'trpL_150310-11_img.tif')
mask = tifffile.imread(root/'trpL_150310-11_mask.tif')
return img, mask
"""
img = tifffile.imread(root / "trpL_150310-11_img.tif")
mask = tifffile.imread(root / "trpL_150310-11_mask.tif")
return img, mask


def test_data_hela():
""" Hela data from the cell tracking challenge
def example_data_hela():
"""Hela data from the cell tracking challenge.
Neumann et al. Phenotypic profiling of the human genome by time-lapse microscopy reveals cell division genes (2010)
subset of Fluo-N2DL-HeLa/train/02
"""
img = tifffile.imread(root/'Fluo_Hela_02_img.tif')
mask = tifffile.imread(root/'Fluo_Hela_02_ERR_SEG.tif')
return img, mask
img = tifffile.imread(root / "Fluo_Hela_02_img.tif")
mask = tifffile.imread(root / "Fluo_Hela_02_ERR_SEG.tif")
return img, mask
1 change: 1 addition & 0 deletions trackastra/model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# ruff: noqa: F401

from .model import TrackingTransformer
from .model_api import Trackastra
2 changes: 1 addition & 1 deletion trackastra/tracking/ilp.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
def track_ilp(
candidate_graph,
allow_divisions: bool = True,
ilp_config: str = "ilp_gt",
ilp_config: str = "gt",
params_file: str | None = None,
**kwargs,
):
Expand Down

0 comments on commit d84a150

Please sign in to comment.