diff --git a/docs/source/transforms/augmentation.rst b/docs/source/transforms/augmentation.rst index fcf238fc..206f9046 100644 --- a/docs/source/transforms/augmentation.rst +++ b/docs/source/transforms/augmentation.rst @@ -71,6 +71,13 @@ Spatial :show-inheritance: +:class:`RandomCombinedAffineElasticDeformation` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: RandomCombinedAffineElasticDeformation + :show-inheritance: + + :class:`RandomAnisotropy` ^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/src/torchio/transforms/__init__.py b/src/torchio/transforms/__init__.py index 087dbe5b..67c5cf00 100644 --- a/src/torchio/transforms/__init__.py +++ b/src/torchio/transforms/__init__.py @@ -17,6 +17,10 @@ from .augmentation.spatial import RandomAffine, Affine from .augmentation.spatial import RandomAnisotropy from .augmentation.spatial import RandomElasticDeformation, ElasticDeformation +from .augmentation.spatial import ( + RandomCombinedAffineElasticDeformation, + CombinedAffineElasticDeformation, +) from .augmentation.intensity import RandomSwap, Swap from .augmentation.intensity import RandomBlur, Blur @@ -67,6 +71,8 @@ 'RandomAnisotropy', 'RandomElasticDeformation', 'ElasticDeformation', + 'RandomCombinedAffineElasticDeformation', + 'CombinedAffineElasticDeformation', 'RandomSwap', 'Swap', 'RandomBlur', diff --git a/src/torchio/transforms/augmentation/random_transform.py b/src/torchio/transforms/augmentation/random_transform.py index 86a5df79..353577b3 100644 --- a/src/torchio/transforms/augmentation/random_transform.py +++ b/src/torchio/transforms/augmentation/random_transform.py @@ -49,8 +49,9 @@ def _get_random_seed() -> int: """ return int(torch.randint(0, 2**31, (1,)).item()) - def sample_uniform_sextet(self, params): + @staticmethod + def sample_uniform_sextet(params): results = [] for a, b in zip(params[::2], params[1::2]): - results.append(self.sample_uniform(a, b)) + results.append(RandomTransform.sample_uniform(a, b)) return torch.Tensor(results) diff --git a/src/torchio/transforms/augmentation/spatial/__init__.py b/src/torchio/transforms/augmentation/spatial/__init__.py index 119769cb..edfb1496 100644 --- a/src/torchio/transforms/augmentation/spatial/__init__.py +++ b/src/torchio/transforms/augmentation/spatial/__init__.py @@ -5,6 +5,8 @@ from .random_elastic_deformation import RandomElasticDeformation from .random_flip import Flip from .random_flip import RandomFlip +from .random_affine_elastic_deformation import CombinedAffineElasticDeformation +from .random_affine_elastic_deformation import RandomCombinedAffineElasticDeformation __all__ = [ @@ -15,4 +17,6 @@ 'RandomAnisotropy', 'RandomElasticDeformation', 'ElasticDeformation', + 'RandomCombinedAffineElasticDeformation', + 'CombinedAffineElasticDeformation', ] diff --git a/src/torchio/transforms/augmentation/spatial/random_affine.py b/src/torchio/transforms/augmentation/spatial/random_affine.py index cd2955e4..272b0998 100644 --- a/src/torchio/transforms/augmentation/spatial/random_affine.py +++ b/src/torchio/transforms/augmentation/spatial/random_affine.py @@ -141,18 +141,18 @@ def __init__( ) self.check_shape = check_shape + @staticmethod def get_params( - self, scales: TypeSextetFloat, degrees: TypeSextetFloat, translation: TypeSextetFloat, isotropic: bool, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - scaling_params = self.sample_uniform_sextet(scales) + scaling_params = RandomTransform.sample_uniform_sextet(scales) if isotropic: scaling_params.fill_(scaling_params[0]) - rotation_params = self.sample_uniform_sextet(degrees) - translation_params = self.sample_uniform_sextet(translation) + rotation_params = RandomTransform.sample_uniform_sextet(degrees) + translation_params = RandomTransform.sample_uniform_sextet(translation) return scaling_params, rotation_params, translation_params def apply_transform(self, subject: Subject) -> Subject: @@ -345,6 +345,27 @@ def get_affine_transform(self, image): return transform + def get_default_pad_value( + self, tensor: torch.Tensor, sitk_image: sitk.Image + ) -> float: + default_value: float + if self.default_pad_value == 'minimum': + default_value = tensor.min().item() + elif self.default_pad_value == 'mean': + default_value = get_borders_mean( + sitk_image, + filter_otsu=False, + ) + elif self.default_pad_value == 'otsu': + default_value = get_borders_mean( + sitk_image, + filter_otsu=True, + ) + else: + assert isinstance(self.default_pad_value, Number) + default_value = float(self.default_pad_value) + return default_value + def apply_transform(self, subject: Subject) -> Subject: if self.check_shape: subject.check_consistent_spatial_shape() @@ -363,21 +384,7 @@ def apply_transform(self, subject: Subject) -> Subject: default_value = 0 else: interpolation = self.image_interpolation - if self.default_pad_value == 'minimum': - default_value = tensor.min().item() - elif self.default_pad_value == 'mean': - default_value = get_borders_mean( - sitk_image, - filter_otsu=False, - ) - elif self.default_pad_value == 'otsu': - default_value = get_borders_mean( - sitk_image, - filter_otsu=True, - ) - else: - assert isinstance(self.default_pad_value, Number) - default_value = float(self.default_pad_value) + default_value = self.get_default_pad_value(tensor, sitk_image) transformed_tensor = self.apply_affine_transform( sitk_image, transform, diff --git a/src/torchio/transforms/augmentation/spatial/random_affine_elastic_deformation.py b/src/torchio/transforms/augmentation/spatial/random_affine_elastic_deformation.py new file mode 100644 index 00000000..545ae8f4 --- /dev/null +++ b/src/torchio/transforms/augmentation/spatial/random_affine_elastic_deformation.py @@ -0,0 +1,439 @@ +from typing import Union +from typing import Tuple + +import numpy as np +import SimpleITK as sitk +import torch + +from .random_affine import ( + Affine, + RandomAffine, + _parse_default_value, + _parse_scales_isotropic, +) +from .random_elastic_deformation import ( + ElasticDeformation, + RandomElasticDeformation, + _parse_max_displacement, + _parse_num_control_points, +) +from .. import RandomTransform +from ... import SpatialTransform +from ....constants import INTENSITY +from ....constants import TYPE +from ....data.io import nib_to_sitk +from ....data.subject import Subject +from ....typing import TypeRangeFloat +from ....typing import TypeSextetFloat +from ....typing import TypeTripletFloat +from ....utils import to_tuple + +TypeOneToSixFloat = Union[TypeRangeFloat, TypeTripletFloat, TypeSextetFloat] + + +class RandomCombinedAffineElasticDeformation(RandomTransform, SpatialTransform): + r"""Apply a RandomAffine and RandomElasticDeformation simultaneously. + + Optimization to use only a single SimpleITK resampling. For additional details on + the transformations, see :class:`~torchio.transforms.augmentation.RandomAffine` + and :class:`~torchio.transforms.augmentation.RandomElasticDeformation` + + Args: + affine_first: Apply affine before elastic deformation. + + scales: Tuple :math:`(a_1, b_1, a_2, b_2, a_3, b_3)` defining the + scaling ranges. + The scaling values along each dimension are :math:`(s_1, s_2, s_3)`, + where :math:`s_i \sim \mathcal{U}(a_i, b_i)`. + If two values :math:`(a, b)` are provided, + then :math:`s_i \sim \mathcal{U}(a, b)`. + If only one value :math:`x` is provided, + then :math:`s_i \sim \mathcal{U}(1 - x, 1 + x)`. + If three values :math:`(x_1, x_2, x_3)` are provided, + then :math:`s_i \sim \mathcal{U}(1 - x_i, 1 + x_i)`. + For example, using ``scales=(0.5, 0.5)`` will zoom out the image, + making the objects inside look twice as small while preserving + the physical size and position of the image bounds. + degrees: Tuple :math:`(a_1, b_1, a_2, b_2, a_3, b_3)` defining the + rotation ranges in degrees. + Rotation angles around each axis are + :math:`(\theta_1, \theta_2, \theta_3)`, + where :math:`\theta_i \sim \mathcal{U}(a_i, b_i)`. + If two values :math:`(a, b)` are provided, + then :math:`\theta_i \sim \mathcal{U}(a, b)`. + If only one value :math:`x` is provided, + then :math:`\theta_i \sim \mathcal{U}(-x, x)`. + If three values :math:`(x_1, x_2, x_3)` are provided, + then :math:`\theta_i \sim \mathcal{U}(-x_i, x_i)`. + translation: Tuple :math:`(a_1, b_1, a_2, b_2, a_3, b_3)` defining the + translation ranges in mm. + Translation along each axis is :math:`(t_1, t_2, t_3)`, + where :math:`t_i \sim \mathcal{U}(a_i, b_i)`. + If two values :math:`(a, b)` are provided, + then :math:`t_i \sim \mathcal{U}(a, b)`. + If only one value :math:`x` is provided, + then :math:`t_i \sim \mathcal{U}(-x, x)`. + If three values :math:`(x_1, x_2, x_3)` are provided, + then :math:`t_i \sim \mathcal{U}(-x_i, x_i)`. + For example, if the image is in RAS+ orientation (e.g., after + applying :class:`~torchio.transforms.preprocessing.ToCanonical`) + and the translation is :math:`(10, 20, 30)`, the sample will move + 10 mm to the right, 20 mm to the front, and 30 mm upwards. + If the image was in, e.g., PIR+ orientation, the sample will move + 10 mm to the back, 20 mm downwards, and 30 mm to the right. + isotropic: If ``True``, the scaling factor along all dimensions is the + same, i.e. :math:`s_1 = s_2 = s_3`. + center: If ``'image'``, rotations and scaling will be performed around + the image center. If ``'origin'``, rotations and scaling will be + performed around the origin in world coordinates. + num_control_points: Number of control points along each dimension of + the coarse grid :math:`(n_x, n_y, n_z)`. + If a single value :math:`n` is passed, + then :math:`n_x = n_y = n_z = n`. + Smaller numbers generate smoother deformations. + The minimum number of control points is ``4`` as this transform + uses cubic B-splines to interpolate displacement. + max_displacement: Maximum displacement along each dimension at each + control point :math:`(D_x, D_y, D_z)`. + The displacement along dimension :math:`i` at each control point is + :math:`d_i \sim \mathcal{U}(0, D_i)`. + If a single value :math:`D` is passed, + then :math:`D_x = D_y = D_z = D`. + Note that the total maximum displacement would actually be + :math:`D_{max} = \sqrt{D_x^2 + D_y^2 + D_z^2}`. + locked_borders: If ``0``, all displacement vectors are kept. + If ``1``, displacement of control points at the + border of the coarse grid will be set to ``0``. + If ``2``, displacement of control points at the border of the image + (red dots in the image below) will also be set to ``0``. + default_pad_value: As the image is rotated, some values near the + borders will be undefined. + If ``'minimum'``, the fill value will be the image minimum. + If ``'mean'``, the fill value is the mean of the border values. + If ``'otsu'``, the fill value is the mean of the values at the + border that lie under an + `Otsu threshold `_. + If it is a number, that value will be used. + image_interpolation: See :ref:`Interpolation`. + label_interpolation: See :ref:`Interpolation`. + check_shape: If ``True`` an error will be raised if the images are in + different physical spaces. If ``False``, :attr:`center` should + probably not be ``'image'`` but ``'center'``. + **kwargs: See :class:`~torchio.transforms.Transform` for additional + keyword arguments. + + Example: + >>> import torchio as tio + >>> image = tio.datasets.Colin27().t1 + >>> transform = tio.RandomCombinedAffineElasticDeformation( + ... scales=(0.9, 1.2), + ... degrees=15, + ... max_displacement=(17, 12, 2) + ... ) + >>> transformed = transform(image) + + .. plot:: + + import torchio as tio + subject = tio.datasets.Slicer('CTChest') + ct = subject.CT_chest + transform = tio.RandomCombinedAffineElasticDeformation(max_displacement=(17, 12, 2)) + ct_transformed = transform(ct) + subject.add_image(ct_transformed, 'Transformed') + subject.plot() + """ + + def __init__( + self, + affine_first: bool = True, + scales: TypeOneToSixFloat = 0.1, + degrees: TypeOneToSixFloat = 10, + translation: TypeOneToSixFloat = 0, + isotropic: bool = False, + center: str = 'image', + num_control_points: Union[int, Tuple[int, int, int]] = 7, + max_displacement: Union[float, Tuple[float, float, float]] = 7.5, + locked_borders: int = 2, + default_pad_value: Union[str, float] = 'minimum', + image_interpolation: str = 'linear', + label_interpolation: str = 'nearest', + check_shape: bool = True, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.affine_first = affine_first + + # affine args + self.isotropic = isotropic + _parse_scales_isotropic(scales, isotropic) + self.scales = self.parse_params(scales, 1, 'scales', min_constraint=0) + self.degrees = self.parse_params(degrees, 0, 'degrees') + self.translation = self.parse_params(translation, 0, 'translation') + if center not in ('image', 'origin'): + message = f'Center argument must be "image" or "origin", not "{center}"' + raise ValueError(message) + self.center = center + self._bspline_transformation = None + self.num_control_points = to_tuple(num_control_points, length=3) + _parse_num_control_points(self.num_control_points) # type: ignore[arg-type] # noqa: B950 + self.max_displacement = to_tuple(max_displacement, length=3) + _parse_max_displacement(self.max_displacement) # type: ignore[arg-type] # noqa: B950 + self.num_locked_borders = locked_borders + if locked_borders not in (0, 1, 2): + raise ValueError('locked_borders must be 0, 1, or 2') + if locked_borders == 2 and 4 in self.num_control_points: + message = ( + 'Setting locked_borders to 2 and using less than 5 control' + 'points results in an identity transform. Lock fewer borders' + ' or use more control points.' + ) + raise ValueError(message) + + self.default_pad_value = _parse_default_value(default_pad_value) + self.image_interpolation = self.parse_interpolation( + image_interpolation, + ) + self.label_interpolation = self.parse_interpolation( + label_interpolation, + ) + self.check_shape = check_shape + + def get_params(self): + affine_params = RandomAffine.get_params( + self.scales, self.degrees, self.translation, self.isotropic + ) + elastic_params = RandomElasticDeformation.get_params( + self.num_control_points, self.max_displacement, self.num_locked_borders + ) + return affine_params, elastic_params + + def apply_transform(self, subject: Subject): + if self.check_shape: + subject.check_consistent_spatial_shape() + (scaling_params, rotation_params, translation_params), control_points = ( + self.get_params() + ) + + arguments = { + 'affine_first': self.affine_first, + 'scales': scaling_params.tolist(), + 'degrees': rotation_params.tolist(), + 'translation': translation_params.tolist(), + 'center': self.center, + 'control_points': control_points, + 'max_displacement': self.max_displacement, + 'default_pad_value': self.default_pad_value, + 'image_interpolation': self.image_interpolation, + 'label_interpolation': self.label_interpolation, + 'check_shape': self.check_shape, + } + + transform = CombinedAffineElasticDeformation( + **self.add_include_exclude(arguments) + ) + transformed = transform(subject) + assert isinstance(transformed, Subject) + return transformed + + +class CombinedAffineElasticDeformation(SpatialTransform): + r"""Apply an Affine and ElasticDeformation simultaneously. + + Optimization to use only a single SimpleITK resampling. For additional details + on the transformations, see :class:`~torchio.transforms.augmentation.Affine` + and :class:`~torchio.transforms.augmentation.ElasticDeformation` + + Args: + affine_first: Apply affine before elastic deformation. + scales: Tuple :math:`(s_1, s_2, s_3)` defining the + scaling values along each dimension. + degrees: Tuple :math:`(\theta_1, \theta_2, \theta_3)` defining the + rotation around each axis. + translation: Tuple :math:`(t_1, t_2, t_3)` defining the + translation in mm along each axis. + control_points: + max_displacement: Maximum displacement along each dimension at each + control point :math:`(D_x, D_y, D_z)`. + The displacement along dimension :math:`i` at each control point is + :math:`d_i \sim \mathcal{U}(0, D_i)`. + Note that the total maximum displacement would actually be + :math:`D_{max} = \sqrt{D_x^2 + D_y^2 + D_z^2}`. + center: If ``'image'``, rotations and scaling will be performed around + the image center. If ``'origin'``, rotations and scaling will be + performed around the origin in world coordinates. + default_pad_value: As the image is rotated, some values near the + borders will be undefined. + If ``'minimum'``, the fill value will be the image minimum. + If ``'mean'``, the fill value is the mean of the border values. + If ``'otsu'``, the fill value is the mean of the values at the + border that lie under an + `Otsu threshold `_. + If it is a number, that value will be used. + image_interpolation: See :ref:`Interpolation`. + label_interpolation: See :ref:`Interpolation`. + check_shape: If ``True`` an error will be raised if the images are in + different physical spaces. If ``False``, :attr:`center` should + probably not be ``'image'`` but ``'center'``. + **kwargs: See :class:`~torchio.transforms.Transform` for additional + keyword arguments. + """ + + def __init__( + self, + affine_first: bool, + scales: TypeTripletFloat, + degrees: TypeTripletFloat, + translation: TypeTripletFloat, + control_points: np.ndarray, + max_displacement: TypeTripletFloat, + center: str = 'image', + default_pad_value: Union[str, float] = 'minimum', + image_interpolation: str = 'linear', + label_interpolation: str = 'nearest', + check_shape: bool = True, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.affine_first = affine_first + + # affine args + self.scales = self.parse_params( + scales, + None, + 'scales', + make_ranges=False, + min_constraint=0, + ) + self.degrees = self.parse_params( + degrees, + None, + 'degrees', + make_ranges=False, + ) + self.translation = self.parse_params( + translation, + None, + 'translation', + make_ranges=False, + ) + if center not in ('image', 'origin'): + message = f'Center argument must be "image" or "origin", not "{center}"' + raise ValueError(message) + self.center = center + self.use_image_center = center == 'image' + + # elastic args + self._bspline_transformation = None + self.control_points = control_points + self.max_displacement = to_tuple(max_displacement, length=3) + + # common args + self.default_pad_value = _parse_default_value(default_pad_value) + self.image_interpolation = self.parse_interpolation( + image_interpolation, + ) + self.label_interpolation = self.parse_interpolation( + label_interpolation, + ) + self.invert_transform = False + self.check_shape = check_shape + + self._affine = Affine( + scales, + degrees, + translation, + center=center, + default_pad_value=default_pad_value, + image_interpolation=image_interpolation, + label_interpolation=label_interpolation, + check_shape=check_shape, + **kwargs, + ) + self._elastic = ElasticDeformation( + control_points, + max_displacement, + image_interpolation=image_interpolation, + label_interpolation=label_interpolation, + **kwargs, + ) + + self.args_names = [ + 'affine_first', + 'scales', + 'degrees', + 'translation', + 'control_points', + 'max_displacement', + 'center', + 'default_pad_value', + 'image_interpolation', + 'label_interpolation', + 'check_shape', + ] + + def apply_transform(self, subject: Subject) -> Subject: + if self.check_shape: + subject.check_consistent_spatial_shape() + default_value: float + + for image in self.get_images(subject): + affine_transform = self._affine.get_affine_transform(image) + transformed_tensors = [] + for tensor in image.data: + sitk_image = nib_to_sitk( + tensor[np.newaxis], + image.affine, + force_3d=True, + ) + if image[TYPE] != INTENSITY: + interpolation = self._affine.label_interpolation + default_value = 0 + else: + interpolation = self._affine.image_interpolation + default_value = self._affine.get_default_pad_value( + tensor, sitk_image + ) + + bspline_transform = self._elastic.get_bspline_transform(sitk_image) + self._elastic.parse_free_form_transform( + bspline_transform, self._elastic.max_displacement + ) + + # stack: LIFO + if self.affine_first: + combined_transforms = [affine_transform, bspline_transform] + else: + combined_transforms = [bspline_transform, affine_transform] + composite_transform = sitk.CompositeTransform(combined_transforms) + + transformed_tensor = self.apply_composite_transform( + sitk_image, + composite_transform, + interpolation, + default_value, + ) + transformed_tensors.append(transformed_tensor) + image.set_data(torch.stack(transformed_tensors)) + return subject + + def apply_composite_transform( + self, + sitk_image: sitk.Image, + transform: sitk.Transform, + interpolation: str, + default_value: float, + ) -> torch.Tensor: + floating = reference = sitk_image + + resampler = sitk.ResampleImageFilter() + resampler.SetInterpolator(self.get_sitk_interpolator(interpolation)) + resampler.SetReferenceImage(reference) + resampler.SetDefaultPixelValue(float(default_value)) + resampler.SetOutputPixelType(sitk.sitkFloat32) + resampler.SetTransform(transform) + resampled = resampler.Execute(floating) + + np_array = sitk.GetArrayFromImage(resampled) + np_array = np_array.transpose() # ITK to NumPy + tensor = torch.as_tensor(np_array) + return tensor diff --git a/tests/transforms/augmentation/test_random_affine_elastic_deformation.py b/tests/transforms/augmentation/test_random_affine_elastic_deformation.py new file mode 100644 index 00000000..df11f635 --- /dev/null +++ b/tests/transforms/augmentation/test_random_affine_elastic_deformation.py @@ -0,0 +1,289 @@ +import pytest +import torch +import torchio as tio + +from ...utils import TorchioTestCase + + +class TestRandomCombinedAffineElasticDeformation(TorchioTestCase): + """Tests for `RandomCombinedAffineElasticDeformation`.""" + + def setUp(self): + # Set image origin far from center + super().setUp() + affine = self.sample_subject.t1.affine + affine[:3, 3] = 1e5 + + def test_inputs_pta_gt_one(self): + with pytest.raises(ValueError): + tio.RandomCombinedAffineElasticDeformation(p=1.5) + + def test_inputs_pta_lt_zero(self): + with pytest.raises(ValueError): + tio.RandomCombinedAffineElasticDeformation(p=-1) + + def test_inputs_interpolation_int(self): + with pytest.raises(TypeError): + tio.RandomCombinedAffineElasticDeformation(image_interpolation=1) + + def test_inputs_interpolation(self): + with pytest.raises(TypeError): + tio.RandomCombinedAffineElasticDeformation(image_interpolation=0) + + def test_num_control_points_noint(self): + with pytest.raises(ValueError): + tio.RandomCombinedAffineElasticDeformation(num_control_points=2.5) + + def test_num_control_points_small(self): + with pytest.raises(ValueError): + tio.RandomCombinedAffineElasticDeformation(num_control_points=3) + + def test_max_displacement_no_num(self): + with pytest.raises(ValueError): + tio.RandomCombinedAffineElasticDeformation(max_displacement=None) + + def test_max_displacement_negative(self): + with pytest.raises(ValueError): + tio.RandomCombinedAffineElasticDeformation(max_displacement=-1) + + def test_wrong_locked_borders(self): + with pytest.raises(ValueError): + tio.RandomCombinedAffineElasticDeformation(locked_borders=-1) + + def test_coarse_grid_removed(self): + with pytest.raises(ValueError): + tio.RandomCombinedAffineElasticDeformation( + num_control_points=(4, 5, 6), + locked_borders=2, + ) + + def test_folding(self): + # Assume shape is (10, 20, 30) and spacing is (1, 1, 1) + # Then grid spacing is (10/(12-2), 20/(5-2), 30/(5-2)) + # or (1, 6.7, 10), and half is (0.5, 3.3, 5) + transform = tio.RandomCombinedAffineElasticDeformation( + num_control_points=(12, 5, 5), + max_displacement=6, + ) + with pytest.warns(RuntimeWarning): + transform(self.sample_subject) + + def test_num_control_points(self): + tio.RandomCombinedAffineElasticDeformation(num_control_points=5) + tio.RandomCombinedAffineElasticDeformation(num_control_points=(5, 6, 7)) + + def test_max_displacement(self): + tio.RandomCombinedAffineElasticDeformation(max_displacement=5) + tio.RandomCombinedAffineElasticDeformation(max_displacement=(5, 6, 7)) + + def test_no_displacement(self): + transform = tio.RandomCombinedAffineElasticDeformation( + max_displacement=0, scales=0, degrees=0, translation=0 + ) + transformed = transform(self.sample_subject) + self.assert_tensor_equal( + self.sample_subject.t1.data, + transformed.t1.data, + ) + self.assert_tensor_equal( + self.sample_subject.label.data, + transformed.label.data, + ) + + def test_rotation_image(self): + # Rotation around image center + transform = tio.RandomCombinedAffineElasticDeformation( + degrees=(90, 90), + default_pad_value=0, + center='image', + ) + transformed = transform(self.sample_subject) + total = transformed.t1.data.sum() + self.assertNotEqual(total, 0) + + def test_rotation_origin(self): + # Rotation around far away point, image should be empty + transform = tio.RandomCombinedAffineElasticDeformation( + degrees=(90, 90), + default_pad_value=0, + center='origin', + ) + transformed = transform(self.sample_subject) + total = transformed.t1.data.sum() + assert total == 0 + + def test_no_rotation(self): + transform = tio.RandomCombinedAffineElasticDeformation( + scales=(1, 1), + degrees=(0, 0), + default_pad_value=0, + max_displacement=0, + center='image', + ) + transformed = transform(self.sample_subject) + self.assert_tensor_almost_equal( + self.sample_subject.t1.data, + transformed.t1.data, + ) + + transform = tio.RandomCombinedAffineElasticDeformation( + scales=(1, 1), + degrees=(180, 180), + default_pad_value=0, + max_displacement=0, + center='image', + ) + transformed = transform(self.sample_subject) + transformed = transform(transformed) + self.assert_tensor_almost_equal( + self.sample_subject.t1.data, + transformed.t1.data, + ) + + def test_isotropic(self): + tio.RandomCombinedAffineElasticDeformation(isotropic=True)(self.sample_subject) + + def test_mean(self): + tio.RandomCombinedAffineElasticDeformation(default_pad_value='mean')( + self.sample_subject + ) + + def test_otsu(self): + tio.RandomCombinedAffineElasticDeformation(default_pad_value='otsu')( + self.sample_subject + ) + + def test_bad_center(self): + with pytest.raises(ValueError): + tio.RandomCombinedAffineElasticDeformation(center='bad') + + def test_negative_scales(self): + with pytest.raises(ValueError): + tio.RandomCombinedAffineElasticDeformation(scales=(-1, 1)) + + def test_scale_too_large(self): + with pytest.raises(ValueError): + tio.RandomCombinedAffineElasticDeformation(scales=1.5) + + def test_scales_range_with_negative_min(self): + with pytest.raises(ValueError): + tio.RandomCombinedAffineElasticDeformation(scales=(-1, 4)) + + def test_wrong_scales_type(self): + with pytest.raises(ValueError): + tio.RandomCombinedAffineElasticDeformation(scales='wrong') + + def test_wrong_degrees_type(self): + with pytest.raises(ValueError): + tio.RandomCombinedAffineElasticDeformation(degrees='wrong') + + def test_too_many_translation_values(self): + with pytest.raises(ValueError): + tio.RandomCombinedAffineElasticDeformation(translation=(-10, 4, 42)) + + def test_wrong_translation_type(self): + with pytest.raises(ValueError): + tio.RandomCombinedAffineElasticDeformation(translation='wrong') + + def test_wrong_center(self): + with pytest.raises(ValueError): + tio.RandomCombinedAffineElasticDeformation(center=0) + + def test_wrong_default_pad_value(self): + with pytest.raises(ValueError): + tio.RandomCombinedAffineElasticDeformation(default_pad_value='wrong') + + def test_wrong_image_interpolation_type(self): + with pytest.raises(TypeError): + tio.RandomCombinedAffineElasticDeformation(image_interpolation=0) + + def test_wrong_image_interpolation_value(self): + with pytest.raises(ValueError): + tio.RandomCombinedAffineElasticDeformation(image_interpolation='wrong') + + def test_incompatible_args_isotropic(self): + with pytest.raises(ValueError): + tio.RandomCombinedAffineElasticDeformation( + scales=(0.8, 0.5, 0.1), isotropic=True + ) + + def test_parse_scales(self): + def do_assert(transform): + assert transform.scales == 3 * (0.9, 1.1) + + do_assert(tio.RandomCombinedAffineElasticDeformation(scales=0.1)) + do_assert(tio.RandomCombinedAffineElasticDeformation(scales=(0.9, 1.1))) + do_assert(tio.RandomCombinedAffineElasticDeformation(scales=3 * (0.1,))) + do_assert(tio.RandomCombinedAffineElasticDeformation(scales=3 * [0.9, 1.1])) + + def test_parse_degrees(self): + def do_assert(transform): + assert transform.degrees == 3 * (-10, 10) + + do_assert(tio.RandomCombinedAffineElasticDeformation(degrees=10)) + do_assert(tio.RandomCombinedAffineElasticDeformation(degrees=(-10, 10))) + do_assert(tio.RandomCombinedAffineElasticDeformation(degrees=3 * (10,))) + do_assert(tio.RandomCombinedAffineElasticDeformation(degrees=3 * [-10, 10])) + + def test_parse_translation(self): + def do_assert(transform): + assert transform.translation == 3 * (-10, 10) + + do_assert(tio.RandomCombinedAffineElasticDeformation(translation=10)) + do_assert(tio.RandomCombinedAffineElasticDeformation(translation=(-10, 10))) + do_assert(tio.RandomCombinedAffineElasticDeformation(translation=3 * (10,))) + do_assert(tio.RandomCombinedAffineElasticDeformation(translation=3 * [-10, 10])) + + def test_default_value_label_map(self): + # From https://github.com/fepegar/torchio/issues/626 + a = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]]).reshape(1, 3, 3, 1) + image = tio.LabelMap(tensor=a) + aff = tio.RandomCombinedAffineElasticDeformation( + translation=(0, 1, 1), default_pad_value='otsu' + ) + transformed = aff(image) + assert all(n in (0, 1) for n in transformed.data.flatten()) + + def test_no_inverse(self): + tensor = torch.zeros((1, 2, 2, 2)) + tensor[0, 1, 1, 1] = 1 # most RAS voxel + expected = torch.zeros((1, 2, 2, 2)) + expected[0, 0, 1, 1] = 1 + scales = 1, 1, 1 + degrees = 0, 0, 90 # anterior should go left + translation = 0, 0, 0 + apply_affine = tio.Affine( + scales, + degrees, + translation, + ) + transformed = apply_affine(tensor) + self.assert_tensor_almost_equal(transformed, expected) + + def test_different_spaces(self): + t1 = self.sample_subject.t1 + label = tio.Resample(2)(self.sample_subject.label) + new_subject = tio.Subject(t1=t1, label=label) + with pytest.raises(RuntimeError): + tio.RandomCombinedAffineElasticDeformation()(new_subject) + tio.RandomCombinedAffineElasticDeformation(check_shape=False)(new_subject) + + def test_transform_order(self): + src_transform = tio.RandomCombinedAffineElasticDeformation( + scales=0, degrees=0, translation=1, num_control_points=5, max_displacement=1 + ) + + (scales, degrees, translation), control_points = src_transform.get_params() + + max_displacement = src_transform.max_displacement + + transform1 = tio.CombinedAffineElasticDeformation( + True, scales, degrees, translation, control_points, max_displacement + ) + transform2 = tio.CombinedAffineElasticDeformation( + False, scales, degrees, translation, control_points, max_displacement + ) + + transformed1 = transform1(self.sample_subject) + transformed2 = transform2(self.sample_subject) + self.assert_tensor_not_equal(transformed1.t1.data, transformed2.t1.data) diff --git a/tests/transforms/test_transforms.py b/tests/transforms/test_transforms.py index e07b654d..59111be5 100644 --- a/tests/transforms/test_transforms.py +++ b/tests/transforms/test_transforms.py @@ -17,6 +17,9 @@ def get_transform(self, channels, is_3d=True, labels=True): landmarks_dict = {channel: np.linspace(0, 100, 13) for channel in channels} disp = 1 if is_3d else (1, 1, 0.01) elastic = tio.RandomElasticDeformation(max_displacement=disp) + affine_elastic = tio.RandomCombinedAffineElasticDeformation( + max_displacement=disp + ) cp_args = (9, 21, 30) if is_3d else (21, 30, 1) resize_args = (10, 20, 30) if is_3d else (10, 20, 1) flip_axes = axes_downsample = (0, 1, 2) if is_3d else (0, 1) @@ -46,6 +49,7 @@ def get_transform(self, channels, is_3d=True, labels=True): tio.HistogramStandardization(landmarks_dict), elastic, tio.RandomAffine(), + affine_elastic, tio.OneOf( { tio.RandomAffine(): 3,