-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ENH: Add Scipy and Cupy as fft interfaces #8
base: main
Are you sure you want to change the base?
Changes from 4 commits
f09a64f
41b77e4
97e7c5c
93618fb
b19a817
8525abd
839214e
7d1ca55
4f01683
63b0c4d
5054745
5319043
35975f3
23d367f
b812c07
2361769
32cf2ce
59f67ec
2468569
7d9f1ae
d2e2f61
e427cc5
d7d84f5
2da1582
57947e6
1ac98ff
fe33bf0
568d511
7d88bb9
d7457e1
528a038
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,3 @@ | ||
sphinx==4.3.0 | ||
sphinxcontrib.bibtex>=2.0 | ||
sphinx_rtd_theme==1.0 | ||
|
||
sphinx | ||
sphinxcontrib.bibtex | ||
sphinx_rtd_theme |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
"""Fourier Transform interfaces available | ||
|
||
This example visualizes the different backends and packages available to the | ||
user for performing Fourier transforms. | ||
|
||
- PyFFTW is initially slow, but over many FFTs is very quick. | ||
- CuPy using CUDA can be very fast, but is currently limited because we are | ||
transferring one image at a time to the GPU. | ||
|
||
""" | ||
import time | ||
import matplotlib.pylab as plt | ||
import numpy as np | ||
import qpretrieve | ||
|
||
# load the experimental data | ||
edata = np.load("./data/hologram_cell.npz") | ||
|
||
# get the available fft interfaces | ||
interfaces_available = qpretrieve.fourier.get_available_interfaces() | ||
|
||
n_transforms = 100 | ||
|
||
# one transform | ||
results_1 = {} | ||
for fft_interface in interfaces_available: | ||
t0 = time.time() | ||
holo = qpretrieve.OffAxisHologram(data=edata["data"], | ||
fft_interface=fft_interface) | ||
holo.run_pipeline(filter_name="disk", filter_size=1 / 2) | ||
bg = qpretrieve.OffAxisHologram(data=edata["bg_data"]) | ||
bg.process_like(holo) | ||
t1 = time.time() | ||
results_1[fft_interface.__name__] = t1 - t0 | ||
num_interfaces = len(results_1) | ||
|
||
# multiple transforms (should see speed increase for PyFFTW) | ||
results = {} | ||
for fft_interface in interfaces_available: | ||
t0 = time.time() | ||
for _ in range(n_transforms): | ||
holo = qpretrieve.OffAxisHologram(data=edata["data"], | ||
fft_interface=fft_interface) | ||
holo.run_pipeline(filter_name="disk", filter_size=1 / 2) | ||
bg = qpretrieve.OffAxisHologram(data=edata["bg_data"]) | ||
bg.process_like(holo) | ||
t1 = time.time() | ||
results[fft_interface.__name__] = t1 - t0 | ||
num_interfaces = len(results) | ||
|
||
fft_interfaces = list(results.keys()) | ||
speed_1 = list(results_1.values()) | ||
speed = list(results.values()) | ||
|
||
fig, axes = plt.subplots(1, 2, figsize=(8, 5)) | ||
ax1, ax2 = axes | ||
labels = [fftstr[9:] for fftstr in fft_interfaces] | ||
|
||
ax1.bar(range(num_interfaces), height=speed_1, color='lightseagreen') | ||
ax1.set_xticks(range(num_interfaces), labels=labels, | ||
rotation=45) | ||
ax1.set_ylabel("Speed (s)") | ||
ax1.set_title("1 Transform") | ||
|
||
ax2.bar(range(num_interfaces), height=speed, color='lightseagreen') | ||
ax2.set_xticks(range(num_interfaces), labels=labels, | ||
rotation=45) | ||
ax2.set_ylabel("Speed (s)") | ||
ax2.set_title(f"{n_transforms} Transforms") | ||
|
||
plt.suptitle("Speed of FFT Interfaces") | ||
plt.tight_layout() | ||
plt.show() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
matplotlib |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -38,7 +38,7 @@ def get_filter_array(filter_name, filter_size, freq_pos, fft_shape): | |
and must be between 0 and `max(fft_shape)/2` | ||
freq_pos: tuple of floats | ||
The position of the filter in frequency coordinates as | ||
returned by :func:`nunpy.fft.fftfreq`. | ||
returned by :func:`numpy.fft.fftfreq`. | ||
fft_shape: tuple of int | ||
The shape of the Fourier transformed image for which the | ||
filter will be applied. The shape must be squared (two | ||
|
@@ -104,8 +104,10 @@ def get_filter_array(filter_name, filter_size, freq_pos, fft_shape): | |
# TODO: avoid the np.roll, instead use the indices directly | ||
alpha = 0.1 | ||
rsize = int(min(fx.size, fy.size) * filter_size) * 2 | ||
tukey_window_x = signal.tukey(rsize, alpha=alpha).reshape(-1, 1) | ||
tukey_window_y = signal.tukey(rsize, alpha=alpha).reshape(1, -1) | ||
tukey_window_x = signal.windows.tukey( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. scipy's new versions imports tukey from |
||
rsize, alpha=alpha).reshape(-1, 1) | ||
tukey_window_y = signal.windows.tukey( | ||
rsize, alpha=alpha).reshape(1, -1) | ||
tukey = tukey_window_x * tukey_window_y | ||
base = np.zeros(fft_shape) | ||
s1 = (np.array(fft_shape) - rsize) // 2 | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,15 +2,36 @@ | |
import warnings | ||
|
||
from .ff_numpy import FFTFilterNumpy | ||
from .ff_scipy import FFTFilterScipy | ||
|
||
try: | ||
from .ff_pyfftw import FFTFilterPyFFTW | ||
except ImportError: | ||
FFTFilterPyFFTW = None | ||
|
||
try: | ||
from .ff_cupy import FFTFilterCupy | ||
except ImportError: | ||
FFTFilterCupy = None | ||
|
||
PREFERRED_INTERFACE = None | ||
|
||
|
||
def get_available_interfaces(): | ||
"""Return a list of available FFT algorithms""" | ||
interfaces = [ | ||
FFTFilterPyFFTW, | ||
FFTFilterNumpy, | ||
FFTFilterScipy, | ||
FFTFilterCupy, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This isn't necessarily the "perferred order" we want, but I'd like to keep it as is due to old default pipelines. |
||
] | ||
interfaces_available = [] | ||
for interface in interfaces: | ||
if interface is not None and interface.is_available: | ||
interfaces_available.append(interface) | ||
return interfaces_available | ||
|
||
|
||
def get_best_interface(): | ||
"""Return the fastest refocusing interface available | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -70,8 +70,11 @@ def __init__(self, data, subtract_mean=True, padding=2, copy=True): | |
else: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here we check and allow complex data. But in the docstring above (originally), we specify real-valued inputs. In what situation would we take complex data? |
||
# convert integer-arrays to floating point arrays | ||
dtype = float | ||
if not copy: | ||
# numpy v2.x behaviour requires asarray with copy=False | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. numpy has made copying a bit awkward. When an array can't be copied (with either |
||
copy = None | ||
data_ed = np.array(data, dtype=dtype, copy=copy) | ||
#: original data (with subtracted mean) | ||
#: original data (with subtracted mean) | ||
self.origin = data_ed | ||
#: whether padding is enabled | ||
self.padding = padding | ||
|
@@ -175,7 +178,7 @@ def filter(self, filter_name: str, filter_size: float, | |
and must be between 0 and `max(fft_shape)/2` | ||
freq_pos: tuple of floats | ||
The position of the filter in frequency coordinates as | ||
returned by :func:`nunpy.fft.fftfreq`. | ||
returned by :func:`numpy.fft.fftfreq`. | ||
scale_to_filter: bool or float | ||
Crop the image in Fourier space after applying the filter, | ||
effectively removing surplus (zero-padding) data and | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import scipy as sp | ||
import cupy as cp | ||
import cupyx.scipy.fft as cufft | ||
|
||
from .base import FFTFilter | ||
|
||
|
||
class FFTFilterCupy(FFTFilter): | ||
"""Wraps the cupy Fourier transform and uses it via the scipy backend | ||
""" | ||
is_available = True | ||
# sp.fft.set_backend(cufft) | ||
|
||
def _init_fft(self, data): | ||
"""Perform initial Fourier transform of the input data | ||
|
||
Parameters | ||
---------- | ||
data: 2d real-valued np.ndarray | ||
Input field to be refocused | ||
|
||
Returns | ||
------- | ||
fft_fdata: 2d complex-valued ndarray | ||
Fourier transform `data` | ||
""" | ||
data_gpu = cp.asarray(data) | ||
# likely an inefficiency here, could use `set_global_backend` | ||
with sp.fft.set_backend(cufft): | ||
fft_gpu = sp.fft.fft2(data_gpu) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not ideal, this is something to work on. |
||
fft_cpu = fft_gpu.get() | ||
return fft_cpu | ||
|
||
def _ifft(self, data): | ||
"""Perform inverse Fourier transform""" | ||
data_gpu = cp.asarray(data) | ||
with sp.fft.set_backend(cufft): | ||
ifft_gpu = sp.fft.ifft2(data_gpu) | ||
ifft_cpu = ifft_gpu.get() | ||
return ifft_cpu |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
import scipy as sp | ||
|
||
|
||
from .base import FFTFilter | ||
|
||
|
||
class FFTFilterScipy(FFTFilter): | ||
"""Wraps the scipy Fourier transform | ||
""" | ||
# always available, because scipy is a dependency | ||
is_available = True | ||
|
||
def _init_fft(self, data): | ||
"""Perform initial Fourier transform of the input data | ||
|
||
Parameters | ||
---------- | ||
data: 2d real-valued np.ndarray | ||
Input field to be refocused | ||
|
||
Returns | ||
------- | ||
fft_fdata: 2d complex-valued ndarray | ||
Fourier transform `data` | ||
""" | ||
return sp.fft.fft2(data) | ||
|
||
def _ifft(self, data): | ||
"""Perform inverse Fourier transform""" | ||
return sp.fft.ifft2(data) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,6 @@ | |
from setuptools import setup, find_packages | ||
import sys | ||
|
||
|
||
author = u"Paul Müller" | ||
authors = [author] | ||
description = 'library for phase retrieval from holograms' | ||
|
@@ -27,8 +26,13 @@ | |
"numpy>=1.9.0", | ||
"scikit-image>=0.11.0", | ||
"scipy>=0.18.0", | ||
], | ||
extras_require={"FFTW": "pyfftw>=0.12.0"}, | ||
], | ||
extras_require={ | ||
"FFTW": "pyfftw>=0.12.0", | ||
# manually install 'cupy-cuda11x' if you have older CUDA. | ||
# See https://cupy.dev/ | ||
"CUPY": "cupy-cuda12x", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will have to check this version. |
||
}, | ||
python_requires='>=3.10, <4', | ||
keywords=["digital holographic microscopy", | ||
"optics", | ||
|
@@ -41,6 +45,6 @@ | |
'Operating System :: OS Independent', | ||
'Programming Language :: Python :: 3', | ||
'Intended Audience :: Science/Research' | ||
], | ||
], | ||
platforms=['ALL'], | ||
) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I need to clean this script up a bit to make it simpler