Skip to content

Commit

Permalink
Add combined affine and elastic deformation augmentation
Browse files Browse the repository at this point in the history
SimpleITK's ResampleImageFilter is an expensive operation,
especially when its called sequentially for both affine and
elastic transformations. By combining the SimpleITK transforms
for both augmentations, the processing time for a sample can be
reduced.

Currently the augmentation uses a combined probability for applying the
transform, rather than independently applying them.

Resolves: #1052
  • Loading branch information
mrdkucher committed Apr 20, 2023
1 parent 640d6e1 commit b3559b7
Show file tree
Hide file tree
Showing 8 changed files with 778 additions and 21 deletions.
7 changes: 7 additions & 0 deletions docs/source/transforms/augmentation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,13 @@ Spatial
:show-inheritance:


:class:`RandomCombinedAffineElasticDeformation`
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: RandomCombinedAffineElasticDeformation
:show-inheritance:


:class:`RandomAnisotropy`
^^^^^^^^^^^^^^^^^^^^^^^^^

Expand Down
6 changes: 6 additions & 0 deletions src/torchio/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
from .augmentation.spatial import RandomAffine, Affine
from .augmentation.spatial import RandomAnisotropy
from .augmentation.spatial import RandomElasticDeformation, ElasticDeformation
from .augmentation.spatial import (
RandomCombinedAffineElasticDeformation,
CombinedAffineElasticDeformation,
)

from .augmentation.intensity import RandomSwap, Swap
from .augmentation.intensity import RandomBlur, Blur
Expand Down Expand Up @@ -67,6 +71,8 @@
'RandomAnisotropy',
'RandomElasticDeformation',
'ElasticDeformation',
'RandomCombinedAffineElasticDeformation',
'CombinedAffineElasticDeformation',
'RandomSwap',
'Swap',
'RandomBlur',
Expand Down
5 changes: 3 additions & 2 deletions src/torchio/transforms/augmentation/random_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,9 @@ def _get_random_seed() -> int:
"""
return int(torch.randint(0, 2**31, (1,)).item())

def sample_uniform_sextet(self, params):
@staticmethod
def sample_uniform_sextet(params):
results = []
for a, b in zip(params[::2], params[1::2]):
results.append(self.sample_uniform(a, b))
results.append(RandomTransform.sample_uniform(a, b))
return torch.Tensor(results)
4 changes: 4 additions & 0 deletions src/torchio/transforms/augmentation/spatial/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from .random_elastic_deformation import RandomElasticDeformation
from .random_flip import Flip
from .random_flip import RandomFlip
from .random_affine_elastic_deformation import CombinedAffineElasticDeformation
from .random_affine_elastic_deformation import RandomCombinedAffineElasticDeformation


__all__ = [
Expand All @@ -15,4 +17,6 @@
'RandomAnisotropy',
'RandomElasticDeformation',
'ElasticDeformation',
'RandomCombinedAffineElasticDeformation',
'CombinedAffineElasticDeformation',
]
45 changes: 26 additions & 19 deletions src/torchio/transforms/augmentation/spatial/random_affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,18 +141,18 @@ def __init__(
)
self.check_shape = check_shape

@staticmethod
def get_params(
self,
scales: TypeSextetFloat,
degrees: TypeSextetFloat,
translation: TypeSextetFloat,
isotropic: bool,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
scaling_params = self.sample_uniform_sextet(scales)
scaling_params = RandomTransform.sample_uniform_sextet(scales)
if isotropic:
scaling_params.fill_(scaling_params[0])
rotation_params = self.sample_uniform_sextet(degrees)
translation_params = self.sample_uniform_sextet(translation)
rotation_params = RandomTransform.sample_uniform_sextet(degrees)
translation_params = RandomTransform.sample_uniform_sextet(translation)
return scaling_params, rotation_params, translation_params

def apply_transform(self, subject: Subject) -> Subject:
Expand Down Expand Up @@ -345,6 +345,27 @@ def get_affine_transform(self, image):

return transform

def get_default_pad_value(
self, tensor: torch.Tensor, sitk_image: sitk.Image
) -> float:
default_value: float
if self.default_pad_value == 'minimum':
default_value = tensor.min().item()
elif self.default_pad_value == 'mean':
default_value = get_borders_mean(
sitk_image,
filter_otsu=False,
)
elif self.default_pad_value == 'otsu':
default_value = get_borders_mean(
sitk_image,
filter_otsu=True,
)
else:
assert isinstance(self.default_pad_value, Number)
default_value = float(self.default_pad_value)
return default_value

def apply_transform(self, subject: Subject) -> Subject:
if self.check_shape:
subject.check_consistent_spatial_shape()
Expand All @@ -363,21 +384,7 @@ def apply_transform(self, subject: Subject) -> Subject:
default_value = 0
else:
interpolation = self.image_interpolation
if self.default_pad_value == 'minimum':
default_value = tensor.min().item()
elif self.default_pad_value == 'mean':
default_value = get_borders_mean(
sitk_image,
filter_otsu=False,
)
elif self.default_pad_value == 'otsu':
default_value = get_borders_mean(
sitk_image,
filter_otsu=True,
)
else:
assert isinstance(self.default_pad_value, Number)
default_value = float(self.default_pad_value)
default_value = self.get_default_pad_value(tensor, sitk_image)
transformed_tensor = self.apply_affine_transform(
sitk_image,
transform,
Expand Down
Loading

0 comments on commit b3559b7

Please sign in to comment.