forked from Project-MONAI/tutorials
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
MAISI VAE transforms (Project-MONAI#1755)
Fixes # . ### Description MAISI VAE transforms ### Checks <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Avoid including large-size files in the PR. - [x] Clean up long text outputs from code cells in the notebook. - [x] For security purposes, please check the contents and remove any sensitive info such as user names and private key. - [x] Ensure (1) hyperlinks and markdown anchors are working (2) use relative paths for tutorial repo files (3) put figure and graphs in the `./figure` folder - [ ] Notebook runs automatically `./runner.sh -t <path to .ipynb file>` --------- Signed-off-by: root <[email protected]> Signed-off-by: Can-Zhao <[email protected]> Co-authored-by: root <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
- Loading branch information
1 parent
6967caf
commit 12b2e55
Showing
2 changed files
with
348 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
{ | ||
"data_option":{ | ||
"random_aug": true, | ||
"spacing_type": "rand_zoom", | ||
"spacing": null, | ||
"select_channel": 0 | ||
}, | ||
"autoencoder_train": { | ||
"batch_size": 1, | ||
"patch_size": [128,128,128], | ||
"val_batch_size": 1, | ||
"val_patch_size": null, | ||
"val_sliding_window_patch_size": [192,192,128], | ||
"lr": 1e-4, | ||
"perceptual_weight": 0.3, | ||
"kl_weight": 1e-7, | ||
"adv_weight": 0.1, | ||
"recon_loss": "l1", | ||
"val_interval": 10, | ||
"cache": 0.5, | ||
"amp": true, | ||
"n_epochs": 12000 | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,324 @@ | ||
# Copyright (c) MONAI Consortium | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import warnings | ||
from typing import List, Optional | ||
|
||
import torch | ||
from monai.transforms import ( | ||
Compose, | ||
DivisiblePadd, | ||
EnsureChannelFirstd, | ||
EnsureTyped, | ||
Lambdad, | ||
LoadImaged, | ||
Orientationd, | ||
RandAdjustContrastd, | ||
RandBiasFieldd, | ||
RandFlipd, | ||
RandGibbsNoised, | ||
RandHistogramShiftd, | ||
RandRotate90d, | ||
RandRotated, | ||
RandScaleIntensityd, | ||
RandShiftIntensityd, | ||
RandSpatialCropd, | ||
RandZoomd, | ||
ResizeWithPadOrCropd, | ||
ScaleIntensityRanged, | ||
ScaleIntensityRangePercentilesd, | ||
SelectItemsd, | ||
Spacingd, | ||
SpatialPadd, | ||
) | ||
|
||
SUPPORT_MODALITIES = ["ct", "mri"] | ||
|
||
|
||
def define_fixed_intensity_transform(modality: str, image_keys: List[str] = ["image"]) -> List: | ||
""" | ||
Define fixed intensity transform based on the modality. | ||
Args: | ||
modality (str): The imaging modality, either 'ct' or 'mri'. | ||
image_keys (List[str], optional): List of image keys. Defaults to ["image"]. | ||
Returns: | ||
List: A list of intensity transforms. | ||
""" | ||
if modality not in SUPPORT_MODALITIES: | ||
warnings.warn( | ||
f"Intensity transform only support {SUPPORT_MODALITIES}. Got {modality}. Will not do any intensity transform and will use original intensities." | ||
) | ||
|
||
modality = modality.lower() # Normalize modality to lowercase | ||
|
||
intensity_transforms = { | ||
"mri": [ | ||
ScaleIntensityRangePercentilesd(keys=image_keys, lower=0.0, upper=99.5, b_min=0.0, b_max=1, clip=False) | ||
], | ||
"ct": [ScaleIntensityRanged(keys=image_keys, a_min=-1000, a_max=1000, b_min=0.0, b_max=1.0, clip=True)], | ||
} | ||
|
||
if modality not in intensity_transforms: | ||
return [] | ||
|
||
return intensity_transforms[modality] | ||
|
||
|
||
def define_random_intensity_transform(modality: str, image_keys: List[str] = ["image"]) -> List: | ||
""" | ||
Define random intensity transform based on the modality. | ||
Args: | ||
modality (str): The imaging modality, either 'ct' or 'mri'. | ||
image_keys (List[str], optional): List of image keys. Defaults to ["image"]. | ||
Returns: | ||
List: A list of random intensity transforms. | ||
""" | ||
modality = modality.lower() # Normalize modality to lowercase | ||
if modality not in SUPPORT_MODALITIES: | ||
warnings.warn( | ||
f"Intensity transform only support {SUPPORT_MODALITIES}. Got {modality}. Will not do any intensity transform and will use original intensities." | ||
) | ||
|
||
if modality == "ct": | ||
return [] # CT HU intensity is stable across different datasets | ||
elif modality == "mri": | ||
return [ | ||
RandBiasFieldd(keys=image_keys, prob=0.3, coeff_range=(0.0, 0.3)), | ||
RandGibbsNoised(keys=image_keys, prob=0.3, alpha=(0.5, 1.0)), | ||
RandAdjustContrastd(keys=image_keys, prob=0.3, gamma=(0.5, 2.0)), | ||
RandHistogramShiftd(keys=image_keys, prob=0.05, num_control_points=10), | ||
] | ||
else: | ||
return [] | ||
|
||
|
||
def define_vae_transform( | ||
is_train: bool, | ||
modality: str, | ||
random_aug: bool, | ||
k: int = 4, | ||
patch_size: List[int] = [128, 128, 128], | ||
val_patch_size: Optional[List[int]] = None, | ||
output_dtype: torch.dtype = torch.float32, | ||
spacing_type: str = "original", | ||
spacing: Optional[List[float]] = None, | ||
image_keys: List[str] = ["image"], | ||
label_keys: List[str] = [], | ||
additional_keys: List[str] = [], | ||
select_channel: int = 0, | ||
) -> tuple: | ||
""" | ||
Define the MAISI VAE transform pipeline for training or validation. | ||
Args: | ||
is_train (bool): Whether it's for training or not. If True, the output transform will consider random_aug, the cropping will use "patch_size" for random cropping. If False, the output transform will alwasy treat "random_aug" as False, will use "val_patch_size" for central cropping. | ||
modality (str): The imaging modality, either 'ct' or 'mri'. | ||
random_aug (bool): Whether to apply random data augmentation. | ||
k (int, optional): Patches should be divisible by k. Defaults to 4. | ||
patch_size (List[int], optional): Size of the patches. Defaults to [128, 128, 128]. Will random crop patch for training. | ||
val_patch_size (Optional[List[int]], optional): Size of validation patches. Defaults to None. If None, will use the whole volume for validation. If given, will central crop a patch for validation. | ||
output_dtype (torch.dtype, optional): Output data type. Defaults to torch.float32. | ||
spacing_type (str, optional): Type of spacing. Defaults to "original". Choose from ["original", "fixed", "rand_zoom"]. | ||
spacing (Optional[List[float]], optional): Spacing values. Defaults to None. | ||
image_keys (List[str], optional): List of image keys. Defaults to ["image"]. | ||
label_keys (List[str], optional): List of label keys. Defaults to []. | ||
additional_keys (List[str], optional): List of additional keys. Defaults to []. | ||
select_channel (int, optional): Channel to select for multi-channel MRI. Defaults to 0. | ||
Returns: | ||
tuple: A tuple containing Composed Transform train_transforms or val_transforms depending on 'is_train'. | ||
""" | ||
modality = modality.lower() # Normalize modality to lowercase | ||
if modality not in SUPPORT_MODALITIES: | ||
warnings.warn( | ||
f"Intensity transform only support {SUPPORT_MODALITIES}. Got {modality}. Will not do any intensity transform and will use original intensities." | ||
) | ||
|
||
if spacing_type not in ["original", "fixed", "rand_zoom"]: | ||
raise ValueError(f"spacing_type has to be chosen from ['original', 'fixed', 'rand_zoom']. Got {spacing_type}.") | ||
|
||
keys = image_keys + label_keys + additional_keys | ||
interp_mode = ["bilinear"] * len(image_keys) + ["nearest"] * len(label_keys) | ||
|
||
common_transform = [ | ||
SelectItemsd(keys=keys, allow_missing_keys=True), | ||
LoadImaged(keys=keys, allow_missing_keys=True), | ||
EnsureChannelFirstd(keys=keys, allow_missing_keys=True), | ||
Orientationd(keys=keys, axcodes="RAS", allow_missing_keys=True), | ||
] | ||
|
||
if modality == "mri": | ||
common_transform.append(Lambdad(keys=image_keys, func=lambda x: x[select_channel : select_channel + 1, ...])) | ||
|
||
common_transform.extend(define_fixed_intensity_transform(modality, image_keys=image_keys)) | ||
|
||
if spacing_type == "fixed": | ||
common_transform.append( | ||
Spacingd(keys=image_keys + label_keys, allow_missing_keys=True, pixdim=spacing, mode=interp_mode) | ||
) | ||
|
||
random_transform = [] | ||
if is_train and random_aug: | ||
random_transform.extend(define_random_intensity_transform(modality, image_keys=image_keys)) | ||
random_transform.extend( | ||
[RandFlipd(keys=keys, allow_missing_keys=True, prob=0.5, spatial_axis=axis) for axis in range(3)] | ||
+ [ | ||
RandRotate90d(keys=keys, allow_missing_keys=True, prob=0.5, spatial_axes=axes) | ||
for axes in [(0, 1), (1, 2), (0, 2)] | ||
] | ||
+ [ | ||
RandScaleIntensityd(keys=image_keys, allow_missing_keys=True, prob=0.3, factors=(0.9, 1.1)), | ||
RandShiftIntensityd(keys=image_keys, allow_missing_keys=True, prob=0.3, offsets=0.05), | ||
] | ||
) | ||
|
||
if spacing_type == "rand_zoom": | ||
random_transform.extend( | ||
[ | ||
RandZoomd( | ||
keys=image_keys + label_keys, | ||
allow_missing_keys=True, | ||
prob=0.3, | ||
min_zoom=0.5, | ||
max_zoom=1.5, | ||
keep_size=False, | ||
mode=interp_mode, | ||
), | ||
RandRotated( | ||
keys=image_keys + label_keys, | ||
allow_missing_keys=True, | ||
prob=0.3, | ||
range_x=0.1, | ||
range_y=0.1, | ||
range_z=0.1, | ||
keep_size=True, | ||
mode=interp_mode, | ||
), | ||
] | ||
) | ||
|
||
if is_train: | ||
train_crop = [ | ||
SpatialPadd(keys=keys, spatial_size=patch_size, allow_missing_keys=True), | ||
RandSpatialCropd( | ||
keys=keys, roi_size=patch_size, allow_missing_keys=True, random_size=False, random_center=True | ||
), | ||
] | ||
else: | ||
val_crop = ( | ||
[DivisiblePadd(keys=keys, allow_missing_keys=True, k=k)] | ||
if val_patch_size is None | ||
else [ResizeWithPadOrCropd(keys=keys, allow_missing_keys=True, spatial_size=val_patch_size)] | ||
) | ||
|
||
final_transform = [EnsureTyped(keys=keys, dtype=output_dtype, allow_missing_keys=True)] | ||
|
||
if is_train: | ||
train_transforms = Compose( | ||
common_transform + random_transform + train_crop + final_transform | ||
if random_aug | ||
else common_transform + train_crop + final_transform | ||
) | ||
return train_transforms | ||
else: | ||
val_transforms = Compose(common_transform + val_crop + final_transform) | ||
return val_transforms | ||
|
||
|
||
class VAE_Transform: | ||
""" | ||
A class to handle MAISI VAE transformations for different modalities. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
is_train: bool, | ||
random_aug: bool, | ||
k: int = 4, | ||
patch_size: List[int] = [128, 128, 128], | ||
val_patch_size: Optional[List[int]] = None, | ||
output_dtype: torch.dtype = torch.float32, | ||
spacing_type: str = "original", | ||
spacing: Optional[List[float]] = None, | ||
image_keys: List[str] = ["image"], | ||
label_keys: List[str] = [], | ||
additional_keys: List[str] = [], | ||
select_channel: int = 0, | ||
): | ||
""" | ||
Initialize the VAE_Transform. | ||
Args: | ||
is_train (bool): Whether it's for training or not. If True, the output transform will consider random_aug, the cropping will use "patch_size" for random cropping. If False, the output transform will alwasy treat "random_aug" as False, will use "val_patch_size" for central cropping. | ||
random_aug (bool): Whether to apply random data augmentation for training. | ||
k (int, optional): Patches should be divisible by k. Defaults to 4. | ||
patch_size (List[int], optional): Size of the patches. Defaults to [128, 128, 128]. Will random crop patch for training. | ||
val_patch_size (Optional[List[int]], optional): Size of validation patches. Defaults to None. If None, will use the whole volume for validation. If given, will central crop a patch for validation. | ||
output_dtype (torch.dtype, optional): Output data type. Defaults to torch.float32. | ||
spacing_type (str, optional): Type of spacing. Defaults to "original". Choose from ["original", "fixed", "rand_zoom"]. | ||
spacing (Optional[List[float]], optional): Spacing values. Defaults to None. | ||
image_keys (List[str], optional): List of image keys. Defaults to ["image"]. | ||
label_keys (List[str], optional): List of label keys. Defaults to []. | ||
additional_keys (List[str], optional): List of additional keys. Defaults to []. | ||
select_channel (int, optional): Channel to select for multi-channel MRI. Defaults to 0. | ||
""" | ||
if spacing_type not in ["original", "fixed", "rand_zoom"]: | ||
raise ValueError( | ||
f"spacing_type has to be chosen from ['original', 'fixed', 'rand_zoom']. Got {spacing_type}." | ||
) | ||
|
||
self.is_train = is_train | ||
self.transform_dict = {} | ||
|
||
for modality in ["ct", "mri"]: | ||
self.transform_dict[modality] = define_vae_transform( | ||
is_train=is_train, | ||
modality=modality, | ||
random_aug=random_aug, | ||
k=k, | ||
patch_size=patch_size, | ||
val_patch_size=val_patch_size, | ||
output_dtype=output_dtype, | ||
spacing_type=spacing_type, | ||
spacing=spacing, | ||
image_keys=image_keys, | ||
label_keys=label_keys, | ||
additional_keys=additional_keys, | ||
select_channel=select_channel, | ||
) | ||
|
||
def __call__(self, img: dict, fixed_modality: Optional[str] = None) -> dict: | ||
""" | ||
Apply the appropriate transform to the input image. | ||
Args: | ||
img (dict): Input image dictionary. | ||
fixed_modality (Optional[str], optional): Fixed modality to use. Defaults to None. | ||
Returns: | ||
Composed Transform | ||
Raises: | ||
ValueError: If the modality is not 'ct' or 'mri'. | ||
""" | ||
modality = fixed_modality or img["class"] | ||
modality = modality.lower() # Normalize modality to lowercase | ||
if modality not in ["ct", "mri"]: | ||
warnings.warn( | ||
f"Intensity transform only support {SUPPORT_MODALITIES}. Got {modality}. Will not do any intensity transform and will use original intensities." | ||
) | ||
|
||
transform = self.transform_dict[modality] | ||
return transform(img) |