Skip to content

Commit

Permalink
Add write_pyramid function to address bug #97 in illumination_correction
Browse files Browse the repository at this point in the history
  • Loading branch information
tcompa committed Jul 12, 2022
1 parent 3556aca commit 9ba643e
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 31 deletions.
38 changes: 10 additions & 28 deletions fractal/tasks/illumination_correction.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@
import numpy as np
from skimage.io import imread

from fractal.tasks.lib_pyramid_creation import create_pyramid
from fractal.tasks.lib_to_zarr_custom import to_zarr_custom
from fractal.tasks.lib_pyramid_creation import write_pyramid


def correct(
Expand Down Expand Up @@ -50,11 +49,6 @@ def correct(
"""

chunk_location = block_info[None]["chunk-location"]

with open("LOG_illum", "a") as out:
out.write(f"[{chunk_location}] START illumination correction")

# Check shapes
if img.shape != (1, img_size_y, img_size_x):
raise Exception(
Expand Down Expand Up @@ -87,9 +81,6 @@ def correct(
)
img_corr[img_corr > np.iinfo(img.dtype).max] = np.iinfo(img.dtype).max

with open("LOG_illum", "a") as out:
out.write(f"[{chunk_location}] END illumination correction")

return img_corr.astype(img.dtype)


Expand Down Expand Up @@ -117,7 +108,7 @@ def illumination_correction(
:param path_dict_corr: path of JSON file with info on illumination matrices
:type path_dict_corr: str
:param coarsening_xy: coarsening factor in XY (optional, default 2)
:type coarsening_z: xy
:type coarsening_xy: xy
:param background: value for background subtraction (optional, default 110)
:type background: int
Expand All @@ -135,9 +126,6 @@ def illumination_correction(
f"overwrite={overwrite} and newzarrurl={newzarrurl}."
)

with open("LOG_illum", "w") as out:
out.write("init")

# Sanitize zarr paths
if not zarrurl.endswith("/"):
zarrurl += "/"
Expand All @@ -155,6 +143,8 @@ def illumination_correction(
with open(path_dict_corr, "r") as jsonfile:
dict_corr = json.load(jsonfile)
root_path_corr = dict_corr.pop("root_path_corr")
if not root_path_corr.endswith("/"):
root_path_corr += "/"

# Assemble dictionary of matrices and check their shapes
corrections = {}
Expand Down Expand Up @@ -207,7 +197,7 @@ def illumination_correction(
data_zyx = data_czyx[ind_ch]
illum_img = corrections[ch]

# Map "correct" function onto each block
# Map correct(..) function onto each block
data_zyx_new = data_zyx.map_blocks(
correct,
chunks=(1, img_size_y, img_size_x),
Expand All @@ -218,27 +208,19 @@ def illumination_correction(
img_size_x=img_size_x,
)
data_czyx_new.append(data_zyx_new)
accumulated_data = da.stack(data_czyx_new, axis=0)

# Construct resolution pyramid
pyramid = create_pyramid(
da.stack(data_czyx_new, axis=0),
coarsening_z=1,
write_pyramid(
accumulated_data,
newzarrurl=newzarrurl,
overwrite=overwrite,
coarsening_xy=coarsening_xy,
num_levels=num_levels,
chunk_size_x=img_size_x,
chunk_size_y=img_size_y,
num_channels=len(chl_list),
)

# Write data into output zarr
for ind_level in range(num_levels):
to_zarr_custom(
newzarrurl,
component=f"{ind_level}",
array=pyramid[ind_level],
overwrite=overwrite,
)


if __name__ == "__main__":
from argparse import ArgumentParser
Expand Down
89 changes: 89 additions & 0 deletions fractal/tasks/lib_pyramid_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
import dask.array as da
import numpy as np

from fractal.tasks.lib_to_zarr_custom import to_zarr_custom


def create_pyramid(
data_czyx,
Expand Down Expand Up @@ -167,3 +169,90 @@ def create_pyramid_3D(
pyramid.append(data_zyx_final.astype(data_zyx.dtype))

return pyramid


def write_pyramid(
data,
overwrite=False,
newzarrurl=None,
coarsening_xy=2,
num_levels=2,
chunk_size_x=None,
chunk_size_y=None,
aggregation_function=None,
):

"""
Take a four-dimensional array and build a pyramid of coarsened levels
:param data_czyx: input data
:type data_czyx: dask array
:param coarsening_xy: coarsening factor along X and Y
:type coarsening_xy: int
:param num_levels: number of levels in the zarr pyramid
:type num_levels: int
:param chunk_size_x: chunk size along X
:type chunk_size_x: int
:param chunk_size_y: chunk size along Y
:type chunk_size_y: int
:param aggregation_function: FIXME
:type aggregation_function: FIXME
"""

# Check the number of axes and identify YX dimensions
ndims = len(data.shape)
if ndims not in [2, 3, 4]:
raise Exception(
"ERROR: data has shape {data.shape}, ndims not in [2,3,4]"
)
y_axis = ndims - 2
x_axis = ndims - 1

# Set rechunking options, if needed
if chunk_size_x is None or chunk_size_y is None:
apply_rechunking = False
else:
apply_rechunking = True
chunking = {y_axis: chunk_size_y, x_axis: chunk_size_x}

# Set aggregation_function
if aggregation_function is None:
aggregation_function = np.mean

# Create pyramid of XY-coarser levels

# Highest-resolution level
level0 = to_zarr_custom(
newzarrurl=newzarrurl, array=data, component="0", overwrite=overwrite
)
if apply_rechunking:
levels = [level0.rechunk(chunking)]
else:
levels = [level0]

# Lower-resolution levels
for ind_level in range(1, num_levels):
# Verify that coarsening is possible
if min(levels[-1].shape[-2:]) < coarsening_xy:
raise Exception(
f"ERROR: at {ind_level}-th level, "
f"coarsening_xy={coarsening_xy} "
f"but {ind_level-1}-th level has shape {levels[-1].shape}"
)
# Apply coarsening
newlevel = da.coarsen(
aggregation_function,
levels[ind_level - 1],
{y_axis: coarsening_xy, x_axis: coarsening_xy},
trim_excess=True,
).astype(data.dtype)
written_level = to_zarr_custom(
newzarrurl=newzarrurl,
array=newlevel,
component=f"{ind_level}",
overwrite=overwrite,
)
if apply_rechunking:
levels.append(written_level.rechunk(chunking))
else:
levels.append(written_level)
17 changes: 14 additions & 3 deletions fractal/tasks/lib_to_zarr_custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@
Institute for Biomedical Research and Pelkmans Lab from the University of
Zurich.
"""

import shutil

import dask.array as da


def to_zarr_custom(newzarrurl=None, array=None, component="", overwrite=False):

Expand All @@ -23,7 +24,7 @@ def to_zarr_custom(newzarrurl=None, array=None, component="", overwrite=False):
(https://github.com/dask/dask/issues/5942), where a dask array loaded with
from_zarr cannot be written with to_zarr(..., overwrite=True).
:param newzarrurl: ouput zarr file
:param newzarrurl: output zarr file
:type newzarrurl: str
:param array: dask array to be stored
:type array: dask array
Expand Down Expand Up @@ -52,10 +53,20 @@ def to_zarr_custom(newzarrurl=None, array=None, component="", overwrite=False):
newzarrurl,
component=component + tmp_suffix,
dimension_separator="/",
compute=True,
)
shutil.rmtree(newzarrurl + component)
shutil.move(
newzarrurl + component + tmp_suffix, newzarrurl + component
)
output = da.from_zarr(newzarrurl, component=component)
else:
array.to_zarr(newzarrurl, component=component, dimension_separator="/")
output = array.to_zarr(
newzarrurl,
component=component,
dimension_separator="/",
compute=True,
return_stored=True,
)

return output

0 comments on commit 9ba643e

Please sign in to comment.