Skip to content

Commit

Permalink
Merge branch 'fix-maisi' of https://github.com/KumoLiu/tutorials into…
Browse files Browse the repository at this point in the history
… fix-maisi
  • Loading branch information
KumoLiu committed Sep 30, 2024
2 parents d019799 + 97b107b commit 7ab71d2
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 7 deletions.
8 changes: 4 additions & 4 deletions generation/maisi/scripts/diff_model_create_training_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,9 @@ def process_file(


@torch.inference_mode()
def diff_model_create_training_data(env_config_path: str, model_config_path: str, model_def_path: str, num_gpus: int) -> None:
def diff_model_create_training_data(
env_config_path: str, model_config_path: str, model_def_path: str, num_gpus: int
) -> None:
"""
Create training data for the diffusion model.
Expand Down Expand Up @@ -224,9 +226,7 @@ def diff_model_create_training_data(env_config_path: str, model_config_path: str
parser.add_argument(
"--model_def", type=str, default="./configs/config_maisi.json", help="Path to model definition file"
)
parser.add_argument(
"--num_gpus", type=int, default=1, help="Number of GPUs to use for distributed training"
)
parser.add_argument("--num_gpus", type=int, default=1, help="Number of GPUs to use for distributed training")

args = parser.parse_args()
diff_model_create_training_data(args.env_config, args.model_config, args.model_def, args.num_gpus)
4 changes: 1 addition & 3 deletions generation/maisi/scripts/diff_model_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,9 +430,7 @@ def diff_model_train(env_config_path: str, model_config_path: str, model_def_pat
parser.add_argument(
"--model_def", type=str, default="./configs/config_maisi.json", help="Path to model definition file"
)
parser.add_argument(
"--num_gpus", type=int, default=1, help="Number of GPUs to use for training"
)
parser.add_argument("--num_gpus", type=int, default=1, help="Number of GPUs to use for training")

args = parser.parse_args()
diff_model_train(args.env_config, args.model_config, args.model_def, args.num_gpus)

0 comments on commit 7ab71d2

Please sign in to comment.