Skip to content

Commit

Permalink
Enhance readme for ddp cases in ldm tutorials (Project-MONAI#1857)
Browse files Browse the repository at this point in the history
Enhance readme for ddp cases in ldm tutorials
Add amp argument in maisi diffusion training notebook as a workaround
for Project-MONAI#1858.

### Checks
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Avoid including large-size files in the PR.
- [ ] Clean up long text outputs from code cells in the notebook.
- [ ] For security purposes, please check the contents and remove any
sensitive info such as user names and private key.
- [ ] Ensure (1) hyperlinks and markdown anchors are working (2) use
relative paths for tutorial repo files (3) put figure and graphs in the
`./figure` folder
- [ ] Notebook runs automatically `./runner.sh -t <path to .ipynb file>`

---------

Signed-off-by: YunLiu <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
KumoLiu and pre-commit-ci[bot] authored Oct 9, 2024
1 parent 2dac69c commit fa73d25
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 18 deletions.
3 changes: 3 additions & 0 deletions generation/2d_ldm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ torchrun \
--master_addr=localhost --master_port=1234 \
train_autoencoder.py -c ./config/config_train_32g.json -e ./config/environment.json -g ${NUM_GPUS_PER_NODE}
```
Please note that during multi-GPU training, additional GPU memory may be required. Users might need to reduce the `batch_size` accordingly based on their available resources to ensure smooth training.

<p align="center">
<img src="./figs/train_recon.png" alt="autoencoder train curve" width="45%" >
Expand Down Expand Up @@ -88,6 +89,8 @@ torchrun \
--master_addr=localhost --master_port=1234 \
train_diffusion.py -c ./config/config_train_32g.json -e ./config/environment.json -g ${NUM_GPUS_PER_NODE}
```
Please note that during multi-GPU training, additional GPU memory may be required. Users might need to reduce the `batch_size` accordingly based on their available resources to ensure smooth training.

<p align="center">
<img src="./figs/train_diffusion.png" alt="latent diffusion train curve" width="45%" >
&nbsp; &nbsp; &nbsp; &nbsp;
Expand Down
2 changes: 2 additions & 0 deletions generation/3d_ldm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ torchrun \
--master_addr=localhost --master_port=1234 \
train_autoencoder.py -c ./config/config_train_32g.json -e ./config/environment.json -g ${NUM_GPUS_PER_NODE}
```
Please note that during multi-GPU training, additional GPU memory may be required. Users might need to reduce the `batch_size` accordingly based on their available resources to ensure smooth training.

<p align="center">
<img src="./figs/train_recon.png" alt="autoencoder train curve" width="45%" >
Expand All @@ -87,6 +88,7 @@ torchrun \
--master_addr=localhost --master_port=1234 \
train_diffusion.py -c ./config/config_train_32g.json -e ./config/environment.json -g ${NUM_GPUS_PER_NODE}
```
Please note that during multi-GPU training, additional GPU memory may be required. Users might need to reduce the `batch_size` accordingly based on their available resources to ensure smooth training.
<p align="center">
<img src="./figs/train_diffusion.png" alt="latent diffusion train curve" width="45%" >
&nbsp; &nbsp; &nbsp; &nbsp;
Expand Down
4 changes: 3 additions & 1 deletion generation/maisi/maisi_diff_unet_training_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,9 @@
"\n",
"After all latent features have been created, we will initiate the multi-GPU script to train the latent diffusion model.\n",
"\n",
"The image generation process utilizes the [DDPM scheduler](https://arxiv.org/pdf/2006.11239) with 1,000 iterative steps. The diffusion model is optimized using L1 loss and a decayed learning rate scheduler. The batch size for this process is set to 1."
"The image generation process utilizes the [DDPM scheduler](https://arxiv.org/pdf/2006.11239) with 1,000 iterative steps. The diffusion model is optimized using L1 loss and a decayed learning rate scheduler. The batch size for this process is set to 1.\n",
"\n",
"Please be aware that using the H100 GPU may occasionally result in random segmentation faults. To avoid this issue, you can disable AMP by setting the `--no_amp` flag."
]
},
{
Expand Down
44 changes: 27 additions & 17 deletions generation/maisi/scripts/diff_model_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from torch.nn.parallel import DistributedDataParallel

import monai
from monai.data import ThreadDataLoader, partition_dataset
from monai.data import DataLoader, partition_dataset
from monai.transforms import Compose
from monai.utils import first

Expand All @@ -50,7 +50,7 @@ def load_filenames(data_list_path: str) -> list:

def prepare_data(
train_files: list, device: torch.device, cache_rate: float, num_workers: int = 2, batch_size: int = 1
) -> ThreadDataLoader:
) -> DataLoader:
"""
Prepare training data.
Expand All @@ -62,7 +62,7 @@ def prepare_data(
batch_size (int): Mini-batch size.
Returns:
ThreadDataLoader: Data loader for training.
DataLoader: Data loader for training.
"""

def _load_data_from_file(file_path, key):
Expand Down Expand Up @@ -90,7 +90,7 @@ def _load_data_from_file(file_path, key):
data=train_files, transform=train_transforms, cache_rate=cache_rate, num_workers=num_workers
)

