diff --git a/generative/maisi/configs/config_maisi_controlnet_train.json b/generative/maisi/configs/config_maisi_controlnet_train.json index 69fb5855e6..50adf9a478 100644 --- a/generative/maisi/configs/config_maisi_controlnet_train.json +++ b/generative/maisi/configs/config_maisi_controlnet_train.json @@ -1,72 +1,4 @@ { - "random_seed": null, - "spatial_dims": 3, - "image_channels": 1, - "latent_channels": 4, - "diffusion_unet_def": { - "_target_": "monai.apps.generation.maisi.networks.diffusion_model_unet_maisi.DiffusionModelUNetMaisi", - "spatial_dims": "@spatial_dims", - "in_channels": "@latent_channels", - "out_channels": "@latent_channels", - "num_channels": [ - 64, - 128, - 256, - 512 - ], - "attention_levels": [ - false, - false, - true, - true - ], - "num_head_channels": [ - 0, - 0, - 32, - 32 - ], - "num_res_blocks": 2, - "use_flash_attention": true, - "include_top_region_index_input": true, - "include_bottom_region_index_input": true, - "include_spacing_input": true - }, - "controlnet_def": { - "_target_": "monai.apps.generation.maisi.networks.controlnet_maisi.ControlNetMaisi", - "spatial_dims": "@spatial_dims", - "in_channels": "@latent_channels", - "num_channels": [ - 64, - 128, - 256, - 512 - ], - "attention_levels": [ - false, - false, - true, - true - ], - "num_head_channels": [ - 0, - 0, - 32, - 32 - ], - "num_res_blocks": 2, - "use_flash_attention": true, - "conditioning_embedding_in_channels": 8, - "conditioning_embedding_num_channels": [8, 32, 64] - }, - "noise_scheduler": { - "_target_": "generative.networks.schedulers.DDPMScheduler", - "num_train_timesteps": 1000, - "beta_start": 0.0015, - "beta_end": 0.0195, - "schedule": "scaled_linear_beta", - "clip_sample": false - }, "controlnet_train": { "batch_size": 1, "cache_rate": 0.0, diff --git a/generative/maisi/scripts/train_controlnet.py b/generative/maisi/scripts/train_controlnet.py index 40f4ea3a1f..f059fe205e 100644 --- a/generative/maisi/scripts/train_controlnet.py +++ b/generative/maisi/scripts/train_controlnet.py @@ -40,8 +40,14 @@ def main(): parser.add_argument( "-c", "--config-file", + default="./configs/config_maisi.json", + help="config json file that stores network hyper-parameters", + ) + parser.add_argument( + "-t", + "--training-config", default="./configs/config_maisi_controlnet_train.json", - help="config json file that stores hyper-parameters", + help="config json file that stores training hyper-parameters", ) parser.add_argument("-g", "--gpus", default=1, type=int, help="number of gpus per node") args = parser.parse_args() @@ -66,11 +72,14 @@ def main(): env_dict = json.load(open(args.environment_file, "r")) config_dict = json.load(open(args.config_file, "r")) + training_config_dict = json.load(open(args.training_config, "r")) for k, v in env_dict.items(): setattr(args, k, v) for k, v in config_dict.items(): setattr(args, k, v) + for k, v in training_config_dict.items(): + setattr(args, k, v) # initialize tensorboard writer if rank == 0: