Skip to content

Commit

Permalink
Update image_labeling task (ref #64)
Browse files Browse the repository at this point in the history
* Temporarily remove relabeling;
* Switch to dask.array.map_blocks;
* Limit number of simultaneous cellpose executions;
* Make sure that cellpose is not called more than once per site during pyramid creation (ref #97);
* Use np.max as aggregation function (ref #97).
  • Loading branch information
tcompa committed Jul 7, 2022
1 parent 271deb5 commit 03a146a
Showing 1 changed file with 91 additions and 76 deletions.
167 changes: 91 additions & 76 deletions fractal/tasks/image_labeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@
Institute for Biomedical Research and Pelkmans Lab from the University of
Zurich.
"""
import itertools
import json
import time
from concurrent.futures import ThreadPoolExecutor

import dask
import dask.array as da
import numpy as np
import zarr
Expand All @@ -26,13 +27,27 @@

def apply_label_to_single_FOV_column(
column,
block_info=None,
model=None,
do_3D=True,
anisotropy=None,
diameter=40.0,
cellprob_threshold=0.0,
label_dtype=None,
):

chunk_location = block_info[None]["chunk-location"]

# Write some debugging info
with open("LOG_image_labeling", "a") as out:
out.write(
f"[{chunk_location}] START Cellpose |"
f" column: {type(column)}, {column.shape} |"
f" do_3D: {do_3D}\n"
)

# Actual labeling
t0 = time.perf_counter()
mask, flows, styles, diams = model.eval(
column,
channels=[0, 0],
Expand All @@ -43,8 +58,21 @@ def apply_label_to_single_FOV_column(
anisotropy=anisotropy,
cellprob_threshold=cellprob_threshold,
)
if not do_3D:
mask = np.expand_dims(mask, axis=0)
t1 = time.perf_counter()

# Write some debugging info
with open("LOG_image_labeling", "a") as out:
out.write(
f"[{chunk_location}] END Cellpose |"
f" Elapsed: {t1-t0:.4f} seconds |"
f" mask shape: {mask.shape},"
f" mask dtype: {mask.dtype} (before recast to {label_dtype}),"
f" max(mask): {np.max(mask)}\n"
)

return mask
return mask.astype(label_dtype)


def image_labeling(
Expand All @@ -53,6 +81,7 @@ def image_labeling(
labeling_level=0,
labeling_channel=None,
chl_list=None,
num_threads=2,
# More parameters
anisotropy=None,
diameter=None,
Expand Down Expand Up @@ -135,93 +164,33 @@ def image_labeling(
model = models.Cellpose(gpu=use_gpu, model_type="nuclei")

# Initialize other things
num_labels_tot = 0
num_labels_column = 0
mask_rechunked = da.empty(
data_zyx_rechunked.shape,
chunks=data_zyx_rechunked.chunks,
dtype=label_dtype,
)

with open("LOG_image_labeling", "w") as out:
out.write(f"Start image_labeling task for {zarrurl}\n")
out.write(f"use_gpu: {use_gpu}\n")
out.write("Total well shape/chunks:\n")
out.write(f"{data_zyx_rechunked.shape}\n")
out.write(f"{data_zyx_rechunked.chunks}\n\n")

# Sequential labeling (and relabeling)
# https://stackoverflow.com/a/72018364/19085332
for inds in itertools.product(*map(range, mask_rechunked.blocks.shape)):

# Select a specific chunk (=column in 3D, =image in 2D)
column_data = data_zyx_rechunked.blocks[inds].compute()

# Write some debugging info
t0 = time.perf_counter()
with open("LOG_image_labeling", "a") as out:
out.write(f"Selected chunk: {inds}\n")
out.write("Now running cellpose\n")
out.write(
f"column_data: {type(column_data)}, {column_data.shape}\n"
)

# Perform segmentation
column_mask = apply_label_to_single_FOV_column(
column_data, model=model, do_3D=do_3D, anisotropy=anisotropy
)
if not do_3D:
column_mask = np.expand_dims(column_mask, axis=0)
num_labels_column = np.max(column_mask)
column_mask_recast = column_mask.astype(label_dtype)

# Apply re-labeling and update total number of labels
column_mask_recast[column_mask_recast > 0] += num_labels_tot
num_labels_tot += num_labels_column

# Check that total number of labels is under control
if num_labels_tot > np.iinfo(label_dtype).max - 1000:
raise Exception(
"ERROR in re-labeling:\n"
f"Reached {num_labels_tot} labels, "
f"but dtype={label_dtype}"
)

# Write some debugging info
t1 = time.perf_counter()
with open("LOG_image_labeling", "a") as out:
out.write(
f"End, dtype={column_mask_recast.dtype} "
f"shape={column_mask_recast.shape}\n"
)
out.write(f"Elapsed: {t1-t0:.4f} seconds\n")
out.write(
f"num_labels_column: {num_labels_column}\n"
f"num_labels_tot: {num_labels_tot}\n\n"
)

# Put data into the main array
# FIXME: is this out-of-memory?? I guess not!
start_z = inds[0] * nz
end_z = (inds[0] + 1) * nz
start_y = inds[1] * img_size_y
end_y = (inds[1] + 1) * img_size_y
start_x = inds[2] * img_size_x
end_x = (inds[2] + 1) * img_size_x
mask_rechunked[
start_z:end_z, start_y:end_y, start_x:end_x
] = column_mask_recast[:, :, :]
# Map labeling function onto all chunks (i.e., FOV colums)
mask_rechunked = data_zyx_rechunked.map_blocks(
apply_label_to_single_FOV_column,
chunks=data_zyx_rechunked.chunks,
meta=np.array((), dtype=label_dtype),
model=model,
do_3D=do_3D,
anisotropy=anisotropy,
label_dtype=label_dtype,
)

# Rechunk to get back to the original chunking (with separate Z planes)
mask = mask_rechunked.rechunk(data_zyx.chunks)

# Construct resolution pyramid
pyramid = create_pyramid_3D(
mask,
coarsening_z=1,
coarsening_xy=coarsening_xy,
num_levels=num_levels,
chunk_size_x=img_size_x,
chunk_size_y=img_size_y,
)

# Write zattrs for labels and for specific label
# FIXME deal with: (1) many channels, (2) overwriting
labels_group = zarr.group(f"{zarrurl}labels")
Expand All @@ -242,14 +211,60 @@ def image_labeling(
}
]

with dask.config.set(pool=ThreadPoolExecutor(num_threads)):
level0 = mask.to_zarr(
zarrurl,
component=f"labels/{label_name}/{0}",
dimension_separator="/",
return_stored=True,
)

# Construct resolution pyramid
pyramid = create_pyramid_3D(
level0,
coarsening_z=1,
coarsening_xy=coarsening_xy,
num_levels=num_levels,
chunk_size_x=img_size_x,
chunk_size_y=img_size_y,
aggregation_function=np.max,
)

# Write data into output zarr
for ind_level in range(num_levels):
for ind_level in range(1, num_levels):
pyramid[ind_level].astype(label_dtype).to_zarr(
zarrurl,
component=f"labels/{label_name}/{ind_level}",
dimension_separator="/",
)

"""
# APPLY RELABELING
# Load level-0 labels
newmask_rechunked = da.from_zarr(.f"{zarrurl}labels/{label_name}/{0}")
.rechunk(mask_rechunked.chunks)
# Sequential relabeling
# https://stackoverflow.com/a/72018364/19085332
num_labels_tot = 0
num_labels_column = 0
for inds in itertools.product(*map(range, newmask_rechunked.blocks.shape)):
# Select a specific chunk (=column in 3D, =image in 2D)
column_mask = newmask_rechunked.blocks[inds].compute()
num_labels_column = np.max(column_mask)
# Apply re-labeling and update total number of labels
column_mask[column_mask > 0] += num_labels_tot
num_labels_tot += num_labels_column
# Check that total number of labels is under control
if num_labels_tot > np.iinfo(label_dtype).max - 1000:
raise Exception(
"ERROR in re-labeling:\n"
f"Reached {num_labels_tot} labels, "
f"but dtype={label_dtype}"
)
"""


if __name__ == "__main__":
from argparse import ArgumentParser
Expand Down

0 comments on commit 03a146a

Please sign in to comment.