From d44c125e7bdd075b04cfcc198627010699946ab7 Mon Sep 17 00:00:00 2001 From: Dong Yang Date: Mon, 15 Jul 2024 01:51:30 -0600 Subject: [PATCH] Add tutorial to demonstrate how to train a generative model that can generate large image volumes (#1745) Fixes #1744. ### Description Add tutorial to demonstrate how to train a generative model that can generate large image volumes. ### Checks - [ ] Avoid including large-size files in the PR. - [ ] Clean up long text outputs from code cells in the notebook. - [ ] For security purposes, please check the contents and remove any sensitive info such as user names and private key. - [ ] Ensure (1) hyperlinks and markdown anchors are working (2) use relative paths for tutorial repo files (3) put figure and graphs in the `./figure` folder - [ ] Notebook runs automatically `./runner.sh -t ` --------- Signed-off-by: dongyang0122 Signed-off-by: Dong Yang Co-authored-by: Dong Yang Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../configs/config_maisi_diff_model.json | 42 ++ .../configs/environment_maisi_diff_model.json | 11 + .../maisi_diff_unet_training_tutorial.ipynb | 620 ++++++++++++++++++ generative/maisi/scripts/__init__.py | 10 + .../diff_model_create_training_data.py | 225 +++++++ generative/maisi/scripts/diff_model_infer.py | 296 +++++++++ .../maisi/scripts/diff_model_setting.py | 88 +++ generative/maisi/scripts/diff_model_train.py | 426 ++++++++++++ generative/maisi/scripts/sample.py | 4 +- 9 files changed, 1719 insertions(+), 3 deletions(-) create mode 100644 generative/maisi/configs/config_maisi_diff_model.json create mode 100644 generative/maisi/configs/environment_maisi_diff_model.json create mode 100644 generative/maisi/maisi_diff_unet_training_tutorial.ipynb create mode 100644 generative/maisi/scripts/__init__.py create mode 100644 generative/maisi/scripts/diff_model_create_training_data.py create mode 100644 generative/maisi/scripts/diff_model_infer.py create mode 100644 generative/maisi/scripts/diff_model_setting.py create mode 100644 generative/maisi/scripts/diff_model_train.py diff --git a/generative/maisi/configs/config_maisi_diff_model.json b/generative/maisi/configs/config_maisi_diff_model.json new file mode 100644 index 0000000000..4849ddf35e --- /dev/null +++ b/generative/maisi/configs/config_maisi_diff_model.json @@ -0,0 +1,42 @@ +{ + "noise_scheduler": { + "_target_": "generative.networks.schedulers.DDPMScheduler", + "num_train_timesteps": 1000, + "beta_start": 0.0015, + "beta_end": 0.0195, + "schedule": "scaled_linear_beta", + "clip_sample": false + }, + "diffusion_unet_train": { + "batch_size": 1, + "cache_rate": 0, + "lr": 0.0001, + "n_epochs": 1000 + }, + "diffusion_unet_inference": { + "dim": [ + 128, + 128, + 128 + ], + "spacing": [ + 1.0, + 1.25, + 0.75 + ], + "top_region_index": [ + 0, + 1, + 0, + 0 + ], + "bottom_region_index": [ + 0, + 0, + 1, + 0 + ], + "random_seed": 0, + "num_inference_steps": 10 + } +} diff --git a/generative/maisi/configs/environment_maisi_diff_model.json b/generative/maisi/configs/environment_maisi_diff_model.json new file mode 100644 index 0000000000..71d12ee4bc --- /dev/null +++ b/generative/maisi/configs/environment_maisi_diff_model.json @@ -0,0 +1,11 @@ +{ + "data_base_dir": "./data", + "embedding_base_dir": "./embeddings", + "json_data_list": "./dataset.json", + "model_dir": "./models", + "model_filename": "diff_unet_ckpt.pt", + "output_dir": "./predictions", + "output_prefix": "unet_3d", + "trained_autoencoder_path": "./models/autoencoder_epoch273.pt", + "existing_ckpt_filepath": null +} diff --git a/generative/maisi/maisi_diff_unet_training_tutorial.ipynb b/generative/maisi/maisi_diff_unet_training_tutorial.ipynb new file mode 100644 index 0000000000..cf45070605 --- /dev/null +++ b/generative/maisi/maisi_diff_unet_training_tutorial.ipynb @@ -0,0 +1,620 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "05fc7b5c", + "metadata": {}, + "source": [ + "Copyright (c) MONAI Consortium \n", + "Licensed under the Apache License, Version 2.0 (the \"License\"); \n", + "you may not use this file except in compliance with the License. \n", + "You may obtain a copy of the License at \n", + "    http://www.apache.org/licenses/LICENSE-2.0 \n", + "Unless required by applicable law or agreed to in writing, software \n", + "distributed under the License is distributed on an \"AS IS\" BASIS, \n", + "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. \n", + "See the License for the specific language governing permissions and \n", + "limitations under the License." + ] + }, + { + "cell_type": "markdown", + "id": "777b7dcb", + "metadata": {}, + "source": [ + "# Training a 3D Diffusion Model for Generating 3D Images with Various Sizes and Spacings" + ] + }, + { + "cell_type": "markdown", + "id": "c9ecfb90", + "metadata": {}, + "source": [ + "## Setup environment" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "58cbde9b", + "metadata": {}, + "outputs": [], + "source": [ + "!python -c \"import monai\" || pip install -q \"monai-weekly[pillow, tqdm]\"\n", + "!python -c \"import xformers\" || pip install -q xformers --index-url https://download.pytorch.org/whl/cu121" + ] + }, + { + "cell_type": "markdown", + "id": "d655b95c", + "metadata": {}, + "source": [ + "## Setup imports" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "e3bf0346", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "MONAI version: 1.3.1+27.g8cfbcbab\n", + "Numpy version: 1.26.4\n", + "Pytorch version: 2.3.1+cu121\n", + "MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False\n", + "MONAI rev id: 8cfbcbabd1529ef4090fb6f7ffbeef47d6b70cc2\n", + "MONAI __file__: /localhome//miniconda3/envs/monai-dev/lib/python3.11/site-packages/monai/__init__.py\n", + "\n", + "Optional dependencies:\n", + "Pytorch Ignite version: 0.4.11\n", + "ITK version: 5.4.0\n", + "Nibabel version: 5.2.1\n", + "scikit-image version: 0.24.0\n", + "scipy version: 1.13.1\n", + "Pillow version: 10.3.0\n", + "Tensorboard version: 2.17.0\n", + "gdown version: 5.2.0\n", + "TorchVision version: 0.18.1+cu121\n", + "tqdm version: 4.66.4\n", + "lmdb version: 1.4.1\n", + "psutil version: 6.0.0\n", + "pandas version: 2.2.2\n", + "einops version: 0.8.0\n", + "transformers version: NOT INSTALLED or UNKNOWN VERSION.\n", + "mlflow version: 2.14.1\n", + "pynrrd version: 1.0.0\n", + "clearml version: 1.16.2\n", + "\n", + "For details about installing the optional dependencies, please visit:\n", + " https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies\n", + "\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "The torch.distributed is either unavailable and uninitiated when RankFilter is instantiated.\n", + "If torch.distributed is used, please ensure that the RankFilter() is called\n", + "after torch.distributed.init_process_group() in the script.\n", + "\n" + ] + } + ], + "source": [ + "from scripts.diff_model_setting import setup_logging\n", + "import copy\n", + "import os\n", + "import json\n", + "import numpy as np\n", + "import nibabel as nib\n", + "import subprocess\n", + "\n", + "from monai.data import create_test_image_3d\n", + "from monai.config import print_config\n", + "\n", + "print_config()\n", + "\n", + "logger = setup_logging(\"notebook\")" + ] + }, + { + "cell_type": "markdown", + "id": "d8e29c23", + "metadata": {}, + "source": [ + "## Simulate a special dataset\n", + "\n", + "It is well known that AI takes time to train. We will simulate a small dataset and run training only for multiple epochs. Due to the nature of AI, the performance shouldn't be highly expected, but the entire pipeline will be completed within minutes!\n", + "\n", + "`sim_datalist` provides the information of the simulated datasets. It lists 2 training images. The size of the dimension is defined by the `sim_dim`." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "fc32a7fe", + "metadata": {}, + "outputs": [], + "source": [ + "sim_datalist = {\"training\": [{\"image\": \"tr_image_001.nii.gz\"}, {\"image\": \"tr_image_002.nii.gz\"}]}\n", + "\n", + "sim_dim = (128, 160, 96)" + ] + }, + { + "cell_type": "markdown", + "id": "b9ac7677", + "metadata": {}, + "source": [ + "## Generate images\n", + "\n", + "Now we can use MONAI `create_test_image_3d` and `nib.Nifti1Image` functions to generate the 3D simulated images under the work_dir" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "1b199078", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[2024-07-14 13:39:39.050][ INFO](notebook) - Generated simulated images.\n" + ] + } + ], + "source": [ + "work_dir = \"./temp_work_dir\"\n", + "if not os.path.isdir(work_dir):\n", + " os.makedirs(work_dir)\n", + "\n", + "dataroot_dir = os.path.join(work_dir, \"sim_dataroot\")\n", + "if not os.path.isdir(dataroot_dir):\n", + " os.makedirs(dataroot_dir)\n", + "\n", + "datalist_file = os.path.join(work_dir, \"sim_datalist.json\")\n", + "with open(datalist_file, \"w\") as f:\n", + " json.dump(sim_datalist, f)\n", + "\n", + "for d in sim_datalist[\"training\"]:\n", + " im, _ = create_test_image_3d(\n", + " sim_dim[0], sim_dim[1], sim_dim[2], rad_max=10, num_seg_classes=1, random_state=np.random.RandomState(42)\n", + " )\n", + " image_fpath = os.path.join(dataroot_dir, d[\"image\"])\n", + " nib.save(nib.Nifti1Image(im, affine=np.eye(4)), image_fpath)\n", + "\n", + "logger.info(\"Generated simulated images.\")" + ] + }, + { + "cell_type": "markdown", + "id": "c2389853", + "metadata": {}, + "source": [ + "## Set up directories and configurations" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "6c7b434c", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[2024-07-14 13:39:39.061][ INFO](notebook) - files and folders under work_dir: ['models', 'sim_datalist.json', 'embeddings', 'config_maisi.json', 'predictions', 'environment_maisi_diff_model.json', 'sim_dataroot', 'config_maisi_diff_model.json'].\n", + "[2024-07-14 13:39:39.062][ INFO](notebook) - number of GPUs: 1.\n" + ] + } + ], + "source": [ + "env_config_path = \"./configs/environment_maisi_diff_model.json\"\n", + "model_config_path = \"./configs/config_maisi_diff_model.json\"\n", + "model_def_path = \"./configs/config_maisi.json\"\n", + "\n", + "# Load environment configuration, model configuration and model definition\n", + "with open(env_config_path, \"r\") as f:\n", + " env_config = json.load(f)\n", + "\n", + "with open(model_config_path, \"r\") as f:\n", + " model_config = json.load(f)\n", + "\n", + "with open(model_def_path, \"r\") as f:\n", + " model_def = json.load(f)\n", + "\n", + "env_config_out = copy.deepcopy(env_config)\n", + "model_config_out = copy.deepcopy(model_config)\n", + "model_def_out = copy.deepcopy(model_def)\n", + "\n", + "# Set up directories based on configurations\n", + "env_config_out[\"data_base_dir\"] = dataroot_dir\n", + "env_config_out[\"embedding_base_dir\"] = os.path.join(work_dir, env_config_out[\"embedding_base_dir\"])\n", + "env_config_out[\"json_data_list\"] = datalist_file\n", + "env_config_out[\"model_dir\"] = os.path.join(work_dir, env_config_out[\"model_dir\"])\n", + "env_config_out[\"output_dir\"] = os.path.join(work_dir, env_config_out[\"output_dir\"])\n", + "env_config_out[\"trained_autoencoder_path\"] = None\n", + "\n", + "# Create necessary directories\n", + "os.makedirs(env_config_out[\"embedding_base_dir\"], exist_ok=True)\n", + "os.makedirs(env_config_out[\"model_dir\"], exist_ok=True)\n", + "os.makedirs(env_config_out[\"output_dir\"], exist_ok=True)\n", + "\n", + "env_config_filepath = os.path.join(work_dir, \"environment_maisi_diff_model.json\")\n", + "with open(env_config_filepath, \"w\") as f:\n", + " json.dump(env_config_out, f, sort_keys=True, indent=4)\n", + "\n", + "# Update model configuration for demo\n", + "max_epochs = 2\n", + "model_config_out[\"diffusion_unet_train\"][\"n_epochs\"] = max_epochs\n", + "\n", + "model_config_filepath = os.path.join(work_dir, \"config_maisi_diff_model.json\")\n", + "with open(model_config_filepath, \"w\") as f:\n", + " json.dump(model_config_out, f, sort_keys=True, indent=4)\n", + "\n", + "# Update model definition for demo\n", + "model_def_out[\"autoencoder_def\"][\"num_splits\"] = 4\n", + "model_def_filepath = os.path.join(work_dir, \"config_maisi.json\")\n", + "with open(model_def_filepath, \"w\") as f:\n", + " json.dump(model_def_out, f, sort_keys=True, indent=4)\n", + "\n", + "# Print files and folders under work_dir\n", + "logger.info(f\"files and folders under work_dir: {os.listdir(work_dir)}.\")\n", + "\n", + "# Adjust based on the number of GPUs you want to use\n", + "num_gpus = 1\n", + "logger.info(f\"number of GPUs: {num_gpus}.\")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "95ea6972", + "metadata": {}, + "outputs": [], + "source": [ + "def run_torchrun(module, module_args, num_gpus=1):\n", + " # Define the arguments for torchrun\n", + " num_nodes = 1\n", + "\n", + " # Build the torchrun command\n", + " torchrun_command = [\n", + " \"torchrun\",\n", + " \"--nproc_per_node\",\n", + " str(num_gpus),\n", + " \"--nnodes\",\n", + " str(num_nodes),\n", + " \"-m\",\n", + " module,\n", + " ] + module_args\n", + "\n", + " # Set the OMP_NUM_THREADS environment variable\n", + " env = os.environ.copy()\n", + " env[\"OMP_NUM_THREADS\"] = \"1\"\n", + "\n", + " # Execute the command\n", + " process = subprocess.Popen(torchrun_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, env=env)\n", + "\n", + " # Print the output in real-time\n", + " try:\n", + " while True:\n", + " output = process.stdout.readline()\n", + " if output == \"\" and process.poll() is not None:\n", + " break\n", + " if output:\n", + " print(output.strip())\n", + " except Exception as e:\n", + " print(f\"An error occurred: {e}\")\n", + " finally:\n", + " # Capture and print any remaining output\n", + " stdout, stderr = process.communicate()\n", + " print(stdout)\n", + " if stderr:\n", + " print(stderr)\n", + " return" + ] + }, + { + "cell_type": "markdown", + "id": "1c904f52", + "metadata": {}, + "source": [ + "## Step 1: Create Training Data" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "f45ea863", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[2024-07-14 13:39:39.072][ INFO](notebook) - Creating training data...\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "[2024-07-14 13:39:48.124][ INFO](creating training data) - Using device cuda:0\n", + "[2024-07-14 13:39:48.587][ERROR](creating training data) - The trained_autoencoder_path does not exist!\n", + "[2024-07-14 13:39:48.587][ INFO](creating training data) - filenames_raw: ['tr_image_001.nii.gz', 'tr_image_002.nii.gz']\n", + "[2024-07-14 13:39:48.705][ INFO](creating training data) - old dim: [128, 160, 96], old spacing: [1.0, 1.0, 1.0]\n", + "[2024-07-14 13:39:48.765][ INFO](creating training data) - new dim: (128, 128, 128), new affine: [[ 1. 0. 0. 0. ]\n", + " [ 0. 1.25 0. 0.125]\n", + " [ 0. 0. 0.75 -0.125]\n", + " [ 0. 0. 0. 1. ]]\n", + "[2024-07-14 13:39:48.765][ INFO](creating training data) - out_filename: ./temp_work_dir/./embeddings/tr_image_001_emb.nii.gz\n", + "[2024-07-14 13:39:49.430][ INFO](creating training data) - z: torch.Size([1, 4, 32, 32, 32]), torch.float32\n", + "[2024-07-14 13:39:49.546][ INFO](creating training data) - old dim: [128, 160, 96], old spacing: [1.0, 1.0, 1.0]\n", + "[2024-07-14 13:39:49.603][ INFO](creating training data) - new dim: (128, 128, 128), new affine: [[ 1. 0. 0. 0. ]\n", + " [ 0. 1.25 0. 0.125]\n", + " [ 0. 0. 0.75 -0.125]\n", + " [ 0. 0. 0. 1. ]]\n", + "[2024-07-14 13:39:49.604][ INFO](creating training data) - out_filename: ./temp_work_dir/./embeddings/tr_image_002_emb.nii.gz\n", + "[2024-07-14 13:39:52.715][ INFO](creating training data) - z: torch.Size([1, 4, 32, 32, 32]), torch.float32\n", + "\n" + ] + } + ], + "source": [ + "logger.info(\"Creating training data...\")\n", + "\n", + "# Define the arguments for torchrun\n", + "module = \"scripts.diff_model_create_training_data\"\n", + "module_args = [\n", + " \"--env_config\",\n", + " env_config_filepath,\n", + " \"--model_config\",\n", + " model_config_filepath,\n", + " \"--model_def\",\n", + " model_def_filepath,\n", + "]\n", + "\n", + "run_torchrun(module, module_args, num_gpus=num_gpus)" + ] + }, + { + "cell_type": "markdown", + "id": "ec5c0c4a", + "metadata": {}, + "source": [ + "## Create .json files for embedding files" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "0221a658", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[2024-07-14 13:39:55.677][ INFO](notebook) - data: {'dim': (32, 32, 32), 'spacing': [1.0, 1.25, 0.75], 'top_region_index': [0, 1, 0, 0], 'bottom_region_index': [0, 0, 1, 0]}.\n", + "[2024-07-14 13:39:55.679][ INFO](notebook) - data: {'dim': (32, 32, 32), 'spacing': [1.0, 1.25, 0.75], 'top_region_index': [0, 1, 0, 0], 'bottom_region_index': [0, 0, 1, 0]}.\n", + "[2024-07-14 13:39:55.679][ INFO](notebook) - Completed creating .json files for all embedding files.\n" + ] + } + ], + "source": [ + "def list_gz_files(folder_path):\n", + " \"\"\"List all .gz files in the folder and its subfolders.\"\"\"\n", + " gz_files = []\n", + " for root, _, files in os.walk(folder_path):\n", + " for file in files:\n", + " if file.endswith(\".gz\"):\n", + " gz_files.append(os.path.join(root, file))\n", + " return gz_files\n", + "\n", + "\n", + "def create_json_files(gz_files):\n", + " \"\"\"Create .json files for each .gz file with the specified keys and values.\"\"\"\n", + " for gz_file in gz_files:\n", + " # Load the NIfTI image\n", + " img = nib.load(gz_file)\n", + "\n", + " # Get the dimensions and spacing\n", + " dimensions = img.shape\n", + " dimensions = dimensions[:3]\n", + " spacing = img.header.get_zooms()[:3]\n", + " spacing = spacing[:3]\n", + " spacing = [float(_item) for _item in spacing]\n", + "\n", + " # Create the dictionary with the specified keys and values\n", + " # The region can be selected from one of four regions from top to bottom.\n", + " # [1,0,0,0] is the head and neck, [0,1,0,0] is the chest region, [0,0,1,0]\n", + " # is the abdomen region, and [0,0,0,1] is the lower body region.\n", + " data = {\n", + " \"dim\": dimensions,\n", + " \"spacing\": spacing,\n", + " \"top_region_index\": [0, 1, 0, 0], # chest region\n", + " \"bottom_region_index\": [0, 0, 1, 0], # abdomen region\n", + " }\n", + " logger.info(f\"data: {data}.\")\n", + "\n", + " # Create the .json filename\n", + " json_filename = gz_file + \".json\"\n", + "\n", + " # Write the dictionary to the .json file\n", + " with open(json_filename, \"w\") as json_file:\n", + " json.dump(data, json_file, indent=4)\n", + "\n", + "\n", + "folder_path = env_config_out[\"embedding_base_dir\"]\n", + "gz_files = list_gz_files(folder_path)\n", + "create_json_files(gz_files)\n", + "\n", + "logger.info(\"Completed creating .json files for all embedding files.\")" + ] + }, + { + "cell_type": "markdown", + "id": "e81a9e48", + "metadata": {}, + "source": [ + "## Step 2: Train the Model" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "ade6389d", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[2024-07-14 13:39:55.683][ INFO](notebook) - Training the model...\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "[2024-07-14 13:40:03.098][ INFO](training) - Using cuda:0 of 1\n", + "[2024-07-14 13:40:03.098][ INFO](training) - [config] ckpt_folder -> ./temp_work_dir/./models.\n", + "[2024-07-14 13:40:03.098][ INFO](training) - [config] data_root -> ./temp_work_dir/./embeddings.\n", + "[2024-07-14 13:40:03.098][ INFO](training) - [config] data_list -> ./temp_work_dir/sim_datalist.json.\n", + "[2024-07-14 13:40:03.098][ INFO](training) - [config] lr -> 0.0001.\n", + "[2024-07-14 13:40:03.098][ INFO](training) - [config] num_epochs -> 2.\n", + "[2024-07-14 13:40:03.098][ INFO](training) - [config] num_train_timesteps -> 1000.\n", + "[2024-07-14 13:40:03.098][ INFO](training) - num_files_train: 2\n", + "[2024-07-14 13:40:07.396][ INFO](training) - Training from scratch.\n", + "[2024-07-14 13:40:07.721][ INFO](training) - Scaling factor set to 0.8950040340423584.\n", + "[2024-07-14 13:40:07.722][ INFO](training) - scale_factor -> 0.8950040340423584.\n", + "[2024-07-14 13:40:07.726][ INFO](training) - torch.set_float32_matmul_precision -> highest.\n", + "[2024-07-14 13:40:07.726][ INFO](training) - Epoch 1, lr 0.0001.\n", + "[2024-07-14 13:40:08.760][ INFO](training) - [2024-07-14 13:40:08] epoch 1, iter 1/2, loss: 0.7985, lr: 0.000100000000.\n", + "[2024-07-14 13:40:08.875][ INFO](training) - [2024-07-14 13:40:08] epoch 1, iter 2/2, loss: 0.7936, lr: 0.000056250000.\n", + "[2024-07-14 13:40:08.877][ INFO](training) - epoch 1 average loss: 0.7961.\n", + "[2024-07-14 13:40:09.694][ INFO](training) - Epoch 2, lr 2.5e-05.\n", + "[2024-07-14 13:40:10.685][ INFO](training) - [2024-07-14 13:40:10] epoch 2, iter 1/2, loss: 0.7902, lr: 0.000025000000.\n", + "[2024-07-14 13:40:10.799][ INFO](training) - [2024-07-14 13:40:10] epoch 2, iter 2/2, loss: 0.7883, lr: 0.000006250000.\n", + "[2024-07-14 13:40:10.802][ INFO](training) - epoch 2 average loss: 0.7893.\n", + "\n" + ] + } + ], + "source": [ + "logger.info(\"Training the model...\")\n", + "\n", + "# Define the arguments for torchrun\n", + "module = \"scripts.diff_model_train\"\n", + "module_args = [\n", + " \"--env_config\",\n", + " env_config_filepath,\n", + " \"--model_config\",\n", + " model_config_filepath,\n", + " \"--model_def\",\n", + " model_def_filepath,\n", + "]\n", + "\n", + "run_torchrun(module, module_args, num_gpus=num_gpus)" + ] + }, + { + "cell_type": "markdown", + "id": "4bdf7b17", + "metadata": {}, + "source": [ + "## Step 3: Infer using the Trained Model" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "1626526d", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[2024-07-14 13:40:17.361][ INFO](notebook) - Running inference...\n", + "[2024-07-14 13:40:43.937][ INFO](notebook) - Completed all steps.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "[2024-07-14 13:40:27.140][ INFO](inference) - Using cuda:0 of 1 with random seed: 47698\n", + "[2024-07-14 13:40:27.140][ INFO](inference) - [config] ckpt_filepath -> ./temp_work_dir/./models/diff_unet_ckpt.pt.\n", + "[2024-07-14 13:40:27.140][ INFO](inference) - [config] random_seed -> 47698.\n", + "[2024-07-14 13:40:27.140][ INFO](inference) - [config] output_prefix -> unet_3d.\n", + "[2024-07-14 13:40:27.140][ INFO](inference) - [config] output_size -> (128, 128, 128).\n", + "[2024-07-14 13:40:27.140][ INFO](inference) - [config] out_spacing -> (1.0, 1.25, 0.75).\n", + "[2024-07-14 13:40:27.614][ERROR](inference) - The trained_autoencoder_path does not exist!\n", + "[2024-07-14 13:40:31.800][ INFO](inference) - checkpoints ./temp_work_dir/./models/diff_unet_ckpt.pt loaded.\n", + "[2024-07-14 13:40:31.801][ INFO](inference) - scale_factor -> 0.8950040340423584.\n", + "[2024-07-14 13:40:31.802][ INFO](inference) - num_downsample_level -> 4, divisor -> 4.\n", + "[2024-07-14 13:40:31.803][ INFO](inference) - noise: cuda:0, torch.float32, \n", + "\n", + " 0%| | 0/10 [00:00 Compose: + """ + Create a set of MONAI transforms for preprocessing. + + Args: + dim (tuple, optional): New dimensions for resizing. Defaults to None. + + Returns: + Compose: Composed MONAI transforms. + """ + if dim: + return Compose( + [ + monai.transforms.LoadImaged(keys="image"), + monai.transforms.EnsureChannelFirstd(keys="image"), + monai.transforms.Orientationd(keys="image", axcodes="RAS"), + monai.transforms.EnsureTyped(keys="image", dtype=torch.float32), + monai.transforms.ScaleIntensityRanged( + keys="image", a_min=-1000, a_max=1000, b_min=0, b_max=1, clip=True + ), + monai.transforms.Resized(keys="image", spatial_size=dim, mode="trilinear"), + ] + ) + else: + return Compose( + [ + monai.transforms.LoadImaged(keys="image"), + monai.transforms.EnsureChannelFirstd(keys="image"), + monai.transforms.Orientationd(keys="image", axcodes="RAS"), + ] + ) + + +def round_number(number: int, base_number: int = 128) -> int: + """ + Round the number to the nearest multiple of the base number, with a minimum value of the base number. + + Args: + number (int): Number to be rounded. + base_number (int): Number to be common divisor. + + Returns: + int: Rounded number. + """ + new_number = max(round(float(number) / float(base_number)), 1.0) * float(base_number) + return int(new_number) + + +def load_filenames(data_list_path: str) -> list: + """ + Load filenames from the JSON data list. + + Args: + data_list_path (str): Path to the JSON data list file. + + Returns: + list: List of filenames. + """ + with open(data_list_path, "r") as file: + json_data = json.load(file) + filenames_raw = json_data["training"] + return [_item["image"] for _item in filenames_raw] + + +def process_file( + filepath: str, + args: argparse.Namespace, + autoencoder: torch.nn.Module, + device: torch.device, + plain_transforms: Compose, + new_transforms: Compose, + logger: logging.Logger, +) -> None: + """ + Process a single file to create training data. + + Args: + filepath (str): Path to the file to be processed. + args (argparse.Namespace): Configuration arguments. + autoencoder (torch.nn.Module): Autoencoder model. + device (torch.device): Device to process the file on. + plain_transforms (Compose): Plain transforms. + new_transforms (Compose): New transforms. + logger (logging.Logger): Logger for logging information. + """ + out_filename_base = filepath.replace(".gz", "").replace(".nii", "") + out_filename_base = os.path.join(args.embedding_base_dir, out_filename_base) + out_filename = out_filename_base + "_emb.nii.gz" + + if os.path.isfile(out_filename): + return + + test_data = {"image": os.path.join(args.data_base_dir, filepath)} + transformed_data = plain_transforms(test_data) + nda = transformed_data["image"] + + dim = [int(nda.meta["dim"][_i]) for _i in range(1, 4)] + spacing = [float(nda.meta["pixdim"][_i]) for _i in range(1, 4)] + + logger.info(f"old dim: {dim}, old spacing: {spacing}") + + new_data = new_transforms(test_data) + nda_image = new_data["image"] + + new_affine = nda_image.meta["affine"].numpy() + nda_image = nda_image.numpy().squeeze() + + logger.info(f"new dim: {nda_image.shape}, new affine: {new_affine}") + + try: + out_path = Path(out_filename) + out_path.parent.mkdir(parents=True, exist_ok=True) + logger.info(f"out_filename: {out_filename}") + + with torch.cuda.amp.autocast(): + pt_nda = torch.from_numpy(nda_image).float().to(device).unsqueeze(0).unsqueeze(0) + z = autoencoder.encode_stage_2_inputs(pt_nda) + logger.info(f"z: {z.size()}, {z.dtype}") + + out_nda = z.squeeze().cpu().detach().numpy().transpose(1, 2, 3, 0) + out_img = nib.Nifti1Image(np.float32(out_nda), affine=new_affine) + nib.save(out_img, out_filename) + except Exception as e: + logger.error(f"Error processing {filepath}: {e}") + + +@torch.inference_mode() +def diff_model_create_training_data(env_config_path: str, model_config_path: str, model_def_path: str) -> None: + """ + Create training data for the diffusion model. + + Args: + env_config_path (str): Path to the environment configuration file. + model_config_path (str): Path to the model configuration file. + model_def_path (str): Path to the model definition file. + """ + args = load_config(env_config_path, model_config_path, model_def_path) + local_rank, world_size, device = initialize_distributed() + logger = setup_logging("creating training data") + logger.info(f"Using device {device}") + + autoencoder = define_instance(args, "autoencoder_def").to(device) + try: + checkpoint_autoencoder = load_autoencoder_ckpt(args.trained_autoencoder_path) + autoencoder.load_state_dict(checkpoint_autoencoder) + except Exception: + logger.error("The trained_autoencoder_path does not exist!") + + Path(args.embedding_base_dir).mkdir(parents=True, exist_ok=True) + + filenames_raw = load_filenames(args.json_data_list) + logger.info(f"filenames_raw: {filenames_raw}") + + plain_transforms = create_transforms(dim=None) + + for _iter in range(len(filenames_raw)): + if _iter % world_size != local_rank: + continue + + filepath = filenames_raw[_iter] + new_dim = tuple( + round_number( + int(plain_transforms({"image": os.path.join(args.data_base_dir, filepath)})["image"].meta["dim"][_i]) + ) + for _i in range(1, 4) + ) + new_transforms = create_transforms(new_dim) + + process_file(filepath, args, autoencoder, device, plain_transforms, new_transforms, logger) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Diffusion Model Training Data Creation") + parser.add_argument( + "--env_config", + type=str, + default="./configs/environment_maisi_diff_model_train.json", + help="Path to environment configuration file", + ) + parser.add_argument( + "--model_config", + type=str, + default="./configs/config_maisi_diff_model_train.json", + help="Path to model training/inference configuration", + ) + parser.add_argument( + "--model_def", type=str, default="./configs/config_maisi.json", help="Path to model definition file" + ) + + args = parser.parse_args() + diff_model_create_training_data(args.env_config, args.model_config, args.model_def) diff --git a/generative/maisi/scripts/diff_model_infer.py b/generative/maisi/scripts/diff_model_infer.py new file mode 100644 index 0000000000..aa4938164b --- /dev/null +++ b/generative/maisi/scripts/diff_model_infer.py @@ -0,0 +1,296 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import argparse +import logging +import os +import random +from datetime import datetime + +import nibabel as nib +import numpy as np +import torch +from tqdm import tqdm + +from monai.inferers import sliding_window_inference +from monai.utils import set_determinism + +from .diff_model_setting import initialize_distributed, load_config, setup_logging +from .sample import ReconModel +from .utils import define_instance, load_autoencoder_ckpt + + +def set_random_seed(seed: int) -> int: + """ + Set random seed for reproducibility. + + Args: + seed (int): Random seed. + + Returns: + int: Set random seed. + """ + random_seed = random.randint(0, 99999) if seed is None else seed + set_determinism(random_seed) + return random_seed + + +def load_models(args: argparse.Namespace, device: torch.device, logger: logging.Logger) -> tuple: + """ + Load the autoencoder and UNet models. + + Args: + args (argparse.Namespace): Configuration arguments. + device (torch.device): Device to load models on. + logger (logging.Logger): Logger for logging information. + + Returns: + tuple: Loaded autoencoder, UNet model, and scale factor. + """ + autoencoder = define_instance(args, "autoencoder_def").to(device) + try: + checkpoint_autoencoder = load_autoencoder_ckpt(args.trained_autoencoder_path) + autoencoder.load_state_dict(checkpoint_autoencoder) + except Exception: + logger.error("The trained_autoencoder_path does not exist!") + + unet = define_instance(args, "diffusion_unet_def").to(device) + checkpoint = torch.load(f"{args.model_dir}/{args.model_filename}", map_location=device) + unet.load_state_dict(checkpoint["unet_state_dict"], strict=True) + logger.info(f"checkpoints {args.model_dir}/{args.model_filename} loaded.") + + scale_factor = checkpoint["scale_factor"] + logger.info(f"scale_factor -> {scale_factor}.") + + return autoencoder, unet, scale_factor + + +def prepare_tensors(args: argparse.Namespace, device: torch.device) -> tuple: + """ + Prepare necessary tensors for inference. + + Args: + args (argparse.Namespace): Configuration arguments. + device (torch.device): Device to load tensors on. + + Returns: + tuple: Prepared top_region_index_tensor, bottom_region_index_tensor, and spacing_tensor. + """ + top_region_index_tensor = np.array(args.diffusion_unet_inference["top_region_index"]).astype(float) * 1e2 + bottom_region_index_tensor = np.array(args.diffusion_unet_inference["bottom_region_index"]).astype(float) * 1e2 + spacing_tensor = np.array(args.diffusion_unet_inference["spacing"]).astype(float) * 1e2 + + top_region_index_tensor = torch.from_numpy(top_region_index_tensor[np.newaxis, :]).half().to(device) + bottom_region_index_tensor = torch.from_numpy(bottom_region_index_tensor[np.newaxis, :]).half().to(device) + spacing_tensor = torch.from_numpy(spacing_tensor[np.newaxis, :]).half().to(device) + + return top_region_index_tensor, bottom_region_index_tensor, spacing_tensor + + +def run_inference( + args: argparse.Namespace, + device: torch.device, + autoencoder: torch.nn.Module, + unet: torch.nn.Module, + scale_factor: float, + top_region_index_tensor: torch.Tensor, + bottom_region_index_tensor: torch.Tensor, + spacing_tensor: torch.Tensor, + output_size: tuple, + divisor: int, + logger: logging.Logger, +) -> np.ndarray: + """ + Run the inference to generate synthetic images. + + Args: + args (argparse.Namespace): Configuration arguments. + device (torch.device): Device to run inference on. + autoencoder (torch.nn.Module): Autoencoder model. + unet (torch.nn.Module): UNet model. + scale_factor (float): Scale factor for the model. + top_region_index_tensor (torch.Tensor): Top region index tensor. + bottom_region_index_tensor (torch.Tensor): Bottom region index tensor. + spacing_tensor (torch.Tensor): Spacing tensor. + output_size (tuple): Output size of the synthetic image. + divisor (int): Divisor for downsample level. + logger (logging.Logger): Logger for logging information. + + Returns: + np.ndarray: Generated synthetic image data. + """ + noise = torch.randn( + (1, args.latent_channels, output_size[0] // divisor, output_size[1] // divisor, output_size[2] // divisor), + device=device, + ) + logger.info(f"noise: {noise.device}, {noise.dtype}, {type(noise)}") + + image = noise + noise_scheduler = define_instance(args, "noise_scheduler") + noise_scheduler.set_timesteps(num_inference_steps=args.diffusion_unet_inference["num_inference_steps"]) + + recon_model = ReconModel(autoencoder=autoencoder, scale_factor=scale_factor).to(device) + autoencoder.eval() + unet.eval() + + with torch.cuda.amp.autocast(enabled=True): + for t in tqdm(noise_scheduler.timesteps, ncols=110): + model_output = unet( + x=image, + timesteps=torch.Tensor((t,)).to(device), + top_region_index_tensor=top_region_index_tensor, + bottom_region_index_tensor=bottom_region_index_tensor, + spacing_tensor=spacing_tensor, + ) + image, _ = noise_scheduler.step(model_output, t, image) + + synthetic_images = sliding_window_inference( + inputs=image, + roi_size=( + min(output_size[0] // divisor // 4 * 3, 96), + min(output_size[1] // divisor // 4 * 3, 96), + min(output_size[2] // divisor // 4 * 3, 96), + ), + sw_batch_size=1, + predictor=recon_model, + mode="gaussian", + overlap=2.0 / 3.0, + sw_device=device, + device=device, + ) + + data = synthetic_images.squeeze().cpu().detach().numpy() + a_min, a_max, b_min, b_max = -1000, 1000, 0, 1 + data = (data - b_min) / (b_max - b_min) * (a_max - a_min) + a_min + data = np.clip(data, a_min, a_max) + return np.int16(data) + + +def save_image( + data: np.ndarray, output_size: tuple, out_spacing: tuple, output_path: str, logger: logging.Logger +) -> None: + """ + Save the generated synthetic image to a file. + + Args: + data (np.ndarray): Synthetic image data. + output_size (tuple): Output size of the image. + out_spacing (tuple): Spacing of the output image. + output_path (str): Path to save the output image. + logger (logging.Logger): Logger for logging information. + """ + out_affine = np.eye(4) + for i in range(3): + out_affine[i, i] = out_spacing[i] + + new_image = nib.Nifti1Image(data, affine=out_affine) + os.makedirs(os.path.dirname(output_path), exist_ok=True) + nib.save(new_image, output_path) + logger.info(f"Saved {output_path}.") + + +@torch.inference_mode() +def diff_model_infer(env_config_path: str, model_config_path: str, model_def_path: str) -> None: + """ + Main function to run the diffusion model inference. + + Args: + env_config_path (str): Path to the environment configuration file. + model_config_path (str): Path to the model configuration file. + model_def_path (str): Path to the model definition file. + """ + args = load_config(env_config_path, model_config_path, model_def_path) + local_rank, world_size, device = initialize_distributed() + logger = setup_logging("inference") + random_seed = set_random_seed( + args.diffusion_unet_inference["random_seed"] + local_rank + if args.diffusion_unet_inference["random_seed"] + else None + ) + logger.info(f"Using {device} of {world_size} with random seed: {random_seed}") + + output_size = tuple(args.diffusion_unet_inference["dim"]) + out_spacing = tuple(args.diffusion_unet_inference["spacing"]) + output_prefix = args.output_prefix + ckpt_filepath = f"{args.model_dir}/{args.model_filename}" + + if local_rank == 0: + logger.info(f"[config] ckpt_filepath -> {ckpt_filepath}.") + logger.info(f"[config] random_seed -> {random_seed}.") + logger.info(f"[config] output_prefix -> {output_prefix}.") + logger.info(f"[config] output_size -> {output_size}.") + logger.info(f"[config] out_spacing -> {out_spacing}.") + + autoencoder, unet, scale_factor = load_models(args, device, logger) + num_downsample_level = max( + 1, + ( + len(args.diffusion_unet_def["num_channels"]) + if isinstance(args.diffusion_unet_def["num_channels"], list) + else len(args.diffusion_unet_def["attention_levels"]) + ), + ) + divisor = 2 ** (num_downsample_level - 2) + logger.info(f"num_downsample_level -> {num_downsample_level}, divisor -> {divisor}.") + + top_region_index_tensor, bottom_region_index_tensor, spacing_tensor = prepare_tensors(args, device) + data = run_inference( + args, + device, + autoencoder, + unet, + scale_factor, + top_region_index_tensor, + bottom_region_index_tensor, + spacing_tensor, + output_size, + divisor, + logger, + ) + + timestamp = datetime.now().strftime("%Y%m%d%H%M%S") + output_path = "{0}/{1}_seed{2}_size{3:d}x{4:d}x{5:d}_spacing{6:.2f}x{7:.2f}x{8:.2f}_{9}.nii.gz".format( + args.output_dir, + output_prefix, + random_seed, + output_size[0], + output_size[1], + output_size[2], + out_spacing[0], + out_spacing[1], + out_spacing[2], + timestamp, + ) + save_image(data, output_size, out_spacing, output_path, logger) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Diffusion Model Inference") + parser.add_argument( + "--env_config", + type=str, + default="./configs/environment_maisi_diff_model_train.json", + help="Path to environment configuration file", + ) + parser.add_argument( + "--model_config", + type=str, + default="./configs/config_maisi_diff_model_train.json", + help="Path to model training/inference configuration", + ) + parser.add_argument( + "--model_def", type=str, default="./configs/config_maisi.json", help="Path to model definition file" + ) + + args = parser.parse_args() + diff_model_infer(args.env_config, args.model_config, args.model_def) diff --git a/generative/maisi/scripts/diff_model_setting.py b/generative/maisi/scripts/diff_model_setting.py new file mode 100644 index 0000000000..0dedabf532 --- /dev/null +++ b/generative/maisi/scripts/diff_model_setting.py @@ -0,0 +1,88 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import argparse +import json +import logging + +import torch +import torch.distributed as dist + +from monai.utils import RankFilter + + +def setup_logging(logger_name: str = "") -> logging.Logger: + """ + Setup the logging configuration. + + Args: + logger_name (str): logger name. + + Returns: + logging.Logger: Configured logger. + """ + logger = logging.getLogger(logger_name) + logger.addFilter(RankFilter()) + logging.basicConfig( + level=logging.INFO, + format="[%(asctime)s.%(msecs)03d][%(levelname)5s](%(name)s) - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + return logger + + +def load_config(env_config_path: str, model_config_path: str, model_def_path: str) -> argparse.Namespace: + """ + Load configuration from JSON files. + + Args: + env_config_path (str): Path to the environment configuration file. + model_config_path (str): Path to the model configuration file. + model_def_path (str): Path to the model definition file. + + Returns: + argparse.Namespace: Loaded configuration. + """ + args = argparse.Namespace() + + with open(env_config_path, "r") as f: + env_config = json.load(f) + for k, v in env_config.items(): + setattr(args, k, v) + + with open(model_config_path, "r") as f: + model_config = json.load(f) + for k, v in model_config.items(): + setattr(args, k, v) + + with open(model_def_path, "r") as f: + model_def = json.load(f) + for k, v in model_def.items(): + setattr(args, k, v) + + return args + + +def initialize_distributed() -> tuple: + """ + Initialize distributed training. + + Returns: + tuple: local_rank, world_size, and device. + """ + dist.init_process_group(backend="nccl", init_method="env://") + local_rank = dist.get_rank() + world_size = dist.get_world_size() + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + return local_rank, world_size, device diff --git a/generative/maisi/scripts/diff_model_train.py b/generative/maisi/scripts/diff_model_train.py new file mode 100644 index 0000000000..094b12fcac --- /dev/null +++ b/generative/maisi/scripts/diff_model_train.py @@ -0,0 +1,426 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import argparse +import json +import logging +import os +from datetime import datetime +from pathlib import Path + +import torch +import torch.distributed as dist +from torch.cuda.amp import GradScaler, autocast +from torch.nn.parallel import DistributedDataParallel + +import monai +from monai.data import ThreadDataLoader, partition_dataset +from monai.transforms import Compose +from monai.utils import first + +from .diff_model_setting import initialize_distributed, load_config, setup_logging +from .utils import define_instance + + +def load_filenames(data_list_path: str) -> list: + """ + Load filenames from the JSON data list. + + Args: + data_list_path (str): Path to the JSON data list file. + + Returns: + list: List of filenames. + """ + with open(data_list_path, "r") as file: + json_data = json.load(file) + filenames_train = json_data["training"] + return [_item["image"].replace(".nii.gz", "_emb.nii.gz") for _item in filenames_train] + + +def prepare_data( + train_files: list, device: torch.device, cache_rate: float, num_workers: int = 2, batch_size: int = 1 +) -> ThreadDataLoader: + """ + Prepare training data. + + Args: + train_files (list): List of training files. + device (torch.device): Device to use for training. + cache_rate (float): Cache rate for dataset. + num_workers (int): Number of workers for data loading. + batch_size (int): Mini-batch size. + + Returns: + ThreadDataLoader: Data loader for training. + """ + train_transforms = Compose( + [ + monai.transforms.LoadImaged(keys=["image"]), + monai.transforms.EnsureChannelFirstd(keys=["image"]), + monai.transforms.Lambdad( + keys="top_region_index", func=lambda x: torch.FloatTensor(json.load(open(x))["top_region_index"]) + ), + monai.transforms.Lambdad( + keys="bottom_region_index", func=lambda x: torch.FloatTensor(json.load(open(x))["bottom_region_index"]) + ), + monai.transforms.Lambdad(keys="spacing", func=lambda x: torch.FloatTensor(json.load(open(x))["spacing"])), + monai.transforms.Lambdad(keys="top_region_index", func=lambda x: x * 1e2), + monai.transforms.Lambdad(keys="bottom_region_index", func=lambda x: x * 1e2), + monai.transforms.Lambdad(keys="spacing", func=lambda x: x * 1e2), + ] + ) + + train_ds = monai.data.CacheDataset( + data=train_files, transform=train_transforms, cache_rate=cache_rate, num_workers=num_workers + ) + + return ThreadDataLoader(train_ds, num_workers=6, batch_size=batch_size, shuffle=True) + + +def load_unet(args: argparse.Namespace, device: torch.device, logger: logging.Logger) -> torch.nn.Module: + """ + Load the UNet model. + + Args: + args (argparse.Namespace): Configuration arguments. + device (torch.device): Device to load the model on. + logger (logging.Logger): Logger for logging information. + + Returns: + torch.nn.Module: Loaded UNet model. + """ + unet = define_instance(args, "diffusion_unet_def").to(device) + unet = torch.nn.SyncBatchNorm.convert_sync_batchnorm(unet) + + if torch.cuda.device_count() > 1: + unet = DistributedDataParallel(unet, device_ids=[device], find_unused_parameters=True) + + if args.existing_ckpt_filepath is None: + logger.info("Training from scratch.") + else: + checkpoint_unet = torch.load(f"{args.existing_ckpt_filepath}", map_location=device) + if torch.cuda.device_count() > 1: + unet.module.load_state_dict(checkpoint_unet["unet_state_dict"], strict=True) + else: + unet.load_state_dict(checkpoint_unet["unet_state_dict"], strict=True) + logger.info(f"Pretrained checkpoint {args.existing_ckpt_filepath} loaded.") + + return unet + + +def calculate_scale_factor( + train_loader: ThreadDataLoader, device: torch.device, logger: logging.Logger +) -> torch.Tensor: + """ + Calculate the scaling factor for the dataset. + + Args: + train_loader (ThreadDataLoader): Data loader for training. + device (torch.device): Device to use for calculation. + logger (logging.Logger): Logger for logging information. + + Returns: + torch.Tensor: Calculated scaling factor. + """ + check_data = first(train_loader) + z = check_data["image"].to(device) + scale_factor = 1 / torch.std(z) + logger.info(f"Scaling factor set to {scale_factor}.") + + dist.barrier() + dist.all_reduce(scale_factor, op=torch.distributed.ReduceOp.AVG) + logger.info(f"scale_factor -> {scale_factor}.") + return scale_factor + + +def create_optimizer(model: torch.nn.Module, lr: float) -> torch.optim.Optimizer: + """ + Create optimizer for training. + + Args: + model (torch.nn.Module): Model to optimize. + lr (float): Learning rate. + + Returns: + torch.optim.Optimizer: Created optimizer. + """ + return torch.optim.Adam(params=model.parameters(), lr=lr) + + +def create_lr_scheduler(optimizer: torch.optim.Optimizer, total_steps: int) -> torch.optim.lr_scheduler.PolynomialLR: + """ + Create learning rate scheduler. + + Args: + optimizer (torch.optim.Optimizer): Optimizer to schedule. + total_steps (int): Total number of training steps. + + Returns: + torch.optim.lr_scheduler.PolynomialLR: Created learning rate scheduler. + """ + return torch.optim.lr_scheduler.PolynomialLR(optimizer, total_iters=total_steps, power=2.0) + + +def train_one_epoch( + epoch: int, + unet: torch.nn.Module, + train_loader: ThreadDataLoader, + optimizer: torch.optim.Optimizer, + lr_scheduler: torch.optim.lr_scheduler.PolynomialLR, + loss_pt: torch.nn.L1Loss, + scaler: GradScaler, + scale_factor: torch.Tensor, + noise_scheduler: torch.nn.Module, + num_images_per_batch: int, + num_train_timesteps: int, + device: torch.device, + logger: logging.Logger, + local_rank: int, +) -> torch.Tensor: + """ + Train the model for one epoch. + + Args: + epoch (int): Current epoch number. + unet (torch.nn.Module): UNet model. + train_loader (ThreadDataLoader): Data loader for training. + optimizer (torch.optim.Optimizer): Optimizer. + lr_scheduler (torch.optim.lr_scheduler.PolynomialLR): Learning rate scheduler. + loss_pt (torch.nn.L1Loss): Loss function. + scaler (GradScaler): Gradient scaler for mixed precision training. + scale_factor (torch.Tensor): Scaling factor. + noise_scheduler (torch.nn.Module): Noise scheduler. + num_images_per_batch (int): Number of images per batch. + num_train_timesteps (int): Number of training timesteps. + device (torch.device): Device to use for training. + logger (logging.Logger): Logger for logging information. + local_rank (int): Local rank for distributed training. + + Returns: + torch.Tensor: Training loss for the epoch. + """ + if local_rank == 0: + current_lr = optimizer.param_groups[0]["lr"] + logger.info(f"Epoch {epoch + 1}, lr {current_lr}.") + + _iter = 0 + loss_torch = torch.zeros(2, dtype=torch.float, device=device) + + unet.train() + for train_data in train_loader: + current_lr = optimizer.param_groups[0]["lr"] + + _iter += 1 + images = train_data["image"].to(device) + images = images * scale_factor + + top_region_index_tensor = train_data["top_region_index"].to(device) + bottom_region_index_tensor = train_data["bottom_region_index"].to(device) + spacing_tensor = train_data["spacing"].to(device) + + optimizer.zero_grad(set_to_none=True) + + with autocast(enabled=True): + noise = torch.randn( + (num_images_per_batch, 4, images.size(-3), images.size(-2), images.size(-1)), device=device + ) + + timesteps = torch.randint(0, num_train_timesteps, (images.shape[0],), device=images.device).long() + + noisy_latent = noise_scheduler.add_noise(original_samples=images, noise=noise, timesteps=timesteps) + + noise_pred = unet( + x=noisy_latent, + timesteps=timesteps, + top_region_index_tensor=top_region_index_tensor, + bottom_region_index_tensor=bottom_region_index_tensor, + spacing_tensor=spacing_tensor, + ) + + loss = loss_pt(noise_pred.float(), noise.float()) + + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + + lr_scheduler.step() + + loss_torch[0] += loss.item() + loss_torch[1] += 1.0 + + if local_rank == 0: + logger.info( + "[{0}] epoch {1}, iter {2}/{3}, loss: {4:.4f}, lr: {5:.12f}.".format( + str(datetime.now())[:19], epoch + 1, _iter, len(train_loader), loss.item(), current_lr + ) + ) + + if torch.cuda.device_count() > 1: + dist.all_reduce(loss_torch, op=torch.distributed.ReduceOp.SUM) + + return loss_torch + + +def save_checkpoint( + epoch: int, + unet: torch.nn.Module, + loss_torch_epoch: float, + num_train_timesteps: int, + scale_factor: torch.Tensor, + ckpt_folder: str, + args: argparse.Namespace, +) -> None: + """ + Save checkpoint. + + Args: + epoch (int): Current epoch number. + unet (torch.nn.Module): UNet model. + loss_torch_epoch (float): Training loss for the epoch. + num_train_timesteps (int): Number of training timesteps. + scale_factor (torch.Tensor): Scaling factor. + ckpt_folder (str): Checkpoint folder path. + args (argparse.Namespace): Configuration arguments. + """ + unet_state_dict = unet.module.state_dict() if torch.cuda.device_count() > 1 else unet.state_dict() + torch.save( + { + "epoch": epoch + 1, + "loss": loss_torch_epoch, + "num_train_timesteps": num_train_timesteps, + "scale_factor": scale_factor, + "unet_state_dict": unet_state_dict, + }, + f"{ckpt_folder}/{args.model_filename}", + ) + + +def diff_model_train(env_config_path: str, model_config_path: str, model_def_path: str) -> None: + """ + Main function to train a diffusion model. + + Args: + env_config_path (str): Path to the environment configuration file. + model_config_path (str): Path to the model configuration file. + model_def_path (str): Path to the model definition file. + """ + args = load_config(env_config_path, model_config_path, model_def_path) + local_rank, world_size, device = initialize_distributed() + logger = setup_logging("training") + + logger.info(f"Using {device} of {world_size}") + + if local_rank == 0: + logger.info(f"[config] ckpt_folder -> {args.model_dir}.") + logger.info(f"[config] data_root -> {args.embedding_base_dir}.") + logger.info(f"[config] data_list -> {args.json_data_list}.") + logger.info(f"[config] lr -> {args.diffusion_unet_train['lr']}.") + logger.info(f"[config] num_epochs -> {args.diffusion_unet_train['n_epochs']}.") + logger.info(f"[config] num_train_timesteps -> {args.noise_scheduler['num_train_timesteps']}.") + + Path(args.model_dir).mkdir(parents=True, exist_ok=True) + + filenames_train = load_filenames(args.json_data_list) + if local_rank == 0: + logger.info(f"num_files_train: {len(filenames_train)}") + + train_files = [] + for _i in range(len(filenames_train)): + str_img = os.path.join(args.embedding_base_dir, filenames_train[_i]) + if not os.path.exists(str_img): + continue + + str_info = os.path.join(args.embedding_base_dir, filenames_train[_i]) + ".json" + train_files.append( + {"image": str_img, "top_region_index": str_info, "bottom_region_index": str_info, "spacing": str_info} + ) + + train_files = partition_dataset( + data=train_files, shuffle=True, num_partitions=dist.get_world_size(), even_divisible=True + )[local_rank] + + train_loader = prepare_data( + train_files, device, args.diffusion_unet_train["cache_rate"], args.diffusion_unet_train["batch_size"] + ) + + unet = load_unet(args, device, logger) + noise_scheduler = define_instance(args, "noise_scheduler") + + scale_factor = calculate_scale_factor(train_loader, device, logger) + optimizer = create_optimizer(unet, args.diffusion_unet_train["lr"]) + + total_steps = (args.diffusion_unet_train["n_epochs"] * len(train_loader.dataset)) / args.diffusion_unet_train[ + "batch_size" + ] + lr_scheduler = create_lr_scheduler(optimizer, total_steps) + loss_pt = torch.nn.L1Loss() + scaler = GradScaler() + + torch.set_float32_matmul_precision("highest") + logger.info("torch.set_float32_matmul_precision -> highest.") + + for epoch in range(args.diffusion_unet_train["n_epochs"]): + loss_torch = train_one_epoch( + epoch, + unet, + train_loader, + optimizer, + lr_scheduler, + loss_pt, + scaler, + scale_factor, + noise_scheduler, + args.diffusion_unet_train["batch_size"], + args.noise_scheduler["num_train_timesteps"], + device, + logger, + local_rank, + ) + + loss_torch = loss_torch.tolist() + if torch.cuda.device_count() == 1 or local_rank == 0: + loss_torch_epoch = loss_torch[0] / loss_torch[1] + logger.info(f"epoch {epoch + 1} average loss: {loss_torch_epoch:.4f}.") + + save_checkpoint( + epoch, + unet, + loss_torch_epoch, + args.noise_scheduler["num_train_timesteps"], + scale_factor, + args.model_dir, + args, + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Diffusion Model Training") + parser.add_argument( + "--env_config", + type=str, + default="./configs/environment_maisi_diff_model_train.json", + help="Path to environment configuration file", + ) + parser.add_argument( + "--model_config", + type=str, + default="./configs/config_maisi_diff_model_train.json", + help="Path to model training/inference configuration", + ) + parser.add_argument( + "--model_def", type=str, default="./configs/config_maisi.json", help="Path to model definition file" + ) + + args = parser.parse_args() + diff_model_train(args.env_config, args.model_config, args.model_def) diff --git a/generative/maisi/scripts/sample.py b/generative/maisi/scripts/sample.py index bcbd662fec..0c6ad50086 100644 --- a/generative/maisi/scripts/sample.py +++ b/generative/maisi/scripts/sample.py @@ -810,14 +810,12 @@ def prepare_anatomy_size_condtion( diff += abs(provide_size - db_size) candidate_list.append((size, diff)) candidate_condition = sorted(candidate_list, key=lambda x: x[1])[0][0] - # logging.info("provide_anatomy_size:", provide_anatomy_size) - # logging.info("candidate_condition:", candidate_condition) # overwrite the anatomy size provided by users for element in controllable_anatomy_size: anatomy_name, anatomy_size = element candidate_condition[anatomy_size_idx[anatomy_name]] = anatomy_size - # logging.info("final candidate_condition:", candidate_condition) + return candidate_condition def prepare_one_mask_and_meta_info(self, anatomy_size_condtion):