Skip to content

Commit

Permalink
Add 3d data for tests
Browse files Browse the repository at this point in the history
  • Loading branch information
bentaculum committed May 28, 2024
1 parent d84a150 commit b5f07f9
Show file tree
Hide file tree
Showing 7 changed files with 65 additions and 40 deletions.
35 changes: 15 additions & 20 deletions tests/test_inference_api.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,29 @@
import os
import tempfile

os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
from pathlib import Path

import pytest
import torch
from trackastra.data import example_data_fluo_3d, example_data_hela
from trackastra.model import Trackastra
from trackastra.tracking import graph_to_ctc, graph_to_napari_tracks


@pytest.mark.parametrize(
"data_path",
"example_data",
[
"data/ctc_2024/DIC-C2DH-HeLa/train/01",
"data/ctc_2024/Fluo-C3DL-MDA231/train/01",
example_data_hela,
example_data_fluo_3d,
],
ids=["2d", "3d"],
)
def test_api(data_path: str):
# data_path = "data/ctc_2024/DIC-C2DH-HeLa/train/01"
data_path = Path(data_path)

# imgs = load_tiff_timeseries(data_path, dtype=float)
# imgs = np.stack([normalize(x) for x in imgs])
# masks = load_tiff_timeseries(f"{data_path}_ST/SEG", dtype=int)

from trackastra.data import example_data_hela

imgs, masks = example_data_hela()
def test_api(example_data):
imgs, masks = example_data()
device = "cuda" if torch.cuda.is_available() else "cpu"
model = Trackastra.from_pretrained(
name="general_2d",
name="ctc",
device=device,
)

Expand All @@ -51,10 +44,12 @@ def test_api(data_path: str):
# masks_path=...,
# )

_, masks_tracked = graph_to_ctc(
track_graph,
masks,
outdir=Path(__file__).parent.resolve() / "tmp" / data_path.name,
)
with tempfile.TemporaryDirectory() as tmp:
tmp = Path(tmp)
_, masks_tracked = graph_to_ctc(
track_graph,
masks,
outdir=tmp,
)

napari_tracks, napari_tracks_graph, _ = graph_to_napari_tracks(track_graph)
2 changes: 1 addition & 1 deletion trackastra/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
collate_sequence_padding,
extract_features_regionprops,
)
from .example_data import example_data_bacteria, example_data_hela
from .example_data import example_data_bacteria, example_data_fluo_3d, example_data_hela
from .sampler import (
BalancedBatchSampler,
BalancedDataModule,
Expand Down
12 changes: 12 additions & 0 deletions trackastra/data/example_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,15 @@ def example_data_hela():
img = tifffile.imread(root / "Fluo_Hela_02_img.tif")
mask = tifffile.imread(root / "Fluo_Hela_02_ERR_SEG.tif")
return img, mask


def example_data_fluo_3d():
"""Fluo-N3DH-CHO data from the cell tracking challenge.
Dzyubachyk et al. Advanced Level-Set-Based Cell Tracking in Time-Lapse Fluorescence Microscopy (2010)
subset of Fluo-N3DH-CHO/train/02
"""
img = tifffile.imread(root / "Fluo-N3DH-CHO_02_img.tif")
mask = tifffile.imread(root / "Fluo-N3DH-CHO_02_ERR_SEG.tif")
return img, mask
Binary file not shown.
Binary file not shown.
4 changes: 3 additions & 1 deletion trackastra/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ def load_tiff_timeseries(
leave=False,
desc=f"Loading [{start_frame}:{end_frame}:{downscale[0]}]",
):
_x = tifffile.imread(f).astype(dtype)
_x = tifffile.imread(f)
if dtype:
_x = _x.astype(dtype)
assert _x.shape == shape
slices = tuple(slice(None, None, d) for d in downscale[1:])
_x = _x[slices]
Expand Down
52 changes: 34 additions & 18 deletions trackastra/model/model_api.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
from tqdm import tqdm
import yaml
import logging
import numpy as np
import logging
from pathlib import Path
from pydantic import validate_call
from typing import Literal
from typing import Literal, Optional

import numpy as np
import yaml
from pydantic import validate_call
from tqdm import tqdm

from .predict import predict_windows
from .pretrained import download_pretrained
from .model import TrackingTransformer
from ..data import build_windows, get_features
from ..tracking import build_graph, track_greedy
from ..utils import normalize
from .model import TrackingTransformer
from .predict import predict_windows
from .pretrained import download_pretrained

logger = logging.getLogger(__name__)


class Trackastra:
def __init__(self, transformer, train_args, device="cpu"):
# Hack: to(device) for some more submodules that map_location does cover
Expand All @@ -33,32 +35,46 @@ def from_folder(cls, dir: Path, device: str = "cpu"):
@classmethod
@validate_call
def from_pretrained(
cls, name: str, device: str = "cpu", download_dir: Path=None):
cls, name: str, device: str = "cpu", download_dir: Optional[Path] = None
):
folder = download_pretrained(name, download_dir)
# download zip from github to location/name, then unzip
return cls.from_folder(folder, device=device)

def _predict(
self, imgs: np.ndarray, masks: np.ndarray, edge_threshold: float = 0.05, n_workers: int = 0,
self,
imgs: np.ndarray,
masks: np.ndarray,
edge_threshold: float = 0.05,
n_workers: int = 0,
progbar_class=tqdm,
):
logger.info("Predicting weights for candidate graph")
imgs = normalize(imgs)
self.transformer.eval()

features = get_features(
detections=masks, imgs=imgs, ndim=self.transformer.config["coord_dim"], n_workers=n_workers, progbar_class=progbar_class
detections=masks,
imgs=imgs,
ndim=self.transformer.config["coord_dim"],
n_workers=n_workers,
progbar_class=progbar_class,
)
logger.info("Building windows")
windows = build_windows(features, window_size=self.transformer.config["window"],progbar_class=progbar_class)

windows = build_windows(
features,
window_size=self.transformer.config["window"],
progbar_class=progbar_class,
)

logger.info("Predicting windows")
predictions = predict_windows(
windows=windows,
features=features,
model=self.transformer,
edge_threshold=edge_threshold,
spatial_dim=masks.ndim - 1,
progbar_class=progbar_class
progbar_class=progbar_class,
)

return predictions
Expand All @@ -77,7 +93,7 @@ def _track_from_predictions(
logger.info("Running greedy tracker")
nodes = predictions["nodes"]
weights = predictions["weights"]

candidate_graph = build_graph(
nodes=nodes,
weights=weights,
Expand All @@ -93,7 +109,7 @@ def _track_from_predictions(
elif mode == "ilp":
from trackastra.tracking.ilp import track_ilp

return track_ilp(candidate_graph, ilp_config='gt', **kwargs)
return track_ilp(candidate_graph, ilp_config="gt", **kwargs)
else:
raise ValueError(f"Tracking mode {mode} does not exist.")

Expand Down

0 comments on commit b5f07f9

Please sign in to comment.