Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/hanjinliu/impy
Browse files Browse the repository at this point in the history
  • Loading branch information
Hanjin Liu committed Apr 3, 2024
2 parents 5bda6e5 + d4874db commit 6e82f2a
Show file tree
Hide file tree
Showing 5 changed files with 179 additions and 63 deletions.
22 changes: 11 additions & 11 deletions impy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
__version__ = "2.3.3"
__version__ = "2.4.0"
__author__ = "Hanjin Liu"
__email__ = "[email protected]"

import logging

from ._const import Const, SetConst, use
from ._const import Const, SetConst, use # noqa

from .collections import DataList, DataDict
from .core import *
from .binder import bind
from .viewer import gui
from .correlation import *
from .arrays import ImgArray, LazyImgArray, BigImgArray, Label # for typing
from . import random, io, lazy
from .axes import slicer
from .collections import DataList, DataDict # noqa
from .core import * # noqa
from .binder import bind # noqa
from .viewer import gui # noqa
from .correlation import * # noqa
from .arrays import ImgArray, LazyImgArray, BigImgArray, Label # noqa
from . import random, io, lazy # noqa
from .axes import slicer # noqa

# Inheritance
# -----------
Expand All @@ -34,7 +34,7 @@
del logging

# dtypes
from numpy import (
from numpy import ( # noqa
uint8, uint16, uint32, uint64,
int8, int16, int32, int64,
float16, float32, float64,
Expand Down
120 changes: 106 additions & 14 deletions impy/io.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any, NamedTuple, Callable, TypeVar, Union, Protocol

from typing import TYPE_CHECKING, Any, NamedTuple, Callable, Sequence, TypeVar, Union, Protocol
from pathlib import Path
import json
import re
import warnings
Expand Down Expand Up @@ -50,6 +52,7 @@ class ImageData(NamedTuple):
from .arrays.bases import MetaArray
from .arrays import LazyImgArray
from .axes import Axes
from dask.array.core import Array

ImpyArray = Union[MetaArray, LazyImgArray]
Reader = Callable[[str, bool], ImageData]
Expand Down Expand Up @@ -123,7 +126,7 @@ def _register(f: _R):

return _register

def imread(self, path: str, memmap: bool = False) -> ImageData:
def imread(self, path: str | Path, memmap: bool = False) -> ImageData:
"""
Read an image file.
Expand All @@ -141,11 +144,15 @@ def imread(self, path: str, memmap: bool = False) -> ImageData:
ImageData
Image data tuple.
"""
_, ext = os.path.splitext(path)
ext = Path(path).suffix
reader = self._reader.get(ext, self._default_reader)
return reader(path, memmap)
return reader(str(path), memmap)

def imread_dask(self, path: str, chunks: Any) -> ImageData:
def _imread_slice(self, path, sl: tuple[slice, ...]) -> np.memmap:
mmap = self.imread(path, memmap=True).image
return np.asarray(mmap[sl], dtype=mmap.dtype)

def imread_dask(self, path: str | Path, chunks: Any) -> ImageData:
"""
Read an image file as a dask array.
Expand All @@ -165,21 +172,39 @@ def imread_dask(self, path: str, chunks: Any) -> ImageData:
Image data tuple.
"""
from .array_api import xp

path = Path(path)
image_data = self.imread(path, memmap=True)
img = image_data.image

if img.dtype == ">u2":
img = img.astype(np.uint16)
from dask import array as da, delayed
from dask.array.core import normalize_chunks

from dask import array as da
if str(type(img)) == "<class 'zarr.core.Array'>":
if path.suffix == ".zarr":
if img.dtype == ">u2":
img = img.astype(np.uint16)
dask = da.from_zarr(img, chunks=chunks).map_blocks(
xp.asarray, dtype=img.dtype
)
else:
dask = da.from_array(img, chunks=chunks, meta=xp.array([])).map_blocks(
xp.asarray, dtype=img.dtype
chunks_: tuple[tuple[int, ...]] = normalize_chunks(
chunks,
shape=img.shape,
dtype=img.dtype,
)
chunk_slices = [_chunk_to_slice(c) for c in chunks_]
block_shape = tuple(len(c) for c in chunks_)
delayed_imread = delayed(self._imread_slice)
arr_blocks = np.empty(block_shape, dtype=object)
for ind, _ in np.ndenumerate(arr_blocks):
sl = tuple(sls[i] for i, sls in zip(ind, chunk_slices))
cur_shape = tuple(_sl.stop - _sl.start for _sl in sl)
arr_blocks[ind] = da.from_delayed(
delayed_imread(path, sl), shape=cur_shape, dtype=img.dtype,
meta=xp.array([]),
)
dask = da.block(arr_blocks.tolist())

return ImageData(
image=dask,
axes=image_data.axes,
Expand All @@ -199,6 +224,15 @@ def imsave(
writer = self._writer.get(ext, self._default_writer)
return writer(path, img, lazy)

def _chunk_to_slice(chunk: Sequence[int]) -> list[slice]:
# _chunk_to_slice([5, 15, 30]) --> [0:5, 5:20, 20:50]
start = 0
out: list[slice] = []
for c in chunk:
_next = start + c
out.append(slice(start, _next))
start = _next
return out

IO = ImageIO()

Expand Down Expand Up @@ -336,7 +370,9 @@ def _(path: str, img: ImpyArray, lazy: bool = False):
from dask import array as da
kwargs = _get_ijmeta_from_img(img, update_lut=False)
mmap = memmap(str(path), shape=img.shape, dtype=img.dtype, **kwargs)
da.store(img.value, mmap, compute=True)
img_dask = _rechunk_to_ones(img.value)
writer = _MemmapArrayWriter(path, mmap.offset, img.shape, img_dask.chunksize)
da.store(img_dask, writer)
mmap.flush()
return

Expand All @@ -359,7 +395,7 @@ def _(path: str, img: ImpyArray, lazy: bool = False):
img_new.set_scale(img)
img = img_new

warnings.warn("Image axes changed", UserWarning)
warnings.warn("Image axes changed", UserWarning, stacklevel=2)

img = img.sort_axes()
if img.dtype == "bool":
Expand Down Expand Up @@ -431,7 +467,10 @@ def _(path: str, img: ImpyArray, lazy: bool = False):
mode = _MRC_MODE[img.dtype]
mrc_mmap = mrcfile.new_mmap(path, img.shape, mrc_mode=mode, overwrite=True)
mrc_mmap.voxel_size = voxel_size
da.store(img.value, mrc_mmap.data)

img_dask = _rechunk_to_ones(img.value)
writer = _MemmapArrayWriter(path, mrc_mmap.data.offset, img.shape, img_dask.chunksize)
da.store(img_dask, writer)
mrc_mmap.flush()
return None

Expand Down Expand Up @@ -656,3 +695,56 @@ def _scalar(x: Any) -> np.ndarray:
ar = np.array(None, dtype=object)
ar[()] = x
return ar

def _rechunk_to_ones(arr: Array):
"""Rechunk the array to (1, 1, ..., 1, n, Ny, Nx)"""
size = np.prod(arr.chunksize)
shape = arr.shape
cur_prod = 1
max_i = arr.ndim
for i in reversed(range(arr.ndim)):
cur_prod *= shape[i]
if cur_prod > size:
break
max_i = i
nslices = max(int(size / np.prod(shape[max_i:])), 1)
if max_i == 0:
return arr
else:
return arr.rechunk((1,) * (max_i - 1) + (nslices,) + shape[max_i:])

class _MemmapArrayWriter:
def __init__(
self,
path: str,
offset: int,
shape: tuple[int, ...],
chunksize: tuple[int, ...],
):
self._path = path
self._offset = offset
self._shape = shape # original shape
self._chunksize = chunksize # chunk size
# shape = (33, 160, 1000, 1000)
# chunksize = (1, 16, 1000, 1000)
border = 0
for i, c in enumerate(chunksize):
if c != 1:
border = i
break
self._border = border

def __setitem__(self, sl: tuple[slice, ...], arr: np.ndarray):
# efficient: shape = (10, 100, 150) and sl = (3:5, 0:100, 0:150)
# sl = (0:1, 16:32, 0:1000, 0:1000)

offset = np.sum([sl[i].start * arr.strides[i] for i in range(self._border + 1)])
mmap = np.memmap(
self._path,
mode="r+",
offset=self._offset + offset,
dtype=arr.dtype,
)
arr_ravel = arr.ravel()
mmap[:arr_ravel.size] = arr_ravel
mmap.flush()
4 changes: 2 additions & 2 deletions impy/lazy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .core import *
from . import random
from .core import * # noqa
from . import random # noqa
Loading

0 comments on commit 6e82f2a

Please sign in to comment.