diff --git a/generative/maisi/configs/config_maisi_vae_train.json b/generative/maisi/configs/config_maisi_vae_train.json new file mode 100644 index 0000000000..2e3b18d461 --- /dev/null +++ b/generative/maisi/configs/config_maisi_vae_train.json @@ -0,0 +1,24 @@ +{ + "data_option":{ + "random_aug": true, + "spacing_type": "rand_zoom", + "spacing": null, + "select_channel": 0 + }, + "autoencoder_train": { + "batch_size": 1, + "patch_size": [128,128,128], + "val_batch_size": 1, + "val_patch_size": null, + "val_sliding_window_patch_size": [192,192,128], + "lr": 1e-4, + "perceptual_weight": 0.3, + "kl_weight": 1e-7, + "adv_weight": 0.1, + "recon_loss": "l1", + "val_interval": 10, + "cache": 0.5, + "amp": true, + "n_epochs": 12000 + } +} diff --git a/generative/maisi/scripts/transforms.py b/generative/maisi/scripts/transforms.py new file mode 100644 index 0000000000..fb784f1871 --- /dev/null +++ b/generative/maisi/scripts/transforms.py @@ -0,0 +1,324 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from typing import List, Optional + +import torch +from monai.transforms import ( + Compose, + DivisiblePadd, + EnsureChannelFirstd, + EnsureTyped, + Lambdad, + LoadImaged, + Orientationd, + RandAdjustContrastd, + RandBiasFieldd, + RandFlipd, + RandGibbsNoised, + RandHistogramShiftd, + RandRotate90d, + RandRotated, + RandScaleIntensityd, + RandShiftIntensityd, + RandSpatialCropd, + RandZoomd, + ResizeWithPadOrCropd, + ScaleIntensityRanged, + ScaleIntensityRangePercentilesd, + SelectItemsd, + Spacingd, + SpatialPadd, +) + +SUPPORT_MODALITIES = ["ct", "mri"] + + +def define_fixed_intensity_transform(modality: str, image_keys: List[str] = ["image"]) -> List: + """ + Define fixed intensity transform based on the modality. + + Args: + modality (str): The imaging modality, either 'ct' or 'mri'. + image_keys (List[str], optional): List of image keys. Defaults to ["image"]. + + Returns: + List: A list of intensity transforms. + """ + if modality not in SUPPORT_MODALITIES: + warnings.warn( + f"Intensity transform only support {SUPPORT_MODALITIES}. Got {modality}. Will not do any intensity transform and will use original intensities." + ) + + modality = modality.lower() # Normalize modality to lowercase + + intensity_transforms = { + "mri": [ + ScaleIntensityRangePercentilesd(keys=image_keys, lower=0.0, upper=99.5, b_min=0.0, b_max=1, clip=False) + ], + "ct": [ScaleIntensityRanged(keys=image_keys, a_min=-1000, a_max=1000, b_min=0.0, b_max=1.0, clip=True)], + } + + if modality not in intensity_transforms: + return [] + + return intensity_transforms[modality] + + +def define_random_intensity_transform(modality: str, image_keys: List[str] = ["image"]) -> List: + """ + Define random intensity transform based on the modality. + + Args: + modality (str): The imaging modality, either 'ct' or 'mri'. + image_keys (List[str], optional): List of image keys. Defaults to ["image"]. + + Returns: + List: A list of random intensity transforms. + """ + modality = modality.lower() # Normalize modality to lowercase + if modality not in SUPPORT_MODALITIES: + warnings.warn( + f"Intensity transform only support {SUPPORT_MODALITIES}. Got {modality}. Will not do any intensity transform and will use original intensities." + ) + + if modality == "ct": + return [] # CT HU intensity is stable across different datasets + elif modality == "mri": + return [ + RandBiasFieldd(keys=image_keys, prob=0.3, coeff_range=(0.0, 0.3)), + RandGibbsNoised(keys=image_keys, prob=0.3, alpha=(0.5, 1.0)), + RandAdjustContrastd(keys=image_keys, prob=0.3, gamma=(0.5, 2.0)), + RandHistogramShiftd(keys=image_keys, prob=0.05, num_control_points=10), + ] + else: + return [] + + +def define_vae_transform( + is_train: bool, + modality: str, + random_aug: bool, + k: int = 4, + patch_size: List[int] = [128, 128, 128], + val_patch_size: Optional[List[int]] = None, + output_dtype: torch.dtype = torch.float32, + spacing_type: str = "original", + spacing: Optional[List[float]] = None, + image_keys: List[str] = ["image"], + label_keys: List[str] = [], + additional_keys: List[str] = [], + select_channel: int = 0, +) -> tuple: + """ + Define the MAISI VAE transform pipeline for training or validation. + + Args: + is_train (bool): Whether it's for training or not. If True, the output transform will consider random_aug, the cropping will use "patch_size" for random cropping. If False, the output transform will alwasy treat "random_aug" as False, will use "val_patch_size" for central cropping. + modality (str): The imaging modality, either 'ct' or 'mri'. + random_aug (bool): Whether to apply random data augmentation. + k (int, optional): Patches should be divisible by k. Defaults to 4. + patch_size (List[int], optional): Size of the patches. Defaults to [128, 128, 128]. Will random crop patch for training. + val_patch_size (Optional[List[int]], optional): Size of validation patches. Defaults to None. If None, will use the whole volume for validation. If given, will central crop a patch for validation. + output_dtype (torch.dtype, optional): Output data type. Defaults to torch.float32. + spacing_type (str, optional): Type of spacing. Defaults to "original". Choose from ["original", "fixed", "rand_zoom"]. + spacing (Optional[List[float]], optional): Spacing values. Defaults to None. + image_keys (List[str], optional): List of image keys. Defaults to ["image"]. + label_keys (List[str], optional): List of label keys. Defaults to []. + additional_keys (List[str], optional): List of additional keys. Defaults to []. + select_channel (int, optional): Channel to select for multi-channel MRI. Defaults to 0. + + Returns: + tuple: A tuple containing Composed Transform train_transforms or val_transforms depending on 'is_train'. + """ + modality = modality.lower() # Normalize modality to lowercase + if modality not in SUPPORT_MODALITIES: + warnings.warn( + f"Intensity transform only support {SUPPORT_MODALITIES}. Got {modality}. Will not do any intensity transform and will use original intensities." + ) + + if spacing_type not in ["original", "fixed", "rand_zoom"]: + raise ValueError(f"spacing_type has to be chosen from ['original', 'fixed', 'rand_zoom']. Got {spacing_type}.") + + keys = image_keys + label_keys + additional_keys + interp_mode = ["bilinear"] * len(image_keys) + ["nearest"] * len(label_keys) + + common_transform = [ + SelectItemsd(keys=keys, allow_missing_keys=True), + LoadImaged(keys=keys, allow_missing_keys=True), + EnsureChannelFirstd(keys=keys, allow_missing_keys=True), + Orientationd(keys=keys, axcodes="RAS", allow_missing_keys=True), + ] + + if modality == "mri": + common_transform.append(Lambdad(keys=image_keys, func=lambda x: x[select_channel : select_channel + 1, ...])) + + common_transform.extend(define_fixed_intensity_transform(modality, image_keys=image_keys)) + + if spacing_type == "fixed": + common_transform.append( + Spacingd(keys=image_keys + label_keys, allow_missing_keys=True, pixdim=spacing, mode=interp_mode) + ) + + random_transform = [] + if is_train and random_aug: + random_transform.extend(define_random_intensity_transform(modality, image_keys=image_keys)) + random_transform.extend( + [RandFlipd(keys=keys, allow_missing_keys=True, prob=0.5, spatial_axis=axis) for axis in range(3)] + + [ + RandRotate90d(keys=keys, allow_missing_keys=True, prob=0.5, spatial_axes=axes) + for axes in [(0, 1), (1, 2), (0, 2)] + ] + + [ + RandScaleIntensityd(keys=image_keys, allow_missing_keys=True, prob=0.3, factors=(0.9, 1.1)), + RandShiftIntensityd(keys=image_keys, allow_missing_keys=True, prob=0.3, offsets=0.05), + ] + ) + + if spacing_type == "rand_zoom": + random_transform.extend( + [ + RandZoomd( + keys=image_keys + label_keys, + allow_missing_keys=True, + prob=0.3, + min_zoom=0.5, + max_zoom=1.5, + keep_size=False, + mode=interp_mode, + ), + RandRotated( + keys=image_keys + label_keys, + allow_missing_keys=True, + prob=0.3, + range_x=0.1, + range_y=0.1, + range_z=0.1, + keep_size=True, + mode=interp_mode, + ), + ] + ) + + if is_train: + train_crop = [ + SpatialPadd(keys=keys, spatial_size=patch_size, allow_missing_keys=True), + RandSpatialCropd( + keys=keys, roi_size=patch_size, allow_missing_keys=True, random_size=False, random_center=True + ), + ] + else: + val_crop = ( + [DivisiblePadd(keys=keys, allow_missing_keys=True, k=k)] + if val_patch_size is None + else [ResizeWithPadOrCropd(keys=keys, allow_missing_keys=True, spatial_size=val_patch_size)] + ) + + final_transform = [EnsureTyped(keys=keys, dtype=output_dtype, allow_missing_keys=True)] + + if is_train: + train_transforms = Compose( + common_transform + random_transform + train_crop + final_transform + if random_aug + else common_transform + train_crop + final_transform + ) + return train_transforms + else: + val_transforms = Compose(common_transform + val_crop + final_transform) + return val_transforms + + +class VAE_Transform: + """ + A class to handle MAISI VAE transformations for different modalities. + """ + + def __init__( + self, + is_train: bool, + random_aug: bool, + k: int = 4, + patch_size: List[int] = [128, 128, 128], + val_patch_size: Optional[List[int]] = None, + output_dtype: torch.dtype = torch.float32, + spacing_type: str = "original", + spacing: Optional[List[float]] = None, + image_keys: List[str] = ["image"], + label_keys: List[str] = [], + additional_keys: List[str] = [], + select_channel: int = 0, + ): + """ + Initialize the VAE_Transform. + + Args: + is_train (bool): Whether it's for training or not. If True, the output transform will consider random_aug, the cropping will use "patch_size" for random cropping. If False, the output transform will alwasy treat "random_aug" as False, will use "val_patch_size" for central cropping. + random_aug (bool): Whether to apply random data augmentation for training. + k (int, optional): Patches should be divisible by k. Defaults to 4. + patch_size (List[int], optional): Size of the patches. Defaults to [128, 128, 128]. Will random crop patch for training. + val_patch_size (Optional[List[int]], optional): Size of validation patches. Defaults to None. If None, will use the whole volume for validation. If given, will central crop a patch for validation. + output_dtype (torch.dtype, optional): Output data type. Defaults to torch.float32. + spacing_type (str, optional): Type of spacing. Defaults to "original". Choose from ["original", "fixed", "rand_zoom"]. + spacing (Optional[List[float]], optional): Spacing values. Defaults to None. + image_keys (List[str], optional): List of image keys. Defaults to ["image"]. + label_keys (List[str], optional): List of label keys. Defaults to []. + additional_keys (List[str], optional): List of additional keys. Defaults to []. + select_channel (int, optional): Channel to select for multi-channel MRI. Defaults to 0. + """ + if spacing_type not in ["original", "fixed", "rand_zoom"]: + raise ValueError( + f"spacing_type has to be chosen from ['original', 'fixed', 'rand_zoom']. Got {spacing_type}." + ) + + self.is_train = is_train + self.transform_dict = {} + + for modality in ["ct", "mri"]: + self.transform_dict[modality] = define_vae_transform( + is_train=is_train, + modality=modality, + random_aug=random_aug, + k=k, + patch_size=patch_size, + val_patch_size=val_patch_size, + output_dtype=output_dtype, + spacing_type=spacing_type, + spacing=spacing, + image_keys=image_keys, + label_keys=label_keys, + additional_keys=additional_keys, + select_channel=select_channel, + ) + + def __call__(self, img: dict, fixed_modality: Optional[str] = None) -> dict: + """ + Apply the appropriate transform to the input image. + + Args: + img (dict): Input image dictionary. + fixed_modality (Optional[str], optional): Fixed modality to use. Defaults to None. + + Returns: + Composed Transform + + Raises: + ValueError: If the modality is not 'ct' or 'mri'. + """ + modality = fixed_modality or img["class"] + modality = modality.lower() # Normalize modality to lowercase + if modality not in ["ct", "mri"]: + warnings.warn( + f"Intensity transform only support {SUPPORT_MODALITIES}. Got {modality}. Will not do any intensity transform and will use original intensities." + ) + + transform = self.transform_dict[modality] + return transform(img)