Skip to content

Commit

Permalink
MAISI VAE transforms (Project-MONAI#1755)
Browse files Browse the repository at this point in the history
Fixes # .

### Description
MAISI VAE transforms

### Checks
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Avoid including large-size files in the PR.
- [x] Clean up long text outputs from code cells in the notebook.
- [x] For security purposes, please check the contents and remove any
sensitive info such as user names and private key.
- [x] Ensure (1) hyperlinks and markdown anchors are working (2) use
relative paths for tutorial repo files (3) put figure and graphs in the
`./figure` folder
- [ ] Notebook runs automatically `./runner.sh -t <path to .ipynb file>`

---------

Signed-off-by: root <[email protected]>
Signed-off-by: Can-Zhao <[email protected]>
Co-authored-by: root <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Jul 11, 2024
1 parent 6967caf commit 12b2e55
Show file tree
Hide file tree
Showing 2 changed files with 348 additions and 0 deletions.
24 changes: 24 additions & 0 deletions generative/maisi/configs/config_maisi_vae_train.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
{
"data_option":{
"random_aug": true,
"spacing_type": "rand_zoom",
"spacing": null,
"select_channel": 0
},
"autoencoder_train": {
"batch_size": 1,
"patch_size": [128,128,128],
"val_batch_size": 1,
"val_patch_size": null,
"val_sliding_window_patch_size": [192,192,128],
"lr": 1e-4,
"perceptual_weight": 0.3,
"kl_weight": 1e-7,
"adv_weight": 0.1,
"recon_loss": "l1",
"val_interval": 10,
"cache": 0.5,
"amp": true,
"n_epochs": 12000
}
}
324 changes: 324 additions & 0 deletions generative/maisi/scripts/transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,324 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
from typing import List, Optional

import torch
from monai.transforms import (
Compose,
DivisiblePadd,
EnsureChannelFirstd,
EnsureTyped,
Lambdad,
LoadImaged,
Orientationd,
RandAdjustContrastd,
RandBiasFieldd,
RandFlipd,
RandGibbsNoised,
RandHistogramShiftd,
RandRotate90d,
RandRotated,
RandScaleIntensityd,
RandShiftIntensityd,
RandSpatialCropd,
RandZoomd,
ResizeWithPadOrCropd,
ScaleIntensityRanged,
ScaleIntensityRangePercentilesd,
SelectItemsd,
Spacingd,
SpatialPadd,
)

SUPPORT_MODALITIES = ["ct", "mri"]


def define_fixed_intensity_transform(modality: str, image_keys: List[str] = ["image"]) -> List:
"""
Define fixed intensity transform based on the modality.
Args:
modality (str): The imaging modality, either 'ct' or 'mri'.
image_keys (List[str], optional): List of image keys. Defaults to ["image"].
Returns:
List: A list of intensity transforms.
"""
if modality not in SUPPORT_MODALITIES:
warnings.warn(
f"Intensity transform only support {SUPPORT_MODALITIES}. Got {modality}. Will not do any intensity transform and will use original intensities."
)

modality = modality.lower() # Normalize modality to lowercase

intensity_transforms = {
"mri": [
ScaleIntensityRangePercentilesd(keys=image_keys, lower=0.0, upper=99.5, b_min=0.0, b_max=1, clip=False)
],
"ct": [ScaleIntensityRanged(keys=image_keys, a_min=-1000, a_max=1000, b_min=0.0, b_max=1.0, clip=True)],
}

if modality not in intensity_transforms:
return []

return intensity_transforms[modality]


def define_random_intensity_transform(modality: str, image_keys: List[str] = ["image"]) -> List:
"""
Define random intensity transform based on the modality.
Args:
modality (str): The imaging modality, either 'ct' or 'mri'.
image_keys (List[str], optional): List of image keys. Defaults to ["image"].
Returns:
List: A list of random intensity transforms.
"""
modality = modality.lower() # Normalize modality to lowercase
if modality not in SUPPORT_MODALITIES:
warnings.warn(
f"Intensity transform only support {SUPPORT_MODALITIES}. Got {modality}. Will not do any intensity transform and will use original intensities."
)

if modality == "ct":
return [] # CT HU intensity is stable across different datasets
elif modality == "mri":
return [
RandBiasFieldd(keys=image_keys, prob=0.3, coeff_range=(0.0, 0.3)),
RandGibbsNoised(keys=image_keys, prob=0.3, alpha=(0.5, 1.0)),
RandAdjustContrastd(keys=image_keys, prob=0.3, gamma=(0.5, 2.0)),
RandHistogramShiftd(keys=image_keys, prob=0.05, num_control_points=10),
]
else:
return []


def define_vae_transform(
is_train: bool,
modality: str,
random_aug: bool,
k: int = 4,
patch_size: List[int] = [128, 128, 128],
val_patch_size: Optional[List[int]] = None,
output_dtype: torch.dtype = torch.float32,
spacing_type: str = "original",
spacing: Optional[List[float]] = None,
image_keys: List[str] = ["image"],
label_keys: List[str] = [],
additional_keys: List[str] = [],
select_channel: int = 0,
) -> tuple:
"""
Define the MAISI VAE transform pipeline for training or validation.
Args:
is_train (bool): Whether it's for training or not. If True, the output transform will consider random_aug, the cropping will use "patch_size" for random cropping. If False, the output transform will alwasy treat "random_aug" as False, will use "val_patch_size" for central cropping.
modality (str): The imaging modality, either 'ct' or 'mri'.
random_aug (bool): Whether to apply random data augmentation.
k (int, optional): Patches should be divisible by k. Defaults to 4.
patch_size (List[int], optional): Size of the patches. Defaults to [128, 128, 128]. Will random crop patch for training.
val_patch_size (Optional[List[int]], optional): Size of validation patches. Defaults to None. If None, will use the whole volume for validation. If given, will central crop a patch for validation.
output_dtype (torch.dtype, optional): Output data type. Defaults to torch.float32.
spacing_type (str, optional): Type of spacing. Defaults to "original". Choose from ["original", "fixed", "rand_zoom"].
spacing (Optional[List[float]], optional): Spacing values. Defaults to None.
image_keys (List[str], optional): List of image keys. Defaults to ["image"].
label_keys (List[str], optional): List of label keys. Defaults to [].
additional_keys (List[str], optional): List of additional keys. Defaults to [].
select_channel (int, optional): Channel to select for multi-channel MRI. Defaults to 0.
Returns:
tuple: A tuple containing Composed Transform train_transforms or val_transforms depending on 'is_train'.
"""
modality = modality.lower() # Normalize modality to lowercase
if modality not in SUPPORT_MODALITIES:
warnings.warn(
f"Intensity transform only support {SUPPORT_MODALITIES}. Got {modality}. Will not do any intensity transform and will use original intensities."
)

if spacing_type not in ["original", "fixed", "rand_zoom"]:
raise ValueError(f"spacing_type has to be chosen from ['original', 'fixed', 'rand_zoom']. Got {spacing_type}.")

keys = image_keys + label_keys + additional_keys
interp_mode = ["bilinear"] * len(image_keys) + ["nearest"] * len(label_keys)

common_transform = [
SelectItemsd(keys=keys, allow_missing_keys=True),
LoadImaged(keys=keys, allow_missing_keys=True),
EnsureChannelFirstd(keys=keys, allow_missing_keys=True),
Orientationd(keys=keys, axcodes="RAS", allow_missing_keys=True),
]

if modality == "mri":
common_transform.append(Lambdad(keys=image_keys, func=lambda x: x[select_channel : select_channel + 1, ...]))

common_transform.extend(define_fixed_intensity_transform(modality, image_keys=image_keys))

if spacing_type == "fixed":
common_transform.append(
Spacingd(keys=image_keys + label_keys, allow_missing_keys=True, pixdim=spacing, mode=interp_mode)
)

random_transform = []
if is_train and random_aug:
random_transform.extend(define_random_intensity_transform(modality, image_keys=image_keys))
random_transform.extend(
[RandFlipd(keys=keys, allow_missing_keys=True, prob=0.5, spatial_axis=axis) for axis in range(3)]
+ [
RandRotate90d(keys=keys, allow_missing_keys=True, prob=0.5, spatial_axes=axes)
for axes in [(0, 1), (1, 2), (0, 2)]
]
+ [
RandScaleIntensityd(keys=image_keys, allow_missing_keys=True, prob=0.3, factors=(0.9, 1.1)),
RandShiftIntensityd(keys=image_keys, allow_missing_keys=True, prob=0.3, offsets=0.05),
]
)

if spacing_type == "rand_zoom":
random_transform.extend(
[
RandZoomd(
keys=image_keys + label_keys,
allow_missing_keys=True,
prob=0.3,
min_zoom=0.5,
max_zoom=1.5,
keep_size=False,
mode=interp_mode,
),
RandRotated(
keys=image_keys + label_keys,
allow_missing_keys=True,
prob=0.3,
range_x=0.1,
range_y=0.1,
range_z=0.1,
keep_size=True,
mode=interp_mode,
),
]
)

if is_train:
train_crop = [
SpatialPadd(keys=keys, spatial_size=patch_size, allow_missing_keys=True),
RandSpatialCropd(
keys=keys, roi_size=patch_size, allow_missing_keys=True, random_size=False, random_center=True
),
]
else:
val_crop = (
[DivisiblePadd(keys=keys, allow_missing_keys=True, k=k)]
if val_patch_size is None
else [ResizeWithPadOrCropd(keys=keys, allow_missing_keys=True, spatial_size=val_patch_size)]
)

final_transform = [EnsureTyped(keys=keys, dtype=output_dtype, allow_missing_keys=True)]

if is_train:
train_transforms = Compose(
common_transform + random_transform + train_crop + final_transform
if random_aug
else common_transform + train_crop + final_transform
)
return train_transforms
else:
val_transforms = Compose(common_transform + val_crop + final_transform)
return val_transforms


class VAE_Transform:
"""
A class to handle MAISI VAE transformations for different modalities.
"""

def __init__(
self,
is_train: bool,
random_aug: bool,
k: int = 4,
patch_size: List[int] = [128, 128, 128],
val_patch_size: Optional[List[int]] = None,
output_dtype: torch.dtype = torch.float32,
spacing_type: str = "original",
spacing: Optional[List[float]] = None,
image_keys: List[str] = ["image"],
label_keys: List[str] = [],
additional_keys: List[str] = [],
select_channel: int = 0,
):
"""
Initialize the VAE_Transform.
Args:
is_train (bool): Whether it's for training or not. If True, the output transform will consider random_aug, the cropping will use "patch_size" for random cropping. If False, the output transform will alwasy treat "random_aug" as False, will use "val_patch_size" for central cropping.
random_aug (bool): Whether to apply random data augmentation for training.
k (int, optional): Patches should be divisible by k. Defaults to 4.
patch_size (List[int], optional): Size of the patches. Defaults to [128, 128, 128]. Will random crop patch for training.
val_patch_size (Optional[List[int]], optional): Size of validation patches. Defaults to None. If None, will use the whole volume for validation. If given, will central crop a patch for validation.
output_dtype (torch.dtype, optional): Output data type. Defaults to torch.float32.
spacing_type (str, optional): Type of spacing. Defaults to "original". Choose from ["original", "fixed", "rand_zoom"].
spacing (Optional[List[float]], optional): Spacing values. Defaults to None.
image_keys (List[str], optional): List of image keys. Defaults to ["image"].
label_keys (List[str], optional): List of label keys. Defaults to [].
additional_keys (List[str], optional): List of additional keys. Defaults to [].
select_channel (int, optional): Channel to select for multi-channel MRI. Defaults to 0.
"""
if spacing_type not in ["original", "fixed", "rand_zoom"]:
raise ValueError(
f"spacing_type has to be chosen from ['original', 'fixed', 'rand_zoom']. Got {spacing_type}."
)

self.is_train = is_train
self.transform_dict = {}

for modality in ["ct", "mri"]:
self.transform_dict[modality] = define_vae_transform(
is_train=is_train,
modality=modality,
random_aug=random_aug,
k=k,
patch_size=patch_size,
val_patch_size=val_patch_size,
output_dtype=output_dtype,
spacing_type=spacing_type,
spacing=spacing,
image_keys=image_keys,
label_keys=label_keys,
additional_keys=additional_keys,
select_channel=select_channel,
)

def __call__(self, img: dict, fixed_modality: Optional[str] = None) -> dict:
"""
Apply the appropriate transform to the input image.
Args:
img (dict): Input image dictionary.
fixed_modality (Optional[str], optional): Fixed modality to use. Defaults to None.
Returns:
Composed Transform
Raises:
ValueError: If the modality is not 'ct' or 'mri'.
"""
modality = fixed_modality or img["class"]
modality = modality.lower() # Normalize modality to lowercase
if modality not in ["ct", "mri"]:
warnings.warn(
f"Intensity transform only support {SUPPORT_MODALITIES}. Got {modality}. Will not do any intensity transform and will use original intensities."
)

transform = self.transform_dict[modality]
return transform(img)

0 comments on commit 12b2e55

Please sign in to comment.