Skip to content

Commit

Permalink
don't use __getattr__
Browse files Browse the repository at this point in the history
  • Loading branch information
hanjinliu committed Aug 31, 2024
1 parent f52dca6 commit e7a7efd
Show file tree
Hide file tree
Showing 5 changed files with 234 additions and 241 deletions.
5 changes: 2 additions & 3 deletions impy/array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,6 @@ def __init__(self):
self._reset_namespace()
self.setNumpy()

def __getattr__(self, key: str):
return getattr(self._module, key)

def _reset_namespace(self):
self._signal = None
self._fft = None
Expand Down Expand Up @@ -111,6 +108,7 @@ def setNumpy(self) -> None:
self.argmin = np.argmin
self.pad = np.pad
self.isnan = np.isnan
self.eye = np.eye

self.state = "numpy"
from ._const import Const
Expand Down Expand Up @@ -181,6 +179,7 @@ def cp_asnumpy(arr, dtype=None):
self.argmin = cp.argmin
self.pad = cp.pad
self.isnan = cp.isnan
self.eye = cp.eye
self.state = "cupy"

from ._const import Const
Expand Down
17 changes: 4 additions & 13 deletions impy/arrays/_utils/_deconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,6 @@
import numpy as np
from impy.array_api import xp

try:
gradient = xp.gradient
except AttributeError:
# CUDA <= ver.8 does not have gradient
import numpy
def gradient(a, axis=None):
out = numpy.gradient(a.get(), axis=axis)
return xp.asarray(out)

__all__ = ["wiener", "richardson_lucy", "richardson_lucy_tv", "check_psf"]

def wiener(obs, psf_ft, psf_ft_conj, lmd):
Expand All @@ -30,7 +21,7 @@ def richardson_lucy(obs, psf_ft, psf_ft_conj, niter, eps):
obs = xp.asarray(obs)
fft = xp.fft.rfftn
ifft = partial(xp.fft.irfftn, s=obs.shape)
conv = factor = xp.empty(obs.shape, dtype=xp.float32) # placeholder
conv = factor = xp.empty(obs.shape, dtype=np.float32) # placeholder
estimated = xp.real(ifft(fft(obs) * psf_ft)) # initialization

for _ in range(niter):
Expand All @@ -45,16 +36,16 @@ def richardson_lucy_tv(obs, psf_ft, psf_ft_conj, max_iter, lmd, tol, eps):
fft = xp.fft.rfftn
ifft = partial(xp.fft.irfftn, s=obs.shape)
est_old = ifft(fft(obs) * psf_ft).real
est_new = xp.empty(obs.shape, dtype=xp.float32)
est_new = xp.empty(obs.shape, dtype=np.float32)
conv = factor = norm = gg = xp.empty(obs.shape, dtype=np.float32) # placeholder

for _ in range(max_iter):
conv[:] = ifft(fft(est_old) * psf_ft).real
factor[:] = ifft(fft(_safe_div(obs, conv, eps=eps)) * psf_ft_conj).real
est_new[:] = est_old * factor
grad = gradient(est_old)
grad = xp.gradient(est_old)
norm[:] = xp.sqrt(sum(g**2 for g in grad))
gg[:] = sum(gradient(_safe_div(grad[i], norm, eps=1e-8), axis=i)
gg[:] = sum(xp.gradient(_safe_div(grad[i], norm, eps=1e-8), axis=i)
for i in range(obs.ndim))
est_new /= (1 - lmd * gg)
gain = xp.sum(xp.abs(est_new - est_old))/xp.sum(xp.abs(est_old))
Expand Down
4 changes: 2 additions & 2 deletions impy/arrays/_utils/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,9 +343,9 @@ def get_fourier_filter(size: int, filter_name: str):
cosine_filter = xp.fft.fftshift(xp.sin(freq))
fourier_filter *= cosine_filter
elif filter_name == "hamming":
fourier_filter *= xp.fft.fftshift(xp.hamming(size))
fourier_filter *= xp.fft.fftshift(xp._module.hamming(size))
elif filter_name == "hann":
fourier_filter *= xp.fft.fftshift(xp.hanning(size))
fourier_filter *= xp.fft.fftshift(xp._module.hanning(size))
elif filter_name is None:
fourier_filter[:] = 1
else:
Expand Down
25 changes: 14 additions & 11 deletions impy/arrays/bases/metaarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,8 +329,8 @@ def _apply_dask(
new_axis: Iterable[int] = None,
dtype = np.float32,
out_chunks: tuple[int, ...] = None,
args: tuple[Any] = None,
kwargs: dict[str, Any] = None
args: tuple[Any] | None = None,
kwargs: dict[str, Any] | None = None
) -> Self:
"""
Convert array into dask array and run a batch process in parallel. In many cases batch process
Expand Down Expand Up @@ -403,15 +403,7 @@ def _apply_dask(
else:
_args.append(arg)

def _func(*args, **kwargs):
args = list(args)
for i in img_idx:
if args[i].ndim < len(slice_in):
continue
args[i] = args[i][slice_in]
out = func(*args, **kwargs)
return xp.asnumpy(out[slice_out])

_func = _make_func(func, img_idx, slice_in, slice_out, xp.asnumpy)
out = da.map_blocks(
_func,
*_args,
Expand Down Expand Up @@ -935,3 +927,14 @@ def __eq__(self, other):
return False

_NOTME = NotMe()

def _make_func(func, img_idx, slice_in, slice_out, as_numpy):
def _func(*args, **kwargs):
args = list(args)
for i in img_idx:
if args[i].ndim < len(slice_in):
continue
args[i] = args[i][slice_in]
out = func(*args, **kwargs)
return as_numpy(out[slice_out])
return _func
Loading

0 comments on commit e7a7efd

Please sign in to comment.