Skip to content

Commit

Permalink
Merge pull request #8 from tinevez/main
Browse files Browse the repository at this point in the history
I/O utilities to facilitate TrackMate interoperability

Co-Authored-By: Jean-Yves Tinevez <[email protected]>
  • Loading branch information
bentaculum and tinevez authored Jun 17, 2024
2 parents c2178f0 + 58671f4 commit 4ec19dc
Show file tree
Hide file tree
Showing 6 changed files with 128 additions and 36 deletions.
2 changes: 0 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ ilp =
motile >= 0.2
dev =
pytest
shell
ruff
black
mypy
Expand All @@ -56,7 +55,6 @@ dev =
build
test =
pytest
shell

[options.entry_points]
console_scripts =
Expand Down
35 changes: 27 additions & 8 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,35 @@
from shell import shell
import os
from pathlib import Path

from test_data import example_dataset


def test_cli_parser():
result = shell("trackastra")
assert result.code == 0
result = os.system("trackastra")
assert result == 0


def test_cli_tracking():
def test_cli_tracking_from_folder():
example_dataset()
result = shell(
"trackastra track -i test_data/img -m test_data/TRA --model-pretrained general_2d" # noqa: RUF100
)
assert result.code == 0
cmd = "trackastra track -i test_data/img -m test_data/TRA --output-ctc test_data/tracked --output-edge-table test_data/tracked.csv --model-pretrained general_2d" # noqa: RUF100
print(cmd)
result = os.system(cmd)
assert Path("test_data/tracked").exists()
assert Path("test_data/tracked.csv").exists()
assert result == 0


def test_cli_tracking_from_file():
root = Path(__file__).parent.parent / "trackastra" / "data" / "resources"
print(root)
assert root.exists()
output_ctc = Path("test_data") / "tracked_bacteria"
output_edge_table = Path("test_data") / "tracked_bacteria.csv"

cmd = f"trackastra track -i {root / 'trpL_150310-11_img.tif'} -m {root / 'trpL_150310-11_mask.tif'} --output-ctc {output_ctc} --output-edge-table {output_edge_table} --model-pretrained general_2d" # noqa: RUF100
print(cmd)
result = os.system(cmd)

assert output_ctc.exists()
assert output_edge_table.exists()
assert result == 0
42 changes: 24 additions & 18 deletions trackastra/cli.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import argparse
import sys
from pathlib import Path

import torch

from .model import Trackastra
from .tracking.utils import graph_to_ctc
from .tracking.utils import graph_to_ctc, graph_to_edge_table
from .utils import str2path


Expand All @@ -31,14 +30,16 @@ def cli():
help="Directory with series of .tif files.",
)
p_track.add_argument(
"-o",
"--outdir",
"--output-ctc",
type=str2path,
default=None,
help=(
"Directory for writing results (optional). Default writes to"
" `{masks}_tracked`."
),
help="If set, write results in CTC format to this directory.",
)
p_track.add_argument(
"--output-edge-table",
type=str2path,
default=None,
help="If set, write results as an edge table in CSV format to the given file.",
)
p_track.add_argument(
"--model-pretrained",
Expand Down Expand Up @@ -93,14 +94,19 @@ def _track_from_disk(args):
mode=args.mode,
)

if args.outdir is None:
outdir = Path(f"{args.masks}_tracked")
else:
outdir = args.outdir
if args.output_ctc:
outdir = args.output_ctc
outdir.mkdir(parents=True, exist_ok=True)
graph_to_ctc(
track_graph,
masks,
outdir=outdir,
)

outdir.mkdir(parents=True, exist_ok=True)
graph_to_ctc(
track_graph,
masks,
outdir=outdir,
)
if args.output_edge_table:
outpath = args.output_edge_table
outpath.parent.mkdir(parents=True, exist_ok=True)
graph_to_edge_table(
graph=track_graph,
outpath=outpath,
)
23 changes: 15 additions & 8 deletions trackastra/model/model_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Literal

