diff --git a/setup.cfg b/setup.cfg index 7349308..eda287e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -47,7 +47,6 @@ ilp = motile >= 0.2 dev = pytest - shell ruff black mypy @@ -56,7 +55,6 @@ dev = build test = pytest - shell [options.entry_points] console_scripts = diff --git a/tests/test_cli.py b/tests/test_cli.py index 0010c20..d995366 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,16 +1,35 @@ -from shell import shell +import os +from pathlib import Path from test_data import example_dataset def test_cli_parser(): - result = shell("trackastra") - assert result.code == 0 + result = os.system("trackastra") + assert result == 0 -def test_cli_tracking(): +def test_cli_tracking_from_folder(): example_dataset() - result = shell( - "trackastra track -i test_data/img -m test_data/TRA --model-pretrained general_2d" # noqa: RUF100 - ) - assert result.code == 0 + cmd = "trackastra track -i test_data/img -m test_data/TRA --output-ctc test_data/tracked --output-edge-table test_data/tracked.csv --model-pretrained general_2d" # noqa: RUF100 + print(cmd) + result = os.system(cmd) + assert Path("test_data/tracked").exists() + assert Path("test_data/tracked.csv").exists() + assert result == 0 + + +def test_cli_tracking_from_file(): + root = Path(__file__).parent.parent / "trackastra" / "data" / "resources" + print(root) + assert root.exists() + output_ctc = Path("test_data") / "tracked_bacteria" + output_edge_table = Path("test_data") / "tracked_bacteria.csv" + + cmd = f"trackastra track -i {root / 'trpL_150310-11_img.tif'} -m {root / 'trpL_150310-11_mask.tif'} --output-ctc {output_ctc} --output-edge-table {output_edge_table} --model-pretrained general_2d" # noqa: RUF100 + print(cmd) + result = os.system(cmd) + + assert output_ctc.exists() + assert output_edge_table.exists() + assert result == 0 diff --git a/trackastra/cli.py b/trackastra/cli.py index 55614fc..d9db966 100644 --- a/trackastra/cli.py +++ b/trackastra/cli.py @@ -1,11 +1,10 @@ import argparse import sys -from pathlib import Path import torch from .model import Trackastra -from .tracking.utils import graph_to_ctc +from .tracking.utils import graph_to_ctc, graph_to_edge_table from .utils import str2path @@ -31,14 +30,16 @@ def cli(): help="Directory with series of .tif files.", ) p_track.add_argument( - "-o", - "--outdir", + "--output-ctc", type=str2path, default=None, - help=( - "Directory for writing results (optional). Default writes to" - " `{masks}_tracked`." - ), + help="If set, write results in CTC format to this directory.", + ) + p_track.add_argument( + "--output-edge-table", + type=str2path, + default=None, + help="If set, write results as an edge table in CSV format to the given file.", ) p_track.add_argument( "--model-pretrained", @@ -93,14 +94,19 @@ def _track_from_disk(args): mode=args.mode, ) - if args.outdir is None: - outdir = Path(f"{args.masks}_tracked") - else: - outdir = args.outdir + if args.output_ctc: + outdir = args.output_ctc + outdir.mkdir(parents=True, exist_ok=True) + graph_to_ctc( + track_graph, + masks, + outdir=outdir, + ) - outdir.mkdir(parents=True, exist_ok=True) - graph_to_ctc( - track_graph, - masks, - outdir=outdir, - ) + if args.output_edge_table: + outpath = args.output_edge_table + outpath.parent.mkdir(parents=True, exist_ok=True) + graph_to_edge_table( + graph=track_graph, + outpath=outpath, + ) diff --git a/trackastra/model/model_api.py b/trackastra/model/model_api.py index b9952a3..0dda834 100644 --- a/trackastra/model/model_api.py +++ b/trackastra/model/model_api.py @@ -4,6 +4,7 @@ from typing import Literal import numpy as np +import tifffile import torch import yaml from tqdm import tqdm @@ -159,11 +160,13 @@ def track_from_disk( Args: imgs_path: - Directory containing a series of numbered tiff files. - Each file contains an image of shape (C),(Z),Y,X. + Options + - Directory containing a series of numbered tiff files. Each file contains an image of shape (C),(Z),Y,X. + - Single tiff file with time series of shape T,(C),(Z),Y,X. masks_path: - Directory containing a series of numbered tiff files. - Each file contains an image of shape (Z), Y, X. + Options + - Directory containing a series of numbered tiff files. Each file contains an image of shape (C),(Z),Y,X. + - Single tiff file with time series of shape T,(Z),Y,X. mode (optional): Mode for candidate graph pruning. """ @@ -172,11 +175,15 @@ def track_from_disk( if not masks_path.exists(): raise FileNotFoundError(f"{masks_path=} does not exist.") - if not imgs_path.is_dir() or not masks_path.is_dir(): - raise NotImplementedError("Currently only tiff sequences are supported.") + if imgs_path.is_dir(): + imgs = load_tiff_timeseries(imgs_path) + else: + imgs = tifffile.imread(imgs_path) - imgs = load_tiff_timeseries(imgs_path) - masks = load_tiff_timeseries(masks_path) + if masks_path.is_dir(): + masks = load_tiff_timeseries(masks_path) + else: + masks = tifffile.imread(masks_path) if len(imgs) != len(masks): raise RuntimeError( diff --git a/trackastra/tracking/__init__.py b/trackastra/tracking/__init__.py index fb868e9..4d11d41 100644 --- a/trackastra/tracking/__init__.py +++ b/trackastra/tracking/__init__.py @@ -8,6 +8,7 @@ from .utils import ( ctc_to_napari_tracks, graph_to_ctc, + graph_to_edge_table, graph_to_napari_tracks, linear_chains, ) diff --git a/trackastra/tracking/utils.py b/trackastra/tracking/utils.py index 20715f2..8061bf4 100644 --- a/trackastra/tracking/utils.py +++ b/trackastra/tracking/utils.py @@ -209,6 +209,67 @@ def _check_ctc_df(df: pd.DataFrame, masks: np.ndarray): return True +def graph_to_edge_table( + graph: nx.DiGraph, + frame_attribute: str = "time", + edge_attribute: str = "weight", + outpath: Path | None = None, +) -> pd.DataFrame: + """Write edges of a graph to a table. + + The table has columns `source_frame`, `source_label`, `target_frame`, `target_label`, and `weight`. + The first line is a header. The source and target are the labels of the objects in the + input masks in the designated frames (0-indexed). + + Args: + graph: With node attributes `frame_attribute`, `edge_attribute` and 'label'. + frame_attribute: Name of the frame attribute 'graph`. + edge_attribute: Name of the score attribute in `graph`. + outpath: If given, save the edges in CSV file format. + + Returns: + pd.DataFrame: Edges DataFrame with columns ['source_frame', 'source', 'target_frame', 'target', 'weight'] + """ + rows = [] + for edge in graph.edges: + source = graph.nodes[edge[0]] + target = graph.nodes[edge[1]] + + source_label = int(source["label"]) + source_frame = int(source[frame_attribute]) + target_label = int(target["label"]) + target_frame = int(target[frame_attribute]) + weight = float(graph.edges[edge][edge_attribute]) + + rows.append([source_frame, source_label, target_frame, target_label, weight]) + + df = pd.DataFrame( + rows, + columns=[ + "source_frame", + "source_label", + "target_frame", + "target_label", + "weight", + ], + ) + df = df.sort_values( + by=["source_frame", "source_label", "target_frame", "target_label"], + ascending=True, + ) + + if outpath is not None: + outpath = Path(outpath) + outpath.parent.mkdir( + parents=True, + exist_ok=True, + ) + + df.to_csv(outpath, index=False, header=True, sep=",") + + return df + + def graph_to_ctc( graph: nx.DiGraph, masks_original: np.ndarray,