Skip to content

Commit

Permalink
Implement ROIs within image_labeling (ref #115)
Browse files Browse the repository at this point in the history
* Generalize image_labeling to work with ROIs;
* Generalize image_labeling to work at arbitrary resolution level;
* Add lib_zattrs_utils.py, also with rescale_datasets;
* Rename extract_zyx_pixel_sizes_from_zattrs to extract_zyx_pixel_sizes;
  • Loading branch information
tcompa committed Jul 25, 2022
1 parent 2194f15 commit ff13273
Show file tree
Hide file tree
Showing 6 changed files with 157 additions and 148 deletions.
10 changes: 5 additions & 5 deletions fractal/tasks/illumination_correction.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,10 @@

from fractal.tasks.lib_pyramid_creation import write_pyramid
from fractal.tasks.lib_regions_of_interest import convert_ROI_table_to_indices
from fractal.tasks.lib_regions_of_interest import (
extract_zyx_pixel_sizes_from_zattrs,
)
from fractal.tasks.lib_regions_of_interest import (
split_3D_indices_into_z_layers,
)
from fractal.tasks.lib_zattrs_utils import extract_zyx_pixel_sizes


def correct(
Expand Down Expand Up @@ -141,14 +139,16 @@ def illumination_correction(
FOV_ROI_table = ad.read_zarr(f"{zarrurl}tables/FOV_ROI_table")

# Read pixel sizes from zattrs file
pixel_sizes_zyx = extract_zyx_pixel_sizes_from_zattrs(zarrurl + ".zattrs")
full_res_pxl_sizes_zyx = extract_zyx_pixel_sizes(
zarrurl + ".zattrs", level=0
)

# Create list of indices for 3D FOVs spanning the entire Z direction
list_indices = convert_ROI_table_to_indices(
FOV_ROI_table,
level=0,
coarsening_xy=coarsening_xy,
pixel_sizes_zyx=pixel_sizes_zyx,
full_res_pxl_sizes_zyx=full_res_pxl_sizes_zyx,
)

# Extract image size from FOV-ROI indices
Expand Down
85 changes: 37 additions & 48 deletions fractal/tasks/image_labeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,8 @@

from fractal.tasks.lib_pyramid_creation import write_pyramid
from fractal.tasks.lib_regions_of_interest import convert_ROI_table_to_indices
from fractal.tasks.lib_regions_of_interest import (
extract_zyx_pixel_sizes_from_zattrs,
)
from fractal.tasks.lib_zattrs_utils import extract_zyx_pixel_sizes
from fractal.tasks.lib_zattrs_utils import rescale_datasets


def segment_FOV(
Expand All @@ -44,7 +43,8 @@ def segment_FOV(
logfile="LOG_image_labeling",
):

chunk_location = block_info[None]["chunk-location"]
# chunk_location = block_info[None]["chunk-location"]
chunk_location = "dummy"

# Write some debugging info
with open(logfile, "a") as out:
Expand All @@ -66,9 +66,6 @@ def segment_FOV(
anisotropy=anisotropy,
cellprob_threshold=cellprob_threshold,
)
"""
mask = np.zeros_like(column)
"""
if not do_3D:
mask = np.expand_dims(mask, axis=0)
t1 = time.perf_counter()
Expand All @@ -89,13 +86,13 @@ def segment_FOV(
def image_labeling(
zarrurl,
coarsening_xy=2,
labeling_level=0,
labeling_level=1,
labeling_channel=None,
chl_list=None,
num_threads=1,
relabeling=True,
anisotropy=None,
diameter=None,
diameter_level0=80.0,
cellprob_threshold=None,
model_type="nuclei",
):
Expand All @@ -118,12 +115,6 @@ def image_labeling(
raise Exception(f"ERROR: {labeling_channel} not in {chl_list}")
ind_channel = chl_list.index(labeling_channel)

# Check that level=0
if labeling_level > 0:
raise NotImplementedError(
"By now we can only segment the highest-resolution level"
)

# Set labels dtype
label_dtype = np.uint32

Expand All @@ -134,28 +125,36 @@ def image_labeling(
FOV_ROI_table = ad.read_zarr(f"{zarrurl}tables/FOV_ROI_table")

# Read pixel sizes from zattrs file
pixel_sizes_zyx = extract_zyx_pixel_sizes_from_zattrs(zarrurl + ".zattrs")
full_res_pxl_sizes_zyx = extract_zyx_pixel_sizes(
zarrurl + ".zattrs", level=0
)

# Create list of indices for 3D FOVs spanning the entire Z direction
list_indices = convert_ROI_table_to_indices(
FOV_ROI_table,
level=0,
level=labeling_level,
coarsening_xy=coarsening_xy,
pixel_sizes_zyx=pixel_sizes_zyx,
full_res_pxl_sizes_zyx=full_res_pxl_sizes_zyx,
)

# Extract image size from FOV-ROI indices
# Note: this works at level=0, where FOVs should all be of the exact same
# size (in pixels)
list_indices_level0 = convert_ROI_table_to_indices(
FOV_ROI_table,
level=0,
full_res_pxl_sizes_zyx=full_res_pxl_sizes_zyx,
)
ref_img_size = None
for indices in list_indices:
for indices in list_indices_level0:
img_size = (indices[3] - indices[2], indices[5] - indices[4])
if ref_img_size is None:
ref_img_size = img_size
else:
if img_size != ref_img_size:
raise Exception(
"ERROR: inconsistent image sizes in list_indices"
"ERROR: inconsistent image sizes in list_indices",
list_indices,
)
img_size_y, img_size_x = img_size[:]

Expand All @@ -164,7 +163,7 @@ def image_labeling(
if do_3D:
if anisotropy is None:
# Read pixel sizes from zattrs file
pxl_zyx = extract_zyx_pixel_sizes_from_zattrs(
pxl_zyx = extract_zyx_pixel_sizes(
zarrurl + ".zattrs", level=labeling_level
)
pixel_size_z, pixel_size_y, pixel_size_x = pxl_zyx[:]
Expand All @@ -177,7 +176,7 @@ def image_labeling(
anisotropy = pixel_size_z / pixel_size_x
else:
raise NotImplementedError(
"TODO: check the integration of 2D labeling with ROIs"
"TODO: The integration of 2D labeling with ROIs is not ready yet"
)

# Check model_type
Expand All @@ -201,8 +200,6 @@ def image_labeling(

# Extract num_levels
num_levels = len(multiscales[0]["datasets"])
print("num_levels", num_levels)
print()

# Extract axes, and remove channel
new_axes = [ax for ax in multiscales[0]["axes"] if ax["type"] != "channel"]
Expand All @@ -216,6 +213,7 @@ def image_labeling(

# Check that input array is made of images (in terms of shape/chunks)
nz, ny, nx = data_zyx.shape
"""
if (ny % img_size_y != 0) or (nx % img_size_x != 0):
raise Exception(
"Error in image_labeling, data_zyx.shape: {data_zyx.shape}"
Expand All @@ -227,6 +225,7 @@ def image_labeling(
raise Exception(f"Error in image_labeling, chunks_y: {chunks_y}")
if len(set(chunks_x)) != 1 or chunks_x[0] != img_size_x:
raise Exception(f"Error in image_labeling, chunks_x: {chunks_x}")
"""

# Initialize cellpose
use_gpu = core.use_gpu()
Expand All @@ -246,12 +245,14 @@ def image_labeling(
# Prepare delayed function
delayed_segment_FOV = dask.delayed(segment_FOV)

# Prepare empty mask
# Prepare empty mask with correct chunks
mask = da.empty(
data_zyx.shape, dtype=label_dtype, chunks=(1, img_size_y, img_size_x)
data_zyx.shape,
dtype=label_dtype,
chunks=(1, img_size_y, img_size_x),
)

# Map labeling function onto all FOVs
# Map labeling function onto all FOV ROIs
for indices in list_indices:
s_z, e_z, s_y, e_y, s_x, e_x = indices[:]
shape = [e_z - s_z, e_y - s_y, e_x - s_x]
Expand All @@ -263,39 +264,27 @@ def image_labeling(
do_3D=do_3D,
anisotropy=anisotropy,
label_dtype=label_dtype,
diameter=diameter,
diameter=diameter_level0 / coarsening_xy**labeling_level,
cellprob_threshold=cellprob_threshold,
logfile=logfile,
)
mask[s_z:e_z, s_y:e_y, s_x:e_x] = da.from_delayed(
FOV_mask, shape, label_dtype
)

# Map labeling function onto all chunks (i.e., FOV columns)
"""
mask = (
data_zyx.rechunk((nz, img_size_y, img_size_x))
.map_blocks(
segment_FOV,
meta=np.array((), dtype=label_dtype),
model=model,
do_3D=do_3D,
anisotropy=anisotropy,
label_dtype=label_dtype,
diameter=diameter,
cellprob_threshold=cellprob_threshold,
logfile=logfile,
)
.rechunk((1, img_size_y, img_size_x))
)
"""

with open(logfile, "a") as out:
out.write(
f"After map_block, mask will have shape {mask.shape} "
f"and chunks {mask.chunks}\n\n"
)

# Rescale datasets (only relevant for labeling_level>0)
new_datasets = rescale_datasets(
datasets=multiscales[0]["datasets"],
coarsening_xy=coarsening_xy,
reference_level=labeling_level,
)

# Write zattrs for labels and for specific label
# FIXME deal with: (1) many channels, (2) overwriting
labels_group = zarr.group(f"{zarrurl}labels")
Expand All @@ -307,7 +296,7 @@ def image_labeling(
"name": label_name,
"version": "0.4",
"axes": new_axes,
"datasets": multiscales[0]["datasets"],
"datasets": new_datasets,
}
]

Expand Down
35 changes: 7 additions & 28 deletions fractal/tasks/image_labeling_whole_well.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from cellpose import models

from fractal.tasks.lib_pyramid_creation import write_pyramid
from fractal.tasks.lib_zattrs_utils import rescale_datasets


def image_labeling_whole_well(
Expand Down Expand Up @@ -155,34 +156,12 @@ def image_labeling_whole_well(
f"da.from_array(upscaled_mask) [with rechunking]: {mask_da}\n\n"
)

# Construct rescaled datasets
datasets = multiscales[0]["datasets"]
new_datasets = []
for ds in datasets:
new_ds = {}

# Copy all keys that are not coordinateTransformations (e.g. path)
for key in ds.keys():
if key != "coordinateTransformations":
new_ds[key] = ds[key]

# Update coordinateTransformations
old_transformations = ds["coordinateTransformations"]
new_transformations = []
for t in old_transformations:
if t["type"] == "scale":
new_t = {"type": "scale"}
new_t["scale"] = [
t["scale"][0],
t["scale"][1] * coarsening_xy**labeling_level,
t["scale"][2] * coarsening_xy**labeling_level,
]
new_transformations.append(new_t)
else:
new_transformations.append(t)
new_ds["coordinateTransformations"] = new_transformations

new_datasets.append(new_ds)
# Rescale datasets (only relevant for labeling_level>0)
new_datasets = rescale_datasets(
datasets=multiscales[0]["datasets"],
coarsening_xy=coarsening_xy,
reference_level=labeling_level,
)

# Write zattrs for labels and for specific label
# FIXME deal with: (1) many channels, (2) overwriting
Expand Down
Loading

0 comments on commit ff13273

Please sign in to comment.