Skip to content

Commit

Permalink
Merge branch 'main' into ddp
Browse files Browse the repository at this point in the history
  • Loading branch information
KumoLiu authored Mar 21, 2024
2 parents db05955 + 5d6f600 commit cafdf7b
Show file tree
Hide file tree
Showing 9 changed files with 12 additions and 16 deletions.
6 changes: 3 additions & 3 deletions detection/generate_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def generate_detection_train_transform(

train_transforms = Compose(
[
LoadImaged(keys=[image_key], meta_key_postfix="meta_dict"),
LoadImaged(keys=[image_key], image_only=False, meta_key_postfix="meta_dict"),
EnsureChannelFirstd(keys=[image_key]),
EnsureTyped(keys=[image_key, box_key], dtype=torch.float32),
EnsureTyped(keys=[label_key], dtype=torch.long),
Expand Down Expand Up @@ -224,7 +224,7 @@ def generate_detection_val_transform(

val_transforms = Compose(
[
LoadImaged(keys=[image_key], meta_key_postfix="meta_dict"),
LoadImaged(keys=[image_key], image_only=False, meta_key_postfix="meta_dict"),
EnsureChannelFirstd(keys=[image_key]),
EnsureTyped(keys=[image_key, box_key], dtype=torch.float32),
EnsureTyped(keys=[label_key], dtype=torch.long),
Expand Down Expand Up @@ -280,7 +280,7 @@ def generate_detection_inference_transform(

test_transforms = Compose(
[
LoadImaged(keys=[image_key], meta_key_postfix="meta_dict"),
LoadImaged(keys=[image_key], image_only=False, meta_key_postfix="meta_dict"),
EnsureChannelFirstd(keys=[image_key]),
EnsureTyped(keys=[image_key], dtype=torch.float32),
Orientationd(keys=[image_key], axcodes="RAS"),
Expand Down
1 change: 1 addition & 0 deletions detection/luna16_prepare_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def main():
[
LoadImaged(
keys=["image"],
image_only=False,
meta_key_postfix="meta_dict",
reader="itkreader",
affine_lps_to_ras=True,
Expand Down
1 change: 1 addition & 0 deletions detection/luna16_prepare_images_dicom.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def main():
[
LoadImaged(
keys=["image"],
image_only=False,
meta_key_postfix="meta_dict",
reader="itkreader",
affine_lps_to_ras=True,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@
"\n",
"test_transforms = Compose(\n",
" [\n",
" LoadImaged(keys=[\"kspace\"], reader=FastMRIReader, dtype=np.complex64),\n",
" LoadImaged(keys=[\"kspace\"], reader=FastMRIReader, image_only=False, dtype=np.complex64),\n",
" # user can also add other random transforms\n",
" ExtractDataKeyFromMetaKeyd(keys=[\"reconstruction_rss\", \"mask\"], meta_key=\"kspace_meta_dict\"),\n",
" MaskTransform,\n",
Expand Down
2 changes: 1 addition & 1 deletion reconstruction/MRI_reconstruction/unet_demo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def trainer(args):

train_transforms = Compose(
[
LoadImaged(keys=["kspace"], reader=FastMRIReader, dtype=np.complex64),
LoadImaged(keys=["kspace"], reader=FastMRIReader, image_only=False, dtype=np.complex64),
# user can also add other random transforms
ExtractDataKeyFromMetaKeyd(keys=["reconstruction_rss", "mask"], meta_key="kspace_meta_dict"),
MaskTransform,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@
"\n",
"test_transforms = Compose(\n",
" [\n",
" LoadImaged(keys=[\"kspace\"], reader=FastMRIReader, dtype=np.complex64),\n",
" LoadImaged(keys=[\"kspace\"], reader=FastMRIReader, image_only=False, dtype=np.complex64),\n",
" # user can also add other random transforms\n",
" ExtractDataKeyFromMetaKeyd(keys=[\"reconstruction_rss\", \"mask\"], meta_key=\"kspace_meta_dict\"),\n",
" MaskTransform,\n",
Expand Down
2 changes: 1 addition & 1 deletion reconstruction/MRI_reconstruction/varnet_demo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def trainer(args):

train_transforms = Compose(
[
LoadImaged(keys=["kspace"], reader=FastMRIReader, dtype=np.complex64),
LoadImaged(keys=["kspace"], reader=FastMRIReader, image_only=False, dtype=np.complex64),
# user can also add other random transforms but remember to disable randomness for val_transforms
ExtractDataKeyFromMetaKeyd(keys=["reconstruction_rss", "mask"], meta_key="kspace_meta_dict"),
MaskTransform,
Expand Down
2 changes: 1 addition & 1 deletion self_supervised_pretraining/vit_unetr_ssl/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ At the time of creation of this tutorial, the below additional dependencies are
To begin training with 2 GPU's please see the below example command for execution of the SSL multi-gpu training
script:

`CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 mgpu_ssl_train.py --batch_size=8 --epochs=500 --base_lr=2e-4 --logdir_path=/to/be/defined --output=/to/be/defined --data_root=/to/be/defined --json_path=/to/be/defined`
`CUDA_VISIBLE_DEVICES=0,1 torchrun --nproc_per_node=2 mgpu_ssl_train.py --batch_size=8 --epochs=500 --base_lr=2e-4 --logdir_path=/to/be/defined --output=/to/be/defined --data_root=/to/be/defined --json_path=/to/be/defined`

It can be configured to launch on more GPU's by adding the relevant `CUDA Device` ID in `CUDA_VISIBLE_DEVICES`
and increasing the total count of GPU's `--nproc_per_node`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,6 @@ def parse_option():
metavar="PATH",
help="root of output folder, the full path is <output>/<model_name>/<tag> (default: output)",
)
# Distributed Training
parser.add_argument("--local_rank", type=int, help="local rank for DistributedDataParallel")

# DL Training Hyper-parameters
parser.add_argument("--epochs", default=100, type=int, help="number of epochs")
Expand Down Expand Up @@ -139,10 +137,6 @@ def main(args):
data_list_file_path=json_path, is_segmentation=False, data_list_key="validation", base_dir=data_root
)

# TODO Delete the below print statements
print("List of training samples: {}".format(train_list))
print("List of validation samples: {}".format(val_list))

print("Total training data are {} and validation data are {}".format(len(train_list), len(val_list)))

train_dataset = CacheDataset(data=train_list, transform=train_transforms, cache_rate=1.0, num_workers=4)
Expand Down Expand Up @@ -191,7 +185,7 @@ def main(args):
optimizer = torch.optim.Adam(model.parameters(), lr=args.base_lr)

model = torch.nn.parallel.DistributedDataParallel(
model, device_ids=[args.local_rank], broadcast_buffers=False, find_unused_parameters=True
model, device_ids=[int(os.environ["LOCAL_RANK"])], broadcast_buffers=False, find_unused_parameters=True
)
model_without_ddp = model.module

Expand Down Expand Up @@ -340,7 +334,7 @@ def validate(data_loader, model, loss_functions):
else:
rank = -1
world_size = -1
torch.cuda.set_device(args.local_rank)
torch.cuda.set_device(rank)
torch.distributed.init_process_group(backend="nccl", init_method="env://", world_size=world_size, rank=rank)
torch.distributed.barrier()

Expand Down

0 comments on commit cafdf7b

Please sign in to comment.