import numpy as np
import tifffile
import torch
import yaml
from tqdm import tqdm
Expand Down Expand Up @@ -159,11 +160,13 @@ def track_from_disk(
Args:
imgs_path:
Directory containing a series of numbered tiff files.
Each file contains an image of shape (C),(Z),Y,X.
Options
- Directory containing a series of numbered tiff files. Each file contains an image of shape (C),(Z),Y,X.
- Single tiff file with time series of shape T,(C),(Z),Y,X.
masks_path:
Directory containing a series of numbered tiff files.
Each file contains an image of shape (Z), Y, X.
Options
- Directory containing a series of numbered tiff files. Each file contains an image of shape (C),(Z),Y,X.
- Single tiff file with time series of shape T,(Z),Y,X.
mode (optional):
Mode for candidate graph pruning.
"""
Expand All @@ -172,11 +175,15 @@ def track_from_disk(
if not masks_path.exists():
raise FileNotFoundError(f"{masks_path=} does not exist.")

if not imgs_path.is_dir() or not masks_path.is_dir():
raise NotImplementedError("Currently only tiff sequences are supported.")
if imgs_path.is_dir():
imgs = load_tiff_timeseries(imgs_path)
else:
imgs = tifffile.imread(imgs_path)

imgs = load_tiff_timeseries(imgs_path)
masks = load_tiff_timeseries(masks_path)
if masks_path.is_dir():
masks = load_tiff_timeseries(masks_path)
else:
masks = tifffile.imread(masks_path)

if len(imgs) != len(masks):
raise RuntimeError(
Expand Down
1 change: 1 addition & 0 deletions trackastra/tracking/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .utils import (
ctc_to_napari_tracks,
graph_to_ctc,
graph_to_edge_table,
graph_to_napari_tracks,
linear_chains,
)
61 changes: 61 additions & 0 deletions trackastra/tracking/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,67 @@ def _check_ctc_df(df: pd.DataFrame, masks: np.ndarray):
return True


def graph_to_edge_table(
graph: nx.DiGraph,
frame_attribute: str = "time",
edge_attribute: str = "weight",
outpath: Path | None = None,
) -> pd.DataFrame:
"""Write edges of a graph to a table.
The table has columns `source_frame`, `source_label`, `target_frame`, `target_label`, and `weight`.
The first line is a header. The source and target are the labels of the objects in the
input masks in the designated frames (0-indexed).
Args:
graph: With node attributes `frame_attribute`, `edge_attribute` and 'label'.
frame_attribute: Name of the frame attribute 'graph`.
edge_attribute: Name of the score attribute in `graph`.
outpath: If given, save the edges in CSV file format.
Returns:
pd.DataFrame: Edges DataFrame with columns ['source_frame', 'source', 'target_frame', 'target', 'weight']
"""
rows = []
for edge in graph.edges:
source = graph.nodes[edge[0]]
target = graph.nodes[edge[1]]

source_label = int(source["label"])
source_frame = int(source[frame_attribute])
target_label = int(target["label"])
target_frame = int(target[frame_attribute])
weight = float(graph.edges[edge][edge_attribute])

rows.append([source_frame, source_label, target_frame, target_label, weight])

df = pd.DataFrame(
rows,
columns=[
"source_frame",
"source_label",
"target_frame",
"target_label",
"weight",
],
)
df = df.sort_values(
by=["source_frame", "source_label", "target_frame", "target_label"],
ascending=True,
)

if outpath is not None:
outpath = Path(outpath)
outpath.parent.mkdir(
parents=True,
exist_ok=True,
)

df.to_csv(outpath, index=False, header=True, sep=",")

return df


def graph_to_ctc(
graph: nx.DiGraph,
masks_original: np.ndarray,
Expand Down

0 comments on commit 4ec19dc

Please sign in to comment.