Skip to content

Commit

Permalink
add code
Browse files Browse the repository at this point in the history
  • Loading branch information
alisterburt committed May 26, 2024
1 parent fb422e8 commit e360936
Show file tree
Hide file tree
Showing 3 changed files with 278 additions and 0 deletions.
13 changes: 13 additions & 0 deletions src/torch_grid_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""Grids for 2D/3D image manipulations in PyTorch"""

from importlib.metadata import PackageNotFoundError, version

try:
__version__ = version("torch-grid-utils")
except PackageNotFoundError:
__version__ = "uninstalled"
__author__ = "Alister Burt"
__email__ = "[email protected]"

from .fftfreq_grid import fftfreq_grid
from .coordinate_grid import coordinate_grid
61 changes: 61 additions & 0 deletions src/torch_grid_utils/coordinate_grid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from typing import Sequence

import einops
import numpy as np
import torch


def coordinate_grid(
image_shape: Sequence[int],
center: torch.Tensor | tuple[float, ...] | None = False,
norm: bool = False,
device: torch.device | None = None,
) -> torch.FloatTensor:
"""Get a dense grid of array coordinates from grid dimensions.
For input `image_shape` of `(d, h, w)`, this function produces a
`(d, h, w, 3)` grid of coordinates. Coordinate order matches the order of
dimensions in `image_shape`.
Parameters
----------
image_shape: Sequence[int]
Shape of the image for which coordinates should be returned.
center: torch.Tensor | tuple[float, ...] | None
Array of center points relative to which coordinates will be calculated.
If `None`, default to the array origin `[0, ...]` of zero in all dimensions.
norm: bool
Whether to compute the Euclidean norm of the coordinate grid.
device: torch.device
PyTorch device on which to put the coordinate grid.
Returns
-------
grid: torch.LongTensor
`(*image_shape, image_ndim)` array of coordinates if `norm` is `False`
else `(*image_shape, )`.
"""
grid = torch.tensor(
np.indices(image_shape),
device=device,
dtype=torch.float32
) # (coordinates, *image_shape)
grid = einops.rearrange(grid, 'coords ... -> ... coords')
ndim = len(image_shape)
if center is not None:
center = torch.as_tensor(center, dtype=grid.dtype, device=grid.device)
center = torch.atleast_1d(center)
center, ps = einops.pack([center], pattern='* coords')
ones = ' '.join('1' * ndim)
axis_ids = ' '.join(_unique_characters(ndim))
center = einops.rearrange(center, f"b coords -> b {ones} coords")
grid = grid - center
[grid] = einops.unpack(grid, packed_shapes=ps, pattern=f'* {axis_ids} coords')
if norm is True:
grid = einops.reduce(grid ** 2, '... coords -> ...', reduction='sum') ** 0.5
return grid


def _unique_characters(n: int) -> str:
chars = "abcdefghijklmnopqrstuvwxyz"
return chars[:n]
204 changes: 204 additions & 0 deletions src/torch_grid_utils/fftfreq_grid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
import functools
from typing import Sequence

import einops
import torch


@functools.lru_cache(maxsize=1)
def fftfreq_grid(
image_shape: tuple[int, int] | tuple[int, int, int],
rfft: bool,
fftshift: bool = False,
spacing: float | tuple[float, float] | tuple[float, float, float] = 1,
norm: bool = False,
device: torch.device | None = None,
) -> torch.Tensor:
"""Construct a 2D or 3D grid of DFT sample frequencies.
For a 2D image with shape `(h, w)` and `rfft=False` this function will produce
a `(h, w, 2)` array of DFT sample frequencies in the `h` and `w` dimensions.
If `norm` is True the Euclidean norm will be calculated over the last dimension
leaving a `(h, w)` grid.
Parameters
----------
image_shape: tuple[int, int] | tuple[int, int, int]
Shape of the 2D or 3D image before computing the DFT.
rfft: bool
Whether the output should contain frequencies for a real-valued DFT.
fftshift: bool
Whether to fftshift the output grid.
spacing: float | tuple[float, float] | tuple[float, float, float]
Spacing between samples in each dimension. Sampling is considered to be
isotropic if a single value is passed.
norm: bool
Whether to compute the Euclidean norm over the last dimension.
device: torch.device | None
PyTorch device on which the returned grid will be stored.
Returns
-------
frequency_grid: torch.Tensor
`(*image_shape, ndim)` array of DFT sample frequencies in each
image dimension if `norm` is `False` else `(*image_shape, )`.
"""
if len(image_shape) == 2:
frequency_grid = _construct_fftfreq_grid_2d(
image_shape=image_shape,
rfft=rfft,
spacing=spacing,
device=device,
)
if fftshift is True:
frequency_grid = einops.rearrange(frequency_grid, '... freq -> freq ...')
frequency_grid = fftshift_2d(frequency_grid, rfft=rfft)
frequency_grid = einops.rearrange(frequency_grid, 'freq ... -> ... freq')
elif len(image_shape) == 3:
frequency_grid = _construct_fftfreq_grid_3d(
image_shape=image_shape,
rfft=rfft,
spacing=spacing,
device=device,
)
if fftshift is True:
frequency_grid = einops.rearrange(frequency_grid, '... freq -> freq ...')
frequency_grid = fftshift_3d(frequency_grid, rfft=rfft)
frequency_grid = einops.rearrange(frequency_grid, 'freq ... -> ... freq')
else:
raise NotImplementedError(
"Construction of fftfreq grids is currently only supported for "
"2D and 3D images."
)
if norm is True:
frequency_grid = einops.reduce(
frequency_grid ** 2, '... squared_freqs -> ...', reduction='sum'
) ** 0.5
return frequency_grid


def _construct_fftfreq_grid_2d(
image_shape: tuple[int, int],
rfft: bool,
spacing: float | tuple[float, float] = 1,
device: torch.device = None
) -> torch.Tensor:
"""Construct a grid of DFT sample freqs for a 2D image.
Parameters
----------
image_shape: Sequence[int]
A 2D shape `(h, w)` of the input image for which a grid of DFT sample freqs
should be calculated.
rfft: bool
Whether the frequency grid is for a real fft (rfft).
spacing: float | tuple[float, float]
Sample spacing in `h` and `w` dimensions of the grid.
device: torch.device
Torch device for the resulting grid.
Returns
-------
frequency_grid: torch.Tensor
`(h, w, 2)` array of DFT sample freqs.
Order of freqs in the last dimension corresponds to the order of
the two dimensions of the grid.
"""
dh, dw = spacing if isinstance(spacing, Sequence) else [spacing] * 2
last_axis_frequency_func = torch.fft.rfftfreq if rfft is True else torch.fft.fftfreq
h, w = image_shape
freq_y = torch.fft.fftfreq(h, d=dh, device=device)
freq_x = last_axis_frequency_func(w, d=dw, device=device)
h, w = rfft_shape(image_shape) if rfft is True else image_shape
freq_yy = einops.repeat(freq_y, 'h -> h w', w=w)
freq_xx = einops.repeat(freq_x, 'w -> h w', h=h)
return einops.rearrange([freq_yy, freq_xx], 'freq h w -> h w freq')


def _construct_fftfreq_grid_3d(
image_shape: Sequence[int],
rfft: bool,
spacing: float | tuple[float, float, float] = 1,
device: torch.device = None
) -> torch.Tensor:
"""Construct a grid of DFT sample freqs for a 3D image.
Parameters
----------
image_shape: Sequence[int]
A 3D shape `(d, h, w)` of the input image for which a grid of DFT sample freqs
should be calculated.
rfft: bool
Controls Whether the frequency grid is for a real fft (rfft).
spacing: float | tuple[float, float, float]
Sample spacing in `d`, `h` and `w` dimensions of the grid.
device: torch.device
Torch device for the resulting grid.
Returns
-------
frequency_grid: torch.Tensor
`(h, w, 3)` array of DFT sample freqs.
Order of freqs in the last dimension corresponds to the order of dimensions
of the grid.
"""
dd, dh, dw = spacing if isinstance(spacing, Sequence) else [spacing] * 3
last_axis_frequency_func = torch.fft.rfftfreq if rfft is True else torch.fft.fftfreq
d, h, w = image_shape
freq_z = torch.fft.fftfreq(d, d=dd, device=device)
freq_y = torch.fft.fftfreq(h, d=dh, device=device)
freq_x = last_axis_frequency_func(w, d=dw, device=device)
d, h, w = rfft_shape(image_shape) if rfft is True else image_shape
freq_zz = einops.repeat(freq_z, 'd -> d h w', h=h, w=w)
freq_yy = einops.repeat(freq_y, 'h -> d h w', d=d, w=w)
freq_xx = einops.repeat(freq_x, 'w -> d h w', d=d, h=h)
return einops.rearrange([freq_zz, freq_yy, freq_xx], 'freq ... -> ... freq')


def rfft_shape(input_shape: Sequence[int]) -> tuple[int]:
"""Get the output shape of an rfft on an input with input_shape."""
rfft_shape = list(input_shape)
rfft_shape[-1] = int((rfft_shape[-1] / 2) + 1)
return tuple(rfft_shape)


def dft_center(
image_shape: tuple[int, ...],
rfft: bool,
fftshifted: bool,
device: torch.device | None = None,
) -> torch.LongTensor:
"""Return the position of the DFT center for a given input shape."""
fft_center = torch.zeros(size=(len(image_shape),), device=device)
image_shape = torch.as_tensor(image_shape).float()
if rfft is True:
image_shape = torch.tensor(rfft_shape(image_shape))
if fftshifted is True:
fft_center = torch.divide(image_shape, 2, rounding_mode='floor')
if rfft is True:
fft_center[-1] = 0
return fft_center.long()


def fftshift_2d(input: torch.Tensor, rfft: bool):
if rfft is False:
output = torch.fft.fftshift(input, dim=(-2, -1))
else:
output = torch.fft.fftshift(input, dim=(-2,))
return output


def ifftshift_2d(input: torch.Tensor, rfft: bool):
if rfft is False:
output = torch.fft.ifftshift(input, dim=(-2, -1))
else:
output = torch.fft.ifftshift(input, dim=(-2,))
return output


def fftshift_3d(input: torch.Tensor, rfft: bool):
if rfft is False:
output = torch.fft.fftshift(input, dim=(-3, -2, -1))
else:
output = torch.fft.fftshift(input, dim=(-3, -2,))
return output

0 comments on commit e360936

Please sign in to comment.