From e7a7efd36bf0a1c7ba23ada0cc4e3e6dbb677dc5 Mon Sep 17 00:00:00 2001 From: Hanjin Liu Date: Sat, 31 Aug 2024 18:03:33 +0900 Subject: [PATCH] don't use __getattr__ --- impy/array_api.py | 5 +- impy/arrays/_utils/_deconv.py | 17 +- impy/arrays/_utils/_transform.py | 4 +- impy/arrays/bases/metaarray.py | 25 +- impy/arrays/labeledarray.py | 424 +++++++++++++++---------------- 5 files changed, 234 insertions(+), 241 deletions(-) diff --git a/impy/array_api.py b/impy/array_api.py index 9cacd49..9e76044 100644 --- a/impy/array_api.py +++ b/impy/array_api.py @@ -21,9 +21,6 @@ def __init__(self): self._reset_namespace() self.setNumpy() - def __getattr__(self, key: str): - return getattr(self._module, key) - def _reset_namespace(self): self._signal = None self._fft = None @@ -111,6 +108,7 @@ def setNumpy(self) -> None: self.argmin = np.argmin self.pad = np.pad self.isnan = np.isnan + self.eye = np.eye self.state = "numpy" from ._const import Const @@ -181,6 +179,7 @@ def cp_asnumpy(arr, dtype=None): self.argmin = cp.argmin self.pad = cp.pad self.isnan = cp.isnan + self.eye = cp.eye self.state = "cupy" from ._const import Const diff --git a/impy/arrays/_utils/_deconv.py b/impy/arrays/_utils/_deconv.py index 00197d8..7b40340 100644 --- a/impy/arrays/_utils/_deconv.py +++ b/impy/arrays/_utils/_deconv.py @@ -4,15 +4,6 @@ import numpy as np from impy.array_api import xp -try: - gradient = xp.gradient -except AttributeError: - # CUDA <= ver.8 does not have gradient - import numpy - def gradient(a, axis=None): - out = numpy.gradient(a.get(), axis=axis) - return xp.asarray(out) - __all__ = ["wiener", "richardson_lucy", "richardson_lucy_tv", "check_psf"] def wiener(obs, psf_ft, psf_ft_conj, lmd): @@ -30,7 +21,7 @@ def richardson_lucy(obs, psf_ft, psf_ft_conj, niter, eps): obs = xp.asarray(obs) fft = xp.fft.rfftn ifft = partial(xp.fft.irfftn, s=obs.shape) - conv = factor = xp.empty(obs.shape, dtype=xp.float32) # placeholder + conv = factor = xp.empty(obs.shape, dtype=np.float32) # placeholder estimated = xp.real(ifft(fft(obs) * psf_ft)) # initialization for _ in range(niter): @@ -45,16 +36,16 @@ def richardson_lucy_tv(obs, psf_ft, psf_ft_conj, max_iter, lmd, tol, eps): fft = xp.fft.rfftn ifft = partial(xp.fft.irfftn, s=obs.shape) est_old = ifft(fft(obs) * psf_ft).real - est_new = xp.empty(obs.shape, dtype=xp.float32) + est_new = xp.empty(obs.shape, dtype=np.float32) conv = factor = norm = gg = xp.empty(obs.shape, dtype=np.float32) # placeholder for _ in range(max_iter): conv[:] = ifft(fft(est_old) * psf_ft).real factor[:] = ifft(fft(_safe_div(obs, conv, eps=eps)) * psf_ft_conj).real est_new[:] = est_old * factor - grad = gradient(est_old) + grad = xp.gradient(est_old) norm[:] = xp.sqrt(sum(g**2 for g in grad)) - gg[:] = sum(gradient(_safe_div(grad[i], norm, eps=1e-8), axis=i) + gg[:] = sum(xp.gradient(_safe_div(grad[i], norm, eps=1e-8), axis=i) for i in range(obs.ndim)) est_new /= (1 - lmd * gg) gain = xp.sum(xp.abs(est_new - est_old))/xp.sum(xp.abs(est_old)) diff --git a/impy/arrays/_utils/_transform.py b/impy/arrays/_utils/_transform.py index 11f80b5..de88289 100644 --- a/impy/arrays/_utils/_transform.py +++ b/impy/arrays/_utils/_transform.py @@ -343,9 +343,9 @@ def get_fourier_filter(size: int, filter_name: str): cosine_filter = xp.fft.fftshift(xp.sin(freq)) fourier_filter *= cosine_filter elif filter_name == "hamming": - fourier_filter *= xp.fft.fftshift(xp.hamming(size)) + fourier_filter *= xp.fft.fftshift(xp._module.hamming(size)) elif filter_name == "hann": - fourier_filter *= xp.fft.fftshift(xp.hanning(size)) + fourier_filter *= xp.fft.fftshift(xp._module.hanning(size)) elif filter_name is None: fourier_filter[:] = 1 else: diff --git a/impy/arrays/bases/metaarray.py b/impy/arrays/bases/metaarray.py index 2317ace..80ad9e5 100644 --- a/impy/arrays/bases/metaarray.py +++ b/impy/arrays/bases/metaarray.py @@ -329,8 +329,8 @@ def _apply_dask( new_axis: Iterable[int] = None, dtype = np.float32, out_chunks: tuple[int, ...] = None, - args: tuple[Any] = None, - kwargs: dict[str, Any] = None + args: tuple[Any] | None = None, + kwargs: dict[str, Any] | None = None ) -> Self: """ Convert array into dask array and run a batch process in parallel. In many cases batch process @@ -403,15 +403,7 @@ def _apply_dask( else: _args.append(arg) - def _func(*args, **kwargs): - args = list(args) - for i in img_idx: - if args[i].ndim < len(slice_in): - continue - args[i] = args[i][slice_in] - out = func(*args, **kwargs) - return xp.asnumpy(out[slice_out]) - + _func = _make_func(func, img_idx, slice_in, slice_out, xp.asnumpy) out = da.map_blocks( _func, *_args, @@ -935,3 +927,14 @@ def __eq__(self, other): return False _NOTME = NotMe() + +def _make_func(func, img_idx, slice_in, slice_out, as_numpy): + def _func(*args, **kwargs): + args = list(args) + for i in img_idx: + if args[i].ndim < len(slice_in): + continue + args[i] = args[i][slice_in] + out = func(*args, **kwargs) + return as_numpy(out[slice_out]) + return _func diff --git a/impy/arrays/labeledarray.py b/impy/arrays/labeledarray.py index 7aaff54..bb722fa 100644 --- a/impy/arrays/labeledarray.py +++ b/impy/arrays/labeledarray.py @@ -34,16 +34,16 @@ class SupportAxesSlicing(Protocol): @property def axes(self) -> Axes: """Axes object bound to the object.""" - + def _dimension_matches(self, array: MetaArray) -> bool: """Check if self matches array's shape and axes.""" - + def copy(self) -> Self: """Shallow copy of the object.""" - + def __getitem__(self, key) -> SupportAxesSlicing | None: """Slice object.""" - + def _slice_by(self, key) -> SupportAxesSlicing | None: """Slice object.""" @@ -60,7 +60,7 @@ def __init__(self, data: dict[str, SupportAxesSlicing], parent: MetaArray): import weakref self._data = data self._parent_ref = weakref.ref(parent) - + @property def parent(self): """Return the parent MetaArray object""" @@ -76,14 +76,14 @@ def construct_by_copying(self, parent: MetaArray | None) -> Self: for k, value in self.items(): if parent.axes.contains(value.axes): data[k] = value.copy() - + return self.__class__(data, parent) - + def construct_by_slicing(self, key, next_parent: MetaArray | None) -> Self: parent = self.parent if next_parent is None: next_parent = parent - + data: dict[str, SupportAxesSlicing] = {} for k, value in self.items(): if value is not None: @@ -94,10 +94,10 @@ def construct_by_slicing(self, key, next_parent: MetaArray | None) -> Self: _keys = (key,) label_sl = tuple( _fmt_slice(_keys[i], parent.shape[i]) - for i, a in enumerate(parent.axes) + for i, a in enumerate(parent.axes) if (a in value.axes and i < len(_keys)) ) - + if len(label_sl) == 0 or len(label_sl) > len(value.axes): label_sl = () else: @@ -112,7 +112,7 @@ def construct_by_slicing(self, key, next_parent: MetaArray | None) -> Self: def __getitem__(self, key: str) -> SupportAxesSlicing: return self._data[key] - + def __setitem__(self, key: str, value: SupportAxesSlicing) -> None: if value is None: self.pop(key, None) @@ -125,29 +125,29 @@ def __setitem__(self, key: str, value: SupportAxesSlicing) -> None: f"parent array ({parent.shape_info})." ) self._data[key] = value - + def __delitem__(self, key: str) -> None: del self._data[key] - + def __len__(self) -> int: return len(self._data) def __iter__(self): return iter(self._data) - + class LabeledArray(MetaArray): _name: str _source: Path | None _metadata: dict[str, Any] _covariates: ArrayCovariates - + def __new__( - cls: type[LabeledArray], + cls: type[LabeledArray], obj, name: str | None = None, axes: AxesLike | None = None, - source: str | Path | None = None, + source: str | Path | None = None, metadata: dict[str, Any] | None = None, dtype: DTypeLike = None, ) -> Self: @@ -155,9 +155,9 @@ def __new__( cls, obj, name, axes, source, metadata, dtype ) self._covariates = ArrayCovariates({}, self) - + return self - + @MetaArray.axes.setter def axes(self, value: AxesLike): if not hasattr(self, "_axes"): @@ -176,7 +176,7 @@ def axes(self, value: AxesLike): def range(self) -> tuple[float, float]: """Return min/max range of the array.""" return self.min(), self.max() - + @property def covariates(self) -> ArrayCovariates: """Get all the covariates.""" @@ -186,7 +186,7 @@ def covariates(self) -> ArrayCovariates: def labels(self) -> Label | None: """The label of the image.""" return self.covariates.get("labels") - + @labels.setter def labels(self, value: np.ndarray | None): if value is None: @@ -195,7 +195,7 @@ def labels(self, value: np.ndarray | None): if value is self: raise ValueError("Setting labels recursively is not allowed.") - + if not isinstance(value, Label): # convert input arr = np.asarray(value) @@ -207,23 +207,23 @@ def labels(self, value: np.ndarray | None): ) axes = self.axes[-arr.ndim:] value = Label(arr, axes=axes).optimize() - + if not value._dimension_matches(self): raise ValueError( f"Shape of input label ({value.shape_info}) does not match the " f"parent array ({self.shape_info})." ) self.covariates["labels"] = value - + @labels.deleter def labels(self): self.covariates.pop("labels", None) - + @property def rois(self) -> RoiList: """ROIs of the image.""" return self.covariates.get("rois") - + @rois.setter def rois(self, val) -> None: from ..roi import RoiList, POS @@ -232,23 +232,23 @@ def rois(self, val) -> None: import copy val = copy.copy(val) val.axes = self.axes[0] + val.axes[1:] - + self.covariates["rois"] = val else: self.covariates["rois"] = RoiList(self.axes, val) - + @rois.deleter def rois(self) -> None: self.covariates.pop("rois", None) - + def set_scale(self, other=None, unit: str | None = None, **kwargs) -> Self: out = super().set_scale(other, unit=unit, **kwargs) for cov in self.covariates.values(): if hasattr(cov, "set_scale"): cov.set_scale(other, **kwargs) return out - - + + def _repr_dict_(self): if self.labels is not None: labels_shape_info = self.labels.shape_info @@ -264,20 +264,20 @@ def _repr_dict_(self): "source": self.source, "scale": self.scale, } - - + + def imsave( self, - save_path: str | Path, + save_path: str | Path, *, dtype: DTypeLike = None, overwrite: bool = True, ) -> None: """ - Save image at the same directory as the original image by default. - - For tif file format, if the image contains wrong axes for ImageJ (= except for tzcyx), - then it will converted automatically if possible. For mrc file format, only zyx and yx is + Save image at the same directory as the original image by default. + + For tif file format, if the image contains wrong axes for ImageJ (= except for tzcyx), + then it will converted automatically if possible. For mrc file format, only zyx and yx is allowed. zyx-scale is also saved. Parameters @@ -288,8 +288,8 @@ def imsave( In what data type img will be saved. overwrite : bool, default is True Whether to overwrite the file if it already exists. - - """ + + """ save_path = Path(save_path) if self.ndim < 2: raise ValueError("Cannot save <2D array as an image.") @@ -302,7 +302,7 @@ def imsave( else: ext = ".tif" save_path = save_path.parent / (save_path.name + ext) - + if not Path(save_path).is_absolute(): if self.source is None: raise ValueError( @@ -312,32 +312,32 @@ def imsave( " >>> img.imsave(\"/path/to/XXX.tif\")" ) save_path = self.source.parent / save_path - + if not overwrite and save_path.exists(): raise FileExistsError(f"File {save_path!r} already exists.") if self.metadata is None: self.metadata = {} if dtype is None: dtype = self.dtype - + # save image imsave(save_path, self) return None - + # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # Basic Functions # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # - + def __array_finalize__(self, obj): super().__array_finalize__(obj) self._inherit_covariates(obj) - + def _set_info(self, other: Self, new_axes: Any = MetaArray._INHERIT): super()._set_info(other, new_axes) self._inherit_covariates(other) return self - + def _inherit_covariates(self, other: Self): if isinstance(other, LabeledArray): if other is not self: @@ -346,27 +346,27 @@ def _inherit_covariates(self, other: Self): self._covariates = other._covariates else: self._covariates = ArrayCovariates({}, self) - + def _getitem_additional_set_info(self, other: Self, key, new_axes): self._covariates = getattr(self, "covariates", ArrayCovariates({}, self)) super()._set_info(other, new_axes) if isinstance(other, LabeledArray): self._covariates = other.covariates.construct_by_slicing(key, self) - + return None - + def _update(self, out: Self): self.value[:] = out.as_img_type(self.dtype).value[:] return None - + # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # Type Conversions # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # - + def as_uint8(self) -> Self: if self.dtype == np.uint8: return self - + if self.dtype == np.uint16: out = self.value / 256 elif self.dtype == bool: @@ -400,15 +400,15 @@ def as_uint16(self) -> Self: out = out.view(self.__class__) out._set_info(self) return out - + def as_float(self) -> Self: if self.dtype == np.float32: return self out = self.value.astype(np.float32).view(self.__class__) out._set_info(self) return out - - + + def as_img_type(self, dtype=np.uint16) -> Self: dtype = np.dtype(dtype) if self.dtype == dtype: @@ -435,7 +435,7 @@ def as_img_type(self, dtype=np.uint16) -> Self: return self.astype(dtype) else: raise ValueError(f"dtype: {dtype}") - + # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # Simple Visualizations # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # @@ -454,15 +454,15 @@ def imshow(self, label: bool = False, dims = 2, plugin="matplotlib", **kwargs): # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # Interpolation - # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # - + # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # + @_docs.write_docs @same_dtype(asfloat=True) @dims_to_spatial_axes def map_coordinates( self, coordinates: ArrayLike, - *, + *, mode: PaddingMode = "constant", cval: float = 0, order: int = 3, @@ -492,15 +492,15 @@ def map_coordinates( """ coords = xp.asarray(coordinates) c_axes = complement_axes(dims, self.axes) - + if coords.ndim != 2: drop_axis: list[int] = [] else: drop_axis = [self.axisof(a) for a in dims[:-1]] - + if prefilter is None: prefilter = order > 1 - + fn = _map_coordinates_with_nan if has_nan else _map_coordinates out = self._apply_dask( fn, @@ -510,7 +510,7 @@ def map_coordinates( args=(coords,), kwargs=dict(mode=mode, cval=cval, order=order, prefilter=prefilter), ) - + if coords.ndim == len(dims) + 1: if isinstance(coordinates, MetaArray): new_axes = c_axes + coordinates.axes[1:] @@ -521,7 +521,7 @@ def map_coordinates( new_axes = c_axes + coordinates.axes[1:2] else: new_axes = c_axes + ["#"] - + out = out.view(self.__class__) out._set_info(self, new_axes=new_axes) return out @@ -530,7 +530,7 @@ def map_coordinates( def pointprops(self, coords: Coords, *, order: int = 3, squeeze: bool = True) -> PropArray: """ Measure interpolated intensity at points with float coordinates. - + This method is essentially identical to :func:`map_coordinates` but is more straightforward for measuring intensities at points. @@ -546,7 +546,7 @@ def pointprops(self, coords: Coords, *, order: int = 3, squeeze: bool = True) -> ------- PropArray or float Intensities at points. - + Examples -------- Calculate centroids and measure intensities. @@ -558,7 +558,7 @@ def pointprops(self, coords: Coords, *, order: int = 3, squeeze: bool = True) -> npoints, ncol = coords.shape dims = self.axes[-ncol:] out = self.map_coordinates(coords.T, order=order, dims=dims) - + out = PropArray( out, name=out.name, axes=out.axes, source=out.source, metadata=out.metadata, propname="pointprops", @@ -567,7 +567,7 @@ def pointprops(self, coords: Coords, *, order: int = 3, squeeze: bool = True) -> if npoints == 1 and squeeze: out = out[0] return out - + @_docs.write_docs def reslice( self, @@ -578,7 +578,7 @@ def reslice( prefilter: bool | None = None, ) -> PropArray: """ - Measure line profile (kymograph) iteratively for every slice of image. This + Measure line profile (kymograph) iteratively for every slice of image. This function is almost same as `skimage.measure.profile_line`, but can reslice 3D-images. The argument `linewidth` is not implemented here because it is useless. @@ -596,20 +596,20 @@ def reslice( ------- PropArray Line scans. - + Examples -------- 1. Rescile along a line and fit to a model function for every time frame. - + >>> scan = img.reslice([18, 32], [53, 48]) >>> out = scan.curve_fit(func, init, return_fit=True) >>> plt.plot(scan[0]) >>> plt.plot(out.fit[0]) - + 2. Rescile along a path. - + >>> scan = img.reslice([[18, 32], [53,48], [22,45], [28, 32]]) - """ + """ # path = [[y1, x1], [y2, x2], ..., [yn, xn]] if b is not None: a = [list(a), list(b)] @@ -617,7 +617,7 @@ def reslice( _, ndim = a.shape seg = SegmentedLine(a) coords = seg.sample_points().T - + if ndim == self.ndim: dims = self.axes else: @@ -626,22 +626,22 @@ def reslice( result = self.map_coordinates( coords, order=order, mode="constant", prefilter=prefilter, dims=dims, ) - + new_axis = "s" out = PropArray(result, name=self.name, dtype=np.float32, axes=c_axes+[new_axis], propname="reslice") - + out.set_scale(self) out.set_scale({new_axis: self.scale[dims[-1]] * seg.interv}) return out - + @_docs.write_docs @check_input_and_output def pathprops( self, paths: PathFrame | ArrayLike | Sequence[ArrayLike], - properties: str | Callable | Iterable[str | Callable] = "mean", - *, + properties: str | Callable | Iterable[str | Callable] = "mean", + *, order: int = 1, ) -> DataDict[str, PropArray]: """ @@ -654,7 +654,7 @@ def pathprops( properties : str or callable, or their iterable Properties to be analyzed. {order} - + Returns ------- DataDict of PropArray @@ -664,9 +664,9 @@ def pathprops( -------- 1. Time-course measurement of intensities on a path. >>> img.pathprops([[2, 3], [102, 301], [200, 400]]) - """ + """ id_axis = "N" - + # normalize paths if type(paths).__name__ == "PathFrame": paths = [np.asarray(path) for path in paths.split(id_axis)] @@ -674,11 +674,11 @@ def pathprops( paths = [np.asarray(paths)] else: paths = [np.asarray(path) for path in paths] - + ndim = paths[0].shape[1] npaths = len(paths) dims = ["z", "y", "x"][-ndim:] - + # make a function dictionary funcdict = dict() if isinstance(properties, str) or callable(properties): @@ -690,32 +690,32 @@ def pathprops( funcdict[f.__name__] = f else: raise TypeError(f"Cannot interpret property {f}") - + c_axes = complement_axes(dims, self.axes) out_shape = tuple(self.sizeof(a) for a in c_axes) out = DataDict( {k: PropArray( - np.empty((npaths,) + out_shape, dtype=np.float32), - name=self.name, + np.empty((npaths,) + out_shape, dtype=np.float32), + name=self.name, axes=[id_axis] + c_axes, source=self.source, - propname = f"pathprops<{k}>", + propname = f"pathprops<{k}>", dtype=np.float32 ) for k in funcdict.keys() } ) - + if order > 1: self = self.spline_filter(order=order, mode="constant") - + for i, path in enumerate(paths): resliced = self.reslice(path, order=order, prefilter=False) for name, func in funcdict.items(): out[name][i] = np.apply_along_axis(func, axis=-1, arr=resliced.value) - + return out - + @_docs.write_docs @dims_to_spatial_axes @same_dtype(asfloat=True) @@ -723,18 +723,18 @@ def pathprops( def spline_filter( self, order: int = 3, - mode: PaddingMode = "mirror", + mode: PaddingMode = "mirror", *, dims: Dims = None, update: bool = False, ): """ Run spline filter. - + Parameters ---------- {order}{mode}{dims}{update} - + Returns ------- LabeledArray @@ -746,43 +746,43 @@ def spline_filter( from ._utils import _filters return self._apply_dask( _filters.spline_filter, - c_axes=complement_axes(dims, self.axes), + c_axes=complement_axes(dims, self.axes), args=(order, np.float32, mode), ) - + # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # Cropping # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # - + @_docs.write_docs @check_input_and_output @dims_to_spatial_axes def crop_center(self, scale: nDFloat = 0.5, *, dims=2) -> Self: r""" - Crop out the center of an image. - + Crop out the center of an image. + Parameters ---------- scale : float or array-like, default is 0.5 Scale of the cropped image. If an array is given, each axis will be cropped in different scales, using each value respectively. {dims} - + Returns ------- Self CroppedImage - + Examples -------- 1. Create a :math:`512\times512` image from a :math:`1024\times1024` image. - + >>> img_cropped = img.crop_center(scale=0.5) - + 2. Create a :math:`21\times256\times256` image from a :math:`63\times1024\times1024` image. - + >>> img_cropped = img.crop_center(scale=[1/3, 1/2, 1/2]) - + """ # check scale if hasattr(scale, "__iter__") and len(scale) == 3 and dims == "yx": @@ -790,7 +790,7 @@ def crop_center(self, scale: nDFloat = 0.5, *, dims=2) -> Self: scale = np.asarray(check_nd(scale, len(dims))) if np.any((scale <= 0) | (1 < scale)): raise ValueError(f"scale must be (0, 1], but got {scale}") - + # Make axis-targeted slicing string sizes = self.sizesof(dims) fmt = slicer.get_formatter(dims) @@ -805,13 +805,13 @@ def crop_center(self, scale: nDFloat = 0.5, *, dims=2) -> Self: slices.append(slice(x0, x1)) out = self[fmt[tuple(slices)]] - + return out - + @check_input_and_output def crop_kernel(self, radius: nDInt = 2) -> Self: r""" - Make a kernel from an image by cropping out the center region. + Make a kernel from an image by cropping out the center region. This function is useful especially in `ImgArray.defocus()`. Parameters @@ -823,21 +823,21 @@ def crop_kernel(self, radius: nDInt = 2) -> Self: ------- LabeledArray Kernel - + Examples -------- Make a :math:`4\times4\times4` kernel from a point spread function image (suppose the image shapes are all even numbers). - + >>> psf = ip.imread(r".../PSF.tif") >>> psfker = psf.crop_kernel() >>> psfer.shape (4, 4, 4) - """ + """ sizes = self.shape radii = check_nd(radius, len(sizes)) return self[tuple(slice(s//2-r, (s+1)//2+r) for s, r in zip(sizes, radii))] - + @_docs.write_docs @check_input_and_output @dims_to_spatial_axes @@ -856,37 +856,37 @@ def remove_edges(self, pixel: nDInt = 1, *, dims=2) -> Self: ------- LabeledArray Cropped image. - """ + """ if hasattr(pixel, "__iter__") and len(pixel) == 3 and len(dims) == 2: dims = "zyx" pixel = np.asarray(check_nd(pixel, len(dims)), dtype=np.int64) if np.any(pixel < 0): raise ValueError("`pixel` must be positive.") - + fmt = slicer.get_formatter(dims) sl = tuple(slice(px, (-px or None)) for px in pixel) - + out = self[fmt[sl]] return out - + @_docs.write_docs @check_input_and_output @dims_to_spatial_axes def rotated_crop(self, origin, dst1, dst2, dims=2) -> Self: """ - Crop the image at four courners of an rotated rectangle. Currently only supports rotation within + Crop the image at four courners of an rotated rectangle. Currently only supports rotation within yx-plane. An rotated rectangle is specified with positions of a origin and two destinations `dst1` - and `dst2`, i.e., vectors (dst1-origin) and (dst2-origin) represent a rotated rectangle. Let + and `dst2`, i.e., vectors (dst1-origin) and (dst2-origin) represent a rotated rectangle. Let origin be the origin of a xy-plane, the rotation direction from dst1 to dst2 must be counter- clockwise, or the cropped image will be reversed. - + Parameters - ---------- + ---------- origin : (float, float) dst1 : (float, float) dst2 :(float, float) {dims} - + Returns ------- LabeledArray @@ -900,7 +900,7 @@ def rotated_crop(self, origin, dst1, dst2, dims=2) -> Self: all_coords = ax0[:, np.newaxis] + ax1[np.newaxis] - origin all_coords = np.moveaxis(all_coords, -1, 0) cropped_img = self._apply_dask( - ndi.map_coordinates, complement_axes(dims, self.axes), + ndi.map_coordinates, complement_axes(dims, self.axes), dtype=self.dtype, args=(all_coords,), kwargs=dict(prefilter=False, order=1) @@ -917,20 +917,20 @@ def rotated_crop(self, origin, dst1, dst2, dims=2) -> Self: print("cropping labels failed") else: cropped_img.append_label(cropped_labels) - + return cropped_img - + # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # Label handling and others # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # - + @_docs.write_docs @dims_to_spatial_axes - def specify(self, center: Coords, radius: Coords, *, dims: Dims = None, + def specify(self, center: Coords, radius: Coords, *, dims: Dims = None, labeltype: str = "square") -> Label: """ Make rectangle or ellipse labels from points. - + Parameters ---------- center : array like or MarkerFrame @@ -945,7 +945,7 @@ def specify(self, center: Coords, radius: Coords, *, dims: Dims = None, ------- Label Labeled regions. - + Examples -------- Find single molecules, draw circular labels around them if mean values were greater than 100. @@ -959,45 +959,45 @@ def specify(self, center: Coords, radius: Coords, *, dims: Dims = None, from ._utils._process_numba import _specify_circ_2d, _specify_circ_3d, _specify_square_2d, _specify_square_3d ndim = len(dims) radius = np.asarray(check_nd(radius, ndim), dtype=np.float32) - + if labeltype in ("square", "s"): radius = radius.astype(np.uint8) _specify = {2: _specify_square_2d, 3: _specify_square_3d}[ndim] - + elif labeltype in ("circle", "c"): _specify = {2: _specify_circ_2d, 3: _specify_circ_3d}[ndim] - + else: raise ValueError("`labeltype` must be 'square' or 'circle'.") - + label_axes = str(center.col_axes) label_shape = self.sizesof(label_axes) labels = largest_zeros(label_shape) - + n_label = 1 for sl, crds in center.iter(complement_axes(dims, center.col_axes)): _specify(labels[sl], crds.values, radius, n_label) n_label += len(crds) - + if self.labels is not None: warn("Existing labels are updated.", UserWarning) self.labels = Label(labels, axes=label_axes).optimize() self.labels.set_scale(self) - + else: center = np.asarray(center) if center.ndim == 1: center = center.reshape(1, -1) - + cols = {2:"yx", 3:"zyx"}[center.shape[1]] center = MarkerFrame(center, columns=cols, dtype=np.uint16) - return self.specify(center, radius, dims=dims, labeltype=labeltype) - + return self.specify(center, radius, dims=dims, labeltype=labeltype) + return self.labels - + @_docs.write_docs @dims_to_spatial_axes def label( @@ -1010,14 +1010,14 @@ def label( ) -> Label: """ Label image using skimage's label(). - - Label image using `ref_image` as reference image, or image itself. If + + Label image using `ref_image` as reference image, or image itself. If ``filt`` is given, image will be labeled only if certain condition dictated in `filt` is satisfied. `regionprops_table` is called inside every time image is labeled. - + .. code-block:: python - + def filt(img, lbl, area, major_axis_length): return area>10 and major_axis_length>5 @@ -1026,11 +1026,11 @@ def filt(img, lbl, area, major_axis_length): ref_image : array, optional Image to make label, by default self is used. filt : callable, positional argument but not optional - Filter function. The first argument is intensity image sliced from + Filter function. The first argument is intensity image sliced from `self`, the second is label image sliced from labeled `ref_image`, and the rest arguments is properties that will be calculated using `regionprops` function. The property arguments **must be named - exactly same** as the properties in `regionprops`. Number of + exactly same** as the properties in `regionprops`. Number of arguments can be two. {dims} {connectivity} @@ -1039,33 +1039,33 @@ def filt(img, lbl, area, major_axis_length): ------- Label Newly created label. - + Examples -------- 1. Label the image with threshold and visualize with napari. >>> thr = img.threshold() >>> img.label(thr) >>> ip.gui.add(img) - + 2. Label regions if only intensity is high. >>> def high_intensity(img, lbl, slice): >>> return np.mean(img[slice]) > 10000 >>> img.label(lbl, filt) - + 3. Label regions if no hole exists. >>> def no_hole(img, lbl, euler_number): >>> return euler_number > 0 >>> img.label(lbl, filt) - + 4. Label regions if centroids are inside themselves. >>> def no_hole(img, lbl, centroid): >>> yc, xc = map(int, centroid) >>> return lbl[yc, xc] > 0 >>> img.label(lbl, filt) - + """ from skimage.segmentation import relabel_sequential - from skimage.measure import label as skmes_label + from skimage.measure import label as skmes_label from skimage.measure import regionprops # check the shape of label_image if ref_image is None: @@ -1082,20 +1082,20 @@ def filt(img, lbl, area, major_axis_length): f"Shape mismatch. Image is {self.shape_info} but reference is" f"{ref_image.shape_info}." ) - + c_axes = complement_axes(dims, self.axes) labels = largest_zeros(ref_image.shape) - + if filt is None: labels[:] = ref_image._apply_dask( - skmes_label, - c_axes=c_axes, + skmes_label, + c_axes=c_axes, kwargs=dict(background=0, connectivity=connectivity) ).view(np.ndarray) else: if not callable(filt): raise TypeError("`filt` must be callable.") - + import inspect import pandas as pd properties = tuple(inspect.signature(filt).parameters)[2:] @@ -1104,7 +1104,7 @@ def filt(img, lbl, area, major_axis_length): for sl, lbl in ref_image.iter(c_axes): lbl = skmes_label(lbl, background=0, connectivity=connectivity) img = self.value[sl] - # Following lines are essentially doing the same thing as + # Following lines are essentially doing the same thing as # `skmes.regionprops_table`. However, `skmes.regionprops_table` # returns tuples in the separated columns in DataFrame and rename # property names like "centroid-0" and "centroid-1". @@ -1118,7 +1118,7 @@ def filt(img, lbl, area, major_axis_length): offset=offset )[0] offset += labels.max() - + # correct the label numbers of `labels` labels = labels.view(Label) labels._set_info(ref_image) @@ -1138,23 +1138,23 @@ def append_label(self, label_image: np.ndarray, new: bool = False) -> Label: Labeled image. new : bool, default is False If True, existing labels will be removed anyway. - + Returns ------- Label New labels. - + Example ------- Make label from different channels. - + >>> thr0 = img["c=0"].threshold("90%") >>> thr0.label() # binary to label >>> thr1 = img["c=1"].threshold("90%") >>> thr1.label() # binary to label >>> img.append_label(thr0.labels) >>> img.append_label(thr1.labels) - + If `thr0` has 100 labels and `thr1` has 150 labels then `img` will have :math:`100+150=250` labels. """ # check and cast label dtype @@ -1175,7 +1175,7 @@ def append_label(self, label_image: np.ndarray, new: bool = False) -> Label: f"`label_image` has dtype {label_image.dtype}, which is unable " "to be interpreted as an label." ) - + if self.labels is not None and not new: if label_image.shape != self.labels.shape: raise ImageAxesError( @@ -1183,7 +1183,7 @@ def append_label(self, label_image: np.ndarray, new: bool = False) -> Label: f"{self.labels.shape} while labels with shape " f"{label_image.shape} is given." ) - + self.labels = self.labels.add_label(label_image) else: # when label_image is simple ndarray @@ -1201,14 +1201,14 @@ def append_label(self, label_image: np.ndarray, new: bool = False) -> Label: f"Axes mismatch. Image has {self.axes}-axes but " f"{axes} was given." ) - + self.labels = Label(label_image, axes=axes, source=self.source) return self.labels - + @check_input_and_output(need_labels=True) def proj_labels(self, axis=None, forbid_overlap=False) -> Label: """ - Label projection. This function is useful when zyx-labels are drawn but you want to reduce the + Label projection. This function is useful when zyx-labels are drawn but you want to reduce the dimension. Parameters @@ -1225,7 +1225,7 @@ def proj_labels(self, axis=None, forbid_overlap=False) -> Label: """ self.labels = self.labels.proj(axis=axis, forbid_overlap=forbid_overlap) return self.labels - + def split(self, axis=None) -> DataList[Self]: """ Split n-dimensional image into (n-1)-dimensional images. This function is different from @@ -1245,7 +1245,7 @@ def split(self, axis=None) -> DataList[Self]: if axis is None: axis = find_first_appeared(self.axes, include="cztpa") axisint = self.axisof(axis) - + imgs = super().split(axisint) if self.labels is not None: labels = self.labels.split(axisint) @@ -1253,9 +1253,9 @@ def split(self, axis=None) -> DataList[Self]: lbl.axes = self.labels.axes.drop(axisint) lbl.set_scale(self.labels) img.labels = lbl - + return imgs - + def tile( self, shape: tuple[int, int] | None = None, @@ -1273,12 +1273,12 @@ def tile( Axis (Axes) over which will be iterated. order : str, {"r", "c"}, optional Order of iteration. "r" means row-wise and "c" means column-wise. - + row-wise -----> -----> -----> - + column-wise | | | | | | @@ -1288,7 +1288,7 @@ def tile( ------- Labeled Tiled array - """ + """ if along is None: for a in self.axes: l = np.prod(shape) @@ -1310,10 +1310,10 @@ def tile( raise ValueError("`shape` must be specified unless the length of `along` is 2.") else: raise ValueError("`along` must be a string with length 1 or 2.") - + if order is None: order = "r" - + uy_max, ux_max = shape imgy, imgx = self.sizesof("yx") if len(shape) == 2: @@ -1322,27 +1322,27 @@ def tile( outshape = self.sizesof(c_axes) + (uy_max*imgy, ux_max*imgx) else: raise ValueError("Shape mismatch") - + out = np.zeros(outshape, dtype=self.dtype) - + if order == "r": iter_tile = _iter_tile_yx elif order == "c": iter_tile = _iter_tile_xy else: raise ValueError(f"Could not interpret order={repr(order)}.") - + for (_, img), sl in zip(self.iter(along), iter_tile(uy_max, ux_max, imgy, imgx)): out[sl] = img - + out = out.view(self.__class__) out._set_info(self, new_axes=new_axes) - + if self.labels is not None: tiled_label = self.labels.tile(shape, along, order) out.labels = tiled_label return out - + @check_input_and_output def for_each_channel(self, func: str, along: str = "c", **kwargs) -> Self: """ @@ -1360,7 +1360,7 @@ def for_each_channel(self, func: str, along: str = "c", **kwargs) -> Self: ------- LabeledArray output image stack - """ + """ if not hasattr(self, func): raise AttributeError(f"{self.__class__} does not have method {func}") imgs = self.split(along) @@ -1369,7 +1369,7 @@ def for_each_channel(self, func: str, along: str = "c", **kwargs) -> Self: outs.append(getattr(img, func)(**kw)) out = np.stack(outs, axis=along) return out - + @check_input_and_output def for_params(self, func: Callable|str, var: dict[str, Iterable] = None, **kwargs) -> DataList: """ @@ -1379,7 +1379,7 @@ def for_params(self, func: Callable|str, var: dict[str, Iterable] = None, **kwar Parameters ---------- func : callable or str - Function to apply repetitively. If str, then member method will be called. + Function to apply repetitively. If str, then member method will be called. var : dict[str, Iterable], optional Name of variable and the values to try. If you want to try sigma=1,2,3 then you should give `var={"sigma": [1, 2, 3]}`. @@ -1391,24 +1391,24 @@ def for_params(self, func: Callable|str, var: dict[str, Iterable] = None, **kwar ------- DataList List of outputs. - + Example ------- 1. Try LoG filter with different Gaussian kernel size and visualize all of them in napari. - + >>> out = img.for_params("log_filter", var={"sigma":[1, 2, 3, 4]}) # or >>> out = img.for_params("log_filter", sigma=[1, 2, 3, 4]) # then >>> ip.gui.add(out) - """ + """ if isinstance(func, str) and hasattr(self, func): f = getattr(self, func) elif callable(func): f = partial(func, self) elif not callable(func): raise AttributeError(f"{func} is neither {self.__class__}'s' method nor callable object.") - + if isinstance(var, dict): key, values = tuple(var.items())[0] elif var is None and len(kwargs) == 1: @@ -1416,18 +1416,18 @@ def for_params(self, func: Callable|str, var: dict[str, Iterable] = None, **kwar kwargs = dict() else: raise ValueError("Wrong inputs.") - + if key in kwargs.keys(): raise ValueError(f"Keyword {key} exists in `kwargs`.") outlist = DataList() - + for v in values: kwargs[key] = v out = f(**kwargs) outlist.append(out) - return outlist - - + return outlist + + @check_input_and_output(need_labels=True) def extract(self, label_ids=None, filt=None, cval:float=0) -> DataList[Self]: """ @@ -1441,23 +1441,23 @@ def extract(self, label_ids=None, filt=None, cval:float=0) -> DataList[Self]: If given, only regions `X` that satisfy filt(self, X) will extracted. cval : float, default is 0. Constant value to fill regions outside the extracted labeled regions. - + Returns ------- DataList of LabeledArray Extracted image(s) - """ + """ if filt is None: filt = lambda arr, lbl: True elif not callable(filt): raise TypeError("`filt` must be callable if given.") - + if np.isscalar(label_ids): label_ids = [label_ids] elif label_ids is None: # All the labels except for 0 (which means not labeled) label_ids = [i for i in np.unique(self.labels) if i != 0] - + slices = ndi.find_objects(self.labels) out = [] for i in label_ids: @@ -1468,9 +1468,9 @@ def extract(self, label_ids=None, filt=None, cval:float=0) -> DataList[Self]: obj.value[~subregion] = cval del obj.labels out.append(obj) - + out = DataList(out) - + return out def _iter_dict(d, nparam): @@ -1496,7 +1496,7 @@ def _iter_tile_yx(ymax, xmax, imgy, imgx): +--+--+--+ | 6| 7|..| +--+--+--+ - """ + """ for uy, ux in itertools.product(range(ymax), range(xmax)): sly = slice(uy*imgy, (uy+1)*imgy, None) slx = slice(ux*imgx, (ux+1)*imgx, None) @@ -1511,7 +1511,7 @@ def _iter_tile_xy(ymax, xmax, imgy, imgx): +--+--+--+ | 2| 5|..| +--+--+--+ - """ + """ for uy, ux in itertools.product(range(xmax), range(ymax)): sly = slice(uy*imgy, (uy+1)*imgy, None) slx = slice(ux*imgx, (ux+1)*imgx, None) @@ -1522,19 +1522,19 @@ class SegmentedLine: def __init__(self, nodes: np.ndarray): if nodes.shape[0] < 2: raise ValueError("More than one points must be given.") - + vec = np.diff(nodes, axis=0) dist = np.sqrt(np.sum(vec**2, axis=1)) dist_sum = np.sum(dist) npoints = int(dist_sum) interv = dist_sum / npoints - + self.length = dist_sum self.vec = vec self.dist = dist self.nodes = nodes self.interv = interv - + def sample_points(self) -> np.ndarray: res = 0 out = [self.nodes[0:1]] @@ -1546,10 +1546,10 @@ def sample_points(self) -> np.ndarray: xs = idx[:, np.newaxis] * v[np.newaxis]/d + p out.append(xs) npoints += xs.shape[0] - + if npoints <= int(self.length): out.append(self.nodes[-1:]) - + return np.concatenate(out, axis=0) def _count_list_depth(x) -> int: @@ -1587,7 +1587,7 @@ def _map_coordinates_with_nan(input, coordinates, order, mode, cval, prefilter): prefilter=prefilter, ) mapped_nans = xp.ndi.map_coordinates( - nans.astype(xp.float32), + nans.astype(np.float32), xp.asarray(coordinates), order=order, mode=mode,