From ff13273aaac7e01d284fdd12ad45e500b04189a7 Mon Sep 17 00:00:00 2001 From: Tommaso Comparin <3862206+tcompa@users.noreply.github.com> Date: Mon, 25 Jul 2022 06:56:19 +0000 Subject: [PATCH] Implement ROIs within image_labeling (ref #115) * Generalize image_labeling to work with ROIs; * Generalize image_labeling to work at arbitrary resolution level; * Add lib_zattrs_utils.py, also with rescale_datasets; * Rename extract_zyx_pixel_sizes_from_zattrs to extract_zyx_pixel_sizes; --- fractal/tasks/illumination_correction.py | 10 +-- fractal/tasks/image_labeling.py | 85 ++++++++---------- fractal/tasks/image_labeling_whole_well.py | 35 ++------ fractal/tasks/lib_regions_of_interest.py | 88 ++++++------------- fractal/tasks/lib_zattrs_utils.py | 81 +++++++++++++++++ fractal/tasks/replicate_zarr_structure_mip.py | 6 +- 6 files changed, 157 insertions(+), 148 deletions(-) create mode 100644 fractal/tasks/lib_zattrs_utils.py diff --git a/fractal/tasks/illumination_correction.py b/fractal/tasks/illumination_correction.py index 35d0ab893..03d92d5c1 100644 --- a/fractal/tasks/illumination_correction.py +++ b/fractal/tasks/illumination_correction.py @@ -22,12 +22,10 @@ from fractal.tasks.lib_pyramid_creation import write_pyramid from fractal.tasks.lib_regions_of_interest import convert_ROI_table_to_indices -from fractal.tasks.lib_regions_of_interest import ( - extract_zyx_pixel_sizes_from_zattrs, -) from fractal.tasks.lib_regions_of_interest import ( split_3D_indices_into_z_layers, ) +from fractal.tasks.lib_zattrs_utils import extract_zyx_pixel_sizes def correct( @@ -141,14 +139,16 @@ def illumination_correction( FOV_ROI_table = ad.read_zarr(f"{zarrurl}tables/FOV_ROI_table") # Read pixel sizes from zattrs file - pixel_sizes_zyx = extract_zyx_pixel_sizes_from_zattrs(zarrurl + ".zattrs") + full_res_pxl_sizes_zyx = extract_zyx_pixel_sizes( + zarrurl + ".zattrs", level=0 + ) # Create list of indices for 3D FOVs spanning the entire Z direction list_indices = convert_ROI_table_to_indices( FOV_ROI_table, level=0, coarsening_xy=coarsening_xy, - pixel_sizes_zyx=pixel_sizes_zyx, + full_res_pxl_sizes_zyx=full_res_pxl_sizes_zyx, ) # Extract image size from FOV-ROI indices diff --git a/fractal/tasks/image_labeling.py b/fractal/tasks/image_labeling.py index c44e41bdb..7544c3775 100644 --- a/fractal/tasks/image_labeling.py +++ b/fractal/tasks/image_labeling.py @@ -27,9 +27,8 @@ from fractal.tasks.lib_pyramid_creation import write_pyramid from fractal.tasks.lib_regions_of_interest import convert_ROI_table_to_indices -from fractal.tasks.lib_regions_of_interest import ( - extract_zyx_pixel_sizes_from_zattrs, -) +from fractal.tasks.lib_zattrs_utils import extract_zyx_pixel_sizes +from fractal.tasks.lib_zattrs_utils import rescale_datasets def segment_FOV( @@ -44,7 +43,8 @@ def segment_FOV( logfile="LOG_image_labeling", ): - chunk_location = block_info[None]["chunk-location"] + # chunk_location = block_info[None]["chunk-location"] + chunk_location = "dummy" # Write some debugging info with open(logfile, "a") as out: @@ -66,9 +66,6 @@ def segment_FOV( anisotropy=anisotropy, cellprob_threshold=cellprob_threshold, ) - """ - mask = np.zeros_like(column) - """ if not do_3D: mask = np.expand_dims(mask, axis=0) t1 = time.perf_counter() @@ -89,13 +86,13 @@ def segment_FOV( def image_labeling( zarrurl, coarsening_xy=2, - labeling_level=0, + labeling_level=1, labeling_channel=None, chl_list=None, num_threads=1, relabeling=True, anisotropy=None, - diameter=None, + diameter_level0=80.0, cellprob_threshold=None, model_type="nuclei", ): @@ -118,12 +115,6 @@ def image_labeling( raise Exception(f"ERROR: {labeling_channel} not in {chl_list}") ind_channel = chl_list.index(labeling_channel) - # Check that level=0 - if labeling_level > 0: - raise NotImplementedError( - "By now we can only segment the highest-resolution level" - ) - # Set labels dtype label_dtype = np.uint32 @@ -134,28 +125,36 @@ def image_labeling( FOV_ROI_table = ad.read_zarr(f"{zarrurl}tables/FOV_ROI_table") # Read pixel sizes from zattrs file - pixel_sizes_zyx = extract_zyx_pixel_sizes_from_zattrs(zarrurl + ".zattrs") + full_res_pxl_sizes_zyx = extract_zyx_pixel_sizes( + zarrurl + ".zattrs", level=0 + ) # Create list of indices for 3D FOVs spanning the entire Z direction list_indices = convert_ROI_table_to_indices( FOV_ROI_table, - level=0, + level=labeling_level, coarsening_xy=coarsening_xy, - pixel_sizes_zyx=pixel_sizes_zyx, + full_res_pxl_sizes_zyx=full_res_pxl_sizes_zyx, ) # Extract image size from FOV-ROI indices # Note: this works at level=0, where FOVs should all be of the exact same # size (in pixels) + list_indices_level0 = convert_ROI_table_to_indices( + FOV_ROI_table, + level=0, + full_res_pxl_sizes_zyx=full_res_pxl_sizes_zyx, + ) ref_img_size = None - for indices in list_indices: + for indices in list_indices_level0: img_size = (indices[3] - indices[2], indices[5] - indices[4]) if ref_img_size is None: ref_img_size = img_size else: if img_size != ref_img_size: raise Exception( - "ERROR: inconsistent image sizes in list_indices" + "ERROR: inconsistent image sizes in list_indices", + list_indices, ) img_size_y, img_size_x = img_size[:] @@ -164,7 +163,7 @@ def image_labeling( if do_3D: if anisotropy is None: # Read pixel sizes from zattrs file - pxl_zyx = extract_zyx_pixel_sizes_from_zattrs( + pxl_zyx = extract_zyx_pixel_sizes( zarrurl + ".zattrs", level=labeling_level ) pixel_size_z, pixel_size_y, pixel_size_x = pxl_zyx[:] @@ -177,7 +176,7 @@ def image_labeling( anisotropy = pixel_size_z / pixel_size_x else: raise NotImplementedError( - "TODO: check the integration of 2D labeling with ROIs" + "TODO: The integration of 2D labeling with ROIs is not ready yet" ) # Check model_type @@ -201,8 +200,6 @@ def image_labeling( # Extract num_levels num_levels = len(multiscales[0]["datasets"]) - print("num_levels", num_levels) - print() # Extract axes, and remove channel new_axes = [ax for ax in multiscales[0]["axes"] if ax["type"] != "channel"] @@ -216,6 +213,7 @@ def image_labeling( # Check that input array is made of images (in terms of shape/chunks) nz, ny, nx = data_zyx.shape + """ if (ny % img_size_y != 0) or (nx % img_size_x != 0): raise Exception( "Error in image_labeling, data_zyx.shape: {data_zyx.shape}" @@ -227,6 +225,7 @@ def image_labeling( raise Exception(f"Error in image_labeling, chunks_y: {chunks_y}") if len(set(chunks_x)) != 1 or chunks_x[0] != img_size_x: raise Exception(f"Error in image_labeling, chunks_x: {chunks_x}") + """ # Initialize cellpose use_gpu = core.use_gpu() @@ -246,12 +245,14 @@ def image_labeling( # Prepare delayed function delayed_segment_FOV = dask.delayed(segment_FOV) - # Prepare empty mask + # Prepare empty mask with correct chunks mask = da.empty( - data_zyx.shape, dtype=label_dtype, chunks=(1, img_size_y, img_size_x) + data_zyx.shape, + dtype=label_dtype, + chunks=(1, img_size_y, img_size_x), ) - # Map labeling function onto all FOVs + # Map labeling function onto all FOV ROIs for indices in list_indices: s_z, e_z, s_y, e_y, s_x, e_x = indices[:] shape = [e_z - s_z, e_y - s_y, e_x - s_x] @@ -263,7 +264,7 @@ def image_labeling( do_3D=do_3D, anisotropy=anisotropy, label_dtype=label_dtype, - diameter=diameter, + diameter=diameter_level0 / coarsening_xy**labeling_level, cellprob_threshold=cellprob_threshold, logfile=logfile, ) @@ -271,31 +272,19 @@ def image_labeling( FOV_mask, shape, label_dtype ) - # Map labeling function onto all chunks (i.e., FOV columns) - """ - mask = ( - data_zyx.rechunk((nz, img_size_y, img_size_x)) - .map_blocks( - segment_FOV, - meta=np.array((), dtype=label_dtype), - model=model, - do_3D=do_3D, - anisotropy=anisotropy, - label_dtype=label_dtype, - diameter=diameter, - cellprob_threshold=cellprob_threshold, - logfile=logfile, - ) - .rechunk((1, img_size_y, img_size_x)) - ) - """ - with open(logfile, "a") as out: out.write( f"After map_block, mask will have shape {mask.shape} " f"and chunks {mask.chunks}\n\n" ) + # Rescale datasets (only relevant for labeling_level>0) + new_datasets = rescale_datasets( + datasets=multiscales[0]["datasets"], + coarsening_xy=coarsening_xy, + reference_level=labeling_level, + ) + # Write zattrs for labels and for specific label # FIXME deal with: (1) many channels, (2) overwriting labels_group = zarr.group(f"{zarrurl}labels") @@ -307,7 +296,7 @@ def image_labeling( "name": label_name, "version": "0.4", "axes": new_axes, - "datasets": multiscales[0]["datasets"], + "datasets": new_datasets, } ] diff --git a/fractal/tasks/image_labeling_whole_well.py b/fractal/tasks/image_labeling_whole_well.py index 846a5de47..3f77fa804 100644 --- a/fractal/tasks/image_labeling_whole_well.py +++ b/fractal/tasks/image_labeling_whole_well.py @@ -21,6 +21,7 @@ from cellpose import models from fractal.tasks.lib_pyramid_creation import write_pyramid +from fractal.tasks.lib_zattrs_utils import rescale_datasets def image_labeling_whole_well( @@ -155,34 +156,12 @@ def image_labeling_whole_well( f"da.from_array(upscaled_mask) [with rechunking]: {mask_da}\n\n" ) - # Construct rescaled datasets - datasets = multiscales[0]["datasets"] - new_datasets = [] - for ds in datasets: - new_ds = {} - - # Copy all keys that are not coordinateTransformations (e.g. path) - for key in ds.keys(): - if key != "coordinateTransformations": - new_ds[key] = ds[key] - - # Update coordinateTransformations - old_transformations = ds["coordinateTransformations"] - new_transformations = [] - for t in old_transformations: - if t["type"] == "scale": - new_t = {"type": "scale"} - new_t["scale"] = [ - t["scale"][0], - t["scale"][1] * coarsening_xy**labeling_level, - t["scale"][2] * coarsening_xy**labeling_level, - ] - new_transformations.append(new_t) - else: - new_transformations.append(t) - new_ds["coordinateTransformations"] = new_transformations - - new_datasets.append(new_ds) + # Rescale datasets (only relevant for labeling_level>0) + new_datasets = rescale_datasets( + datasets=multiscales[0]["datasets"], + coarsening_xy=coarsening_xy, + reference_level=labeling_level, + ) # Write zattrs for labels and for specific label # FIXME deal with: (1) many channels, (2) overwriting diff --git a/fractal/tasks/lib_regions_of_interest.py b/fractal/tasks/lib_regions_of_interest.py index 8eb62ec6b..5d491574d 100644 --- a/fractal/tasks/lib_regions_of_interest.py +++ b/fractal/tasks/lib_regions_of_interest.py @@ -1,4 +1,3 @@ -import json import math from typing import List from typing import Tuple @@ -79,13 +78,16 @@ def convert_ROI_table_to_indices( ROI: ad.AnnData, level: int = 0, coarsening_xy: int = 2, - pixel_sizes_zyx: Union[List[float], Tuple[float]] = None, + full_res_pxl_sizes_zyx: Union[List[float], Tuple[float]] = None, ) -> List[List[int]]: - list_indices = [] - - pixel_size_z, pixel_size_y, pixel_size_x = pixel_sizes_zyx + # Set pyramid-level pixel sizes + pxl_size_z, pxl_size_y, pxl_size_x = full_res_pxl_sizes_zyx + prefactor = coarsening_xy**level + pxl_size_x *= prefactor + pxl_size_y *= prefactor + list_indices = [] for FOV in sorted(ROI.obs_names): # Extract data from anndata table @@ -96,18 +98,13 @@ def convert_ROI_table_to_indices( len_y_micrometer = ROI[FOV, "len_y_micrometer"].X[0, 0] len_z_micrometer = ROI[FOV, "len_z_micrometer"].X[0, 0] - # Set pyramid-level pixel sizes - prefactor = coarsening_xy**level - pixel_size_x *= prefactor - pixel_size_y *= prefactor - # Identify indices along the three dimensions - start_x = x_micrometer / pixel_size_x - end_x = (x_micrometer + len_x_micrometer) / pixel_size_x - start_y = y_micrometer / pixel_size_y - end_y = (y_micrometer + len_y_micrometer) / pixel_size_y - start_z = z_micrometer / pixel_size_z - end_z = (z_micrometer + len_z_micrometer) / pixel_size_z + start_x = x_micrometer / pxl_size_x + end_x = (x_micrometer + len_x_micrometer) / pxl_size_x + start_y = y_micrometer / pxl_size_y + end_y = (y_micrometer + len_y_micrometer) / pxl_size_y + start_z = z_micrometer / pxl_size_z + end_z = (z_micrometer + len_z_micrometer) / pxl_size_z indices = [start_z, end_z, start_y, end_y, start_x, end_x] # Round indices to lower integer @@ -115,7 +112,7 @@ def convert_ROI_table_to_indices( indices = list(map(math.floor, indices)) # Append ROI indices to to list - list_indices.append(indices) + list_indices.append(indices[:]) return list_indices @@ -145,7 +142,7 @@ def _inspect_ROI_table( path: str = None, level: int = 0, coarsening_xy: int = 2, - pixel_sizes_zyx=[1.0, 0.1625, 0.1625], + full_res_pxl_sizes_zyx=[1.0, 0.1625, 0.1625], ) -> None: adata = ad.read_zarr(path) @@ -158,7 +155,7 @@ def _inspect_ROI_table( adata, level=level, coarsening_xy=coarsening_xy, - pixel_sizes_zyx=pixel_sizes_zyx, + full_res_pxl_sizes_zyx=full_res_pxl_sizes_zyx, ) list_indices = split_3D_indices_into_z_layers(list_indices) @@ -215,9 +212,12 @@ def temporary_test(): print() print("Indices 3D") - pixel_sizes_zyx = [pixel_size_z, pixel_size_y, pixel_size_x] + full_res_pxl_sizes_zyx = [pixel_size_z, pixel_size_y, pixel_size_x] list_indices = convert_ROI_table_to_indices( - adata, level=0, coarsening_xy=2, pixel_sizes_zyx=pixel_sizes_zyx + adata, + level=0, + coarsening_xy=2, + full_res_pxl_sizes_zyx=full_res_pxl_sizes_zyx, ) for indices in list_indices: print(indices) @@ -232,54 +232,16 @@ def temporary_test(): print("Indices 2D") adata = convert_FOV_ROIs_3D_to_2D(adata, pixel_size_z) list_indices = convert_ROI_table_to_indices( - adata, level=0, coarsening_xy=2, pixel_sizes_zyx=pixel_sizes_zyx + adata, + level=0, + coarsening_xy=2, + full_res_pxl_sizes_zyx=full_res_pxl_sizes_zyx, ) for indices in list_indices: print(indices) print() -def extract_zyx_pixel_sizes_from_zattrs(zattrs_path: str, level: int = 0): - with open(zattrs_path, "r") as jsonfile: - zattrs = json.load(jsonfile) - - try: - - # Identify multiscales - multiscales = zattrs["multiscales"] - - # Check that there is a single multiscale - if len(multiscales) > 1: - raise Exception(f"ERROR: There are {len(multiscales)} multiscales") - - # Check that there are no datasets-global transformations - if "coordinateTransformations" in multiscales[0].keys(): - raise Exception( - "ERROR: coordinateTransformations at the multiscales " - "level are not currently supported" - ) - - # Identify all datasets (AKA pyramid levels) - datasets = multiscales[0]["datasets"] - - # Select highest-resolution dataset - transformations = datasets[level]["coordinateTransformations"] - for t in transformations: - if t["type"] == "scale": - return t["scale"] - raise Exception( - "ERROR:" - f" no scale transformation found for level {level}" - f" in {zattrs_path}" - ) - - except KeyError as e: - raise KeyError( - "extract_zyx_pixel_sizes_from_zattrs failed, for {zattrs_path}\n", - e, - ) - - if __name__ == "__main__": # import sys # args = sys.argv[1:] diff --git a/fractal/tasks/lib_zattrs_utils.py b/fractal/tasks/lib_zattrs_utils.py new file mode 100644 index 000000000..8e6283ed7 --- /dev/null +++ b/fractal/tasks/lib_zattrs_utils.py @@ -0,0 +1,81 @@ +import json +from typing import Dict +from typing import List + + +def extract_zyx_pixel_sizes(zattrs_path: str, level: int = 0): + + with open(zattrs_path, "r") as jsonfile: + zattrs = json.load(jsonfile) + + try: + + # Identify multiscales + multiscales = zattrs["multiscales"] + + # Check that there is a single multiscale + if len(multiscales) > 1: + raise Exception(f"ERROR: There are {len(multiscales)} multiscales") + + # Check that there are no datasets-global transformations + if "coordinateTransformations" in multiscales[0].keys(): + raise Exception( + "ERROR: coordinateTransformations at the multiscales " + "level are not currently supported" + ) + + # Identify all datasets (AKA pyramid levels) + datasets = multiscales[0]["datasets"] + + # Select highest-resolution dataset + transformations = datasets[level]["coordinateTransformations"] + for t in transformations: + if t["type"] == "scale": + return t["scale"] + raise Exception( + "ERROR:" + f" no scale transformation found for level {level}" + f" in {zattrs_path}" + ) + + except KeyError as e: + raise KeyError( + "extract_zyx_pixel_sizes_from_zattrs failed, for {zattrs_path}\n", + e, + ) + + +def rescale_datasets( + datasets: List[Dict] = None, + coarsening_xy: int = None, + reference_level: int = None, +) -> List[Dict]: + if datasets is None or coarsening_xy is None or reference_level is None: + raise TypeError("Missing argument in rescale_datasets") + + # Construct rescaled datasets + new_datasets = [] + for ds in datasets: + new_ds = {} + + # Copy all keys that are not coordinateTransformations (e.g. path) + for key in ds.keys(): + if key != "coordinateTransformations": + new_ds[key] = ds[key] + + # Update coordinateTransformations + old_transformations = ds["coordinateTransformations"] + new_transformations = [] + for t in old_transformations: + if t["type"] == "scale": + new_t = {"type": "scale"} + new_t["scale"] = [ + t["scale"][0], + t["scale"][1] * coarsening_xy**reference_level, + t["scale"][2] * coarsening_xy**reference_level, + ] + new_transformations.append(new_t) + else: + new_transformations.append(t) + new_ds["coordinateTransformations"] = new_transformations + new_datasets.append(new_ds) diff --git a/fractal/tasks/replicate_zarr_structure_mip.py b/fractal/tasks/replicate_zarr_structure_mip.py index 0a8ab2df1..b5aec9bef 100644 --- a/fractal/tasks/replicate_zarr_structure_mip.py +++ b/fractal/tasks/replicate_zarr_structure_mip.py @@ -19,9 +19,7 @@ from anndata.experimental import write_elem from fractal.tasks.lib_regions_of_interest import convert_FOV_ROIs_3D_to_2D -from fractal.tasks.lib_regions_of_interest import ( - extract_zyx_pixel_sizes_from_zattrs, -) +from fractal.tasks.lib_zattrs_utils import extract_zyx_pixel_sizes def replicate_zarr_structure_mip(zarrurl): @@ -128,7 +126,7 @@ def replicate_zarr_structure_mip(zarrurl): ) # Read pixel sizes from zattrs file - pixel_sizes_zyx = extract_zyx_pixel_sizes_from_zattrs(path_FOV_zattrs) + pixel_sizes_zyx = extract_zyx_pixel_sizes(path_FOV_zattrs, level=0) pixel_size_z = pixel_sizes_zyx[0] # Convert 3D FOVs to 2D