-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
fb422e8
commit e360936
Showing
3 changed files
with
278 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
"""Grids for 2D/3D image manipulations in PyTorch""" | ||
|
||
from importlib.metadata import PackageNotFoundError, version | ||
|
||
try: | ||
__version__ = version("torch-grid-utils") | ||
except PackageNotFoundError: | ||
__version__ = "uninstalled" | ||
__author__ = "Alister Burt" | ||
__email__ = "[email protected]" | ||
|
||
from .fftfreq_grid import fftfreq_grid | ||
from .coordinate_grid import coordinate_grid |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
from typing import Sequence | ||
|
||
import einops | ||
import numpy as np | ||
import torch | ||
|
||
|
||
def coordinate_grid( | ||
image_shape: Sequence[int], | ||
center: torch.Tensor | tuple[float, ...] | None = False, | ||
norm: bool = False, | ||
device: torch.device | None = None, | ||
) -> torch.FloatTensor: | ||
"""Get a dense grid of array coordinates from grid dimensions. | ||
For input `image_shape` of `(d, h, w)`, this function produces a | ||
`(d, h, w, 3)` grid of coordinates. Coordinate order matches the order of | ||
dimensions in `image_shape`. | ||
Parameters | ||
---------- | ||
image_shape: Sequence[int] | ||
Shape of the image for which coordinates should be returned. | ||
center: torch.Tensor | tuple[float, ...] | None | ||
Array of center points relative to which coordinates will be calculated. | ||
If `None`, default to the array origin `[0, ...]` of zero in all dimensions. | ||
norm: bool | ||
Whether to compute the Euclidean norm of the coordinate grid. | ||
device: torch.device | ||
PyTorch device on which to put the coordinate grid. | ||
Returns | ||
------- | ||
grid: torch.LongTensor | ||
`(*image_shape, image_ndim)` array of coordinates if `norm` is `False` | ||
else `(*image_shape, )`. | ||
""" | ||
grid = torch.tensor( | ||
np.indices(image_shape), | ||
device=device, | ||
dtype=torch.float32 | ||
) # (coordinates, *image_shape) | ||
grid = einops.rearrange(grid, 'coords ... -> ... coords') | ||
ndim = len(image_shape) | ||
if center is not None: | ||
center = torch.as_tensor(center, dtype=grid.dtype, device=grid.device) | ||
center = torch.atleast_1d(center) | ||
center, ps = einops.pack([center], pattern='* coords') | ||
ones = ' '.join('1' * ndim) | ||
axis_ids = ' '.join(_unique_characters(ndim)) | ||
center = einops.rearrange(center, f"b coords -> b {ones} coords") | ||
grid = grid - center | ||
[grid] = einops.unpack(grid, packed_shapes=ps, pattern=f'* {axis_ids} coords') | ||
if norm is True: | ||
grid = einops.reduce(grid ** 2, '... coords -> ...', reduction='sum') ** 0.5 | ||
return grid | ||
|
||
|
||
def _unique_characters(n: int) -> str: | ||
chars = "abcdefghijklmnopqrstuvwxyz" | ||
return chars[:n] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,204 @@ | ||
import functools | ||
from typing import Sequence | ||
|
||
import einops | ||
import torch | ||
|
||
|
||
@functools.lru_cache(maxsize=1) | ||
def fftfreq_grid( | ||
image_shape: tuple[int, int] | tuple[int, int, int], | ||
rfft: bool, | ||
fftshift: bool = False, | ||
spacing: float | tuple[float, float] | tuple[float, float, float] = 1, | ||
norm: bool = False, | ||
device: torch.device | None = None, | ||
) -> torch.Tensor: | ||
"""Construct a 2D or 3D grid of DFT sample frequencies. | ||
For a 2D image with shape `(h, w)` and `rfft=False` this function will produce | ||
a `(h, w, 2)` array of DFT sample frequencies in the `h` and `w` dimensions. | ||
If `norm` is True the Euclidean norm will be calculated over the last dimension | ||
leaving a `(h, w)` grid. | ||
Parameters | ||
---------- | ||
image_shape: tuple[int, int] | tuple[int, int, int] | ||
Shape of the 2D or 3D image before computing the DFT. | ||
rfft: bool | ||
Whether the output should contain frequencies for a real-valued DFT. | ||
fftshift: bool | ||
Whether to fftshift the output grid. | ||
spacing: float | tuple[float, float] | tuple[float, float, float] | ||
Spacing between samples in each dimension. Sampling is considered to be | ||
isotropic if a single value is passed. | ||
norm: bool | ||
Whether to compute the Euclidean norm over the last dimension. | ||
device: torch.device | None | ||
PyTorch device on which the returned grid will be stored. | ||
Returns | ||
------- | ||
frequency_grid: torch.Tensor | ||
`(*image_shape, ndim)` array of DFT sample frequencies in each | ||
image dimension if `norm` is `False` else `(*image_shape, )`. | ||
""" | ||
if len(image_shape) == 2: | ||
frequency_grid = _construct_fftfreq_grid_2d( | ||
image_shape=image_shape, | ||
rfft=rfft, | ||
spacing=spacing, | ||
device=device, | ||
) | ||
if fftshift is True: | ||
frequency_grid = einops.rearrange(frequency_grid, '... freq -> freq ...') | ||
frequency_grid = fftshift_2d(frequency_grid, rfft=rfft) | ||
frequency_grid = einops.rearrange(frequency_grid, 'freq ... -> ... freq') | ||
elif len(image_shape) == 3: | ||
frequency_grid = _construct_fftfreq_grid_3d( | ||
image_shape=image_shape, | ||
rfft=rfft, | ||
spacing=spacing, | ||
device=device, | ||
) | ||
if fftshift is True: | ||
frequency_grid = einops.rearrange(frequency_grid, '... freq -> freq ...') | ||
frequency_grid = fftshift_3d(frequency_grid, rfft=rfft) | ||
frequency_grid = einops.rearrange(frequency_grid, 'freq ... -> ... freq') | ||
else: | ||
raise NotImplementedError( | ||
"Construction of fftfreq grids is currently only supported for " | ||
"2D and 3D images." | ||
) | ||
if norm is True: | ||
frequency_grid = einops.reduce( | ||
frequency_grid ** 2, '... squared_freqs -> ...', reduction='sum' | ||
) ** 0.5 | ||
return frequency_grid | ||
|
||
|
||
def _construct_fftfreq_grid_2d( | ||
image_shape: tuple[int, int], | ||
rfft: bool, | ||
spacing: float | tuple[float, float] = 1, | ||
device: torch.device = None | ||
) -> torch.Tensor: | ||
"""Construct a grid of DFT sample freqs for a 2D image. | ||
Parameters | ||
---------- | ||
image_shape: Sequence[int] | ||
A 2D shape `(h, w)` of the input image for which a grid of DFT sample freqs | ||
should be calculated. | ||
rfft: bool | ||
Whether the frequency grid is for a real fft (rfft). | ||
spacing: float | tuple[float, float] | ||
Sample spacing in `h` and `w` dimensions of the grid. | ||
device: torch.device | ||
Torch device for the resulting grid. | ||
Returns | ||
------- | ||
frequency_grid: torch.Tensor | ||
`(h, w, 2)` array of DFT sample freqs. | ||
Order of freqs in the last dimension corresponds to the order of | ||
the two dimensions of the grid. | ||
""" | ||
dh, dw = spacing if isinstance(spacing, Sequence) else [spacing] * 2 | ||
last_axis_frequency_func = torch.fft.rfftfreq if rfft is True else torch.fft.fftfreq | ||
h, w = image_shape | ||
freq_y = torch.fft.fftfreq(h, d=dh, device=device) | ||
freq_x = last_axis_frequency_func(w, d=dw, device=device) | ||
h, w = rfft_shape(image_shape) if rfft is True else image_shape | ||
freq_yy = einops.repeat(freq_y, 'h -> h w', w=w) | ||
freq_xx = einops.repeat(freq_x, 'w -> h w', h=h) | ||
return einops.rearrange([freq_yy, freq_xx], 'freq h w -> h w freq') | ||
|
||
|
||
def _construct_fftfreq_grid_3d( | ||
image_shape: Sequence[int], | ||
rfft: bool, | ||
spacing: float | tuple[float, float, float] = 1, | ||
device: torch.device = None | ||
) -> torch.Tensor: | ||
"""Construct a grid of DFT sample freqs for a 3D image. | ||
Parameters | ||
---------- | ||
image_shape: Sequence[int] | ||
A 3D shape `(d, h, w)` of the input image for which a grid of DFT sample freqs | ||
should be calculated. | ||
rfft: bool | ||
Controls Whether the frequency grid is for a real fft (rfft). | ||
spacing: float | tuple[float, float, float] | ||
Sample spacing in `d`, `h` and `w` dimensions of the grid. | ||
device: torch.device | ||
Torch device for the resulting grid. | ||
Returns | ||
------- | ||
frequency_grid: torch.Tensor | ||
`(h, w, 3)` array of DFT sample freqs. | ||
Order of freqs in the last dimension corresponds to the order of dimensions | ||
of the grid. | ||
""" | ||
dd, dh, dw = spacing if isinstance(spacing, Sequence) else [spacing] * 3 | ||
last_axis_frequency_func = torch.fft.rfftfreq if rfft is True else torch.fft.fftfreq | ||
d, h, w = image_shape | ||
freq_z = torch.fft.fftfreq(d, d=dd, device=device) | ||
freq_y = torch.fft.fftfreq(h, d=dh, device=device) | ||
freq_x = last_axis_frequency_func(w, d=dw, device=device) | ||
d, h, w = rfft_shape(image_shape) if rfft is True else image_shape | ||
freq_zz = einops.repeat(freq_z, 'd -> d h w', h=h, w=w) | ||
freq_yy = einops.repeat(freq_y, 'h -> d h w', d=d, w=w) | ||
freq_xx = einops.repeat(freq_x, 'w -> d h w', d=d, h=h) | ||
return einops.rearrange([freq_zz, freq_yy, freq_xx], 'freq ... -> ... freq') | ||
|
||
|
||
def rfft_shape(input_shape: Sequence[int]) -> tuple[int]: | ||
"""Get the output shape of an rfft on an input with input_shape.""" | ||
rfft_shape = list(input_shape) | ||
rfft_shape[-1] = int((rfft_shape[-1] / 2) + 1) | ||
return tuple(rfft_shape) | ||
|
||
|
||
def dft_center( | ||
image_shape: tuple[int, ...], | ||
rfft: bool, | ||
fftshifted: bool, | ||
device: torch.device | None = None, | ||
) -> torch.LongTensor: | ||
"""Return the position of the DFT center for a given input shape.""" | ||
fft_center = torch.zeros(size=(len(image_shape),), device=device) | ||
image_shape = torch.as_tensor(image_shape).float() | ||
if rfft is True: | ||
image_shape = torch.tensor(rfft_shape(image_shape)) | ||
if fftshifted is True: | ||
fft_center = torch.divide(image_shape, 2, rounding_mode='floor') | ||
if rfft is True: | ||
fft_center[-1] = 0 | ||
return fft_center.long() | ||
|
||
|
||
def fftshift_2d(input: torch.Tensor, rfft: bool): | ||
if rfft is False: | ||
output = torch.fft.fftshift(input, dim=(-2, -1)) | ||
else: | ||
output = torch.fft.fftshift(input, dim=(-2,)) | ||
return output | ||
|
||
|
||
def ifftshift_2d(input: torch.Tensor, rfft: bool): | ||
if rfft is False: | ||
output = torch.fft.ifftshift(input, dim=(-2, -1)) | ||
else: | ||
output = torch.fft.ifftshift(input, dim=(-2,)) | ||
return output | ||
|
||
|
||
def fftshift_3d(input: torch.Tensor, rfft: bool): | ||
if rfft is False: | ||
output = torch.fft.fftshift(input, dim=(-3, -2, -1)) | ||
else: | ||
output = torch.fft.fftshift(input, dim=(-3, -2,)) | ||
return output |