return ThreadDataLoader(train_ds, num_workers=6, batch_size=batch_size, shuffle=True)
return DataLoader(train_ds, num_workers=6, batch_size=batch_size, shuffle=True)


def load_unet(args: argparse.Namespace, device: torch.device, logger: logging.Logger) -> torch.nn.Module:
Expand Down Expand Up @@ -124,14 +124,12 @@ def load_unet(args: argparse.Namespace, device: torch.device, logger: logging.Lo
return unet


def calculate_scale_factor(
train_loader: ThreadDataLoader, device: torch.device, logger: logging.Logger
) -> torch.Tensor:
def calculate_scale_factor(train_loader: DataLoader, device: torch.device, logger: logging.Logger) -> torch.Tensor:
"""
Calculate the scaling factor for the dataset.
Args:
train_loader (ThreadDataLoader): Data loader for training.
train_loader (DataLoader): Data loader for training.
device (torch.device): Device to use for calculation.
logger (logging.Logger): Logger for logging information.
Expand Down Expand Up @@ -181,7 +179,7 @@ def create_lr_scheduler(optimizer: torch.optim.Optimizer, total_steps: int) -> t
def train_one_epoch(
epoch: int,
unet: torch.nn.Module,
train_loader: ThreadDataLoader,
train_loader: DataLoader,
optimizer: torch.optim.Optimizer,
lr_scheduler: torch.optim.lr_scheduler.PolynomialLR,
loss_pt: torch.nn.L1Loss,
Expand All @@ -193,14 +191,15 @@ def train_one_epoch(
device: torch.device,
logger: logging.Logger,
local_rank: int,
amp: bool = True,
) -> torch.Tensor:
"""
Train the model for one epoch.
Args:
epoch (int): Current epoch number.
unet (torch.nn.Module): UNet model.
train_loader (ThreadDataLoader): Data loader for training.
train_loader (DataLoader): Data loader for training.
optimizer (torch.optim.Optimizer): Optimizer.
lr_scheduler (torch.optim.lr_scheduler.PolynomialLR): Learning rate scheduler.
loss_pt (torch.nn.L1Loss): Loss function.
Expand All @@ -212,6 +211,7 @@ def train_one_epoch(
device (torch.device): Device to use for training.
logger (logging.Logger): Logger for logging information.
local_rank (int): Local rank for distributed training.
amp (bool): Use automatic mixed precision training.
Returns:
torch.Tensor: Training loss for the epoch.
Expand All @@ -237,7 +237,7 @@ def train_one_epoch(

optimizer.zero_grad(set_to_none=True)

with autocast("cuda", enabled=True):
with autocast("cuda", enabled=amp):
noise = torch.randn(
(num_images_per_batch, 4, images.size(-3), images.size(-2), images.size(-1)), device=device
)
Expand All @@ -256,9 +256,13 @@ def train_one_epoch(

loss = loss_pt(noise_pred.float(), noise.float())

scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
if amp:
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
else:
loss.backward()
optimizer.step()

lr_scheduler.step()

Expand Down Expand Up @@ -312,14 +316,18 @@ def save_checkpoint(
)


def diff_model_train(env_config_path: str, model_config_path: str, model_def_path: str, num_gpus: int) -> None:
def diff_model_train(
env_config_path: str, model_config_path: str, model_def_path: str, num_gpus: int, amp: bool = True
) -> None:
"""
Main function to train a diffusion model.
Args:
env_config_path (str): Path to the environment configuration file.
model_config_path (str): Path to the model configuration file.
model_def_path (str): Path to the model definition file.
num_gpus (int): Number of GPUs to use for training.
amp (bool): Use automatic mixed precision training.
"""
args = load_config(env_config_path, model_config_path, model_def_path)
local_rank, world_size, device = initialize_distributed(num_gpus)
Expand Down Expand Up @@ -357,7 +365,7 @@ def diff_model_train(env_config_path: str, model_config_path: str, model_def_pat
)[local_rank]

train_loader = prepare_data(
train_files, device, args.diffusion_unet_train["cache_rate"], args.diffusion_unet_train["batch_size"]
train_files, device, args.diffusion_unet_train["cache_rate"], batch_size=args.diffusion_unet_train["batch_size"]
)

unet = load_unet(args, device, logger)
Expand Down Expand Up @@ -392,6 +400,7 @@ def diff_model_train(env_config_path: str, model_config_path: str, model_def_pat
device,
logger,
local_rank,
amp=amp,
)

loss_torch = loss_torch.tolist()
Expand Down Expand Up @@ -431,6 +440,7 @@ def diff_model_train(env_config_path: str, model_config_path: str, model_def_pat
"--model_def", type=str, default="./configs/config_maisi.json", help="Path to model definition file"
)
parser.add_argument("--num_gpus", type=int, default=1, help="Number of GPUs to use for training")
parser.add_argument("--no_amp", dest="amp", action="store_false", help="Disable automatic mixed precision training")

args = parser.parse_args()
diff_model_train(args.env_config, args.model_config, args.model_def, args.num_gpus)
diff_model_train(args.env_config, args.model_config, args.model_def, args.num_gpus, args.amp)

0 comments on commit fa73d25

Please sign in to comment.