From f0bb2e5e4cac791981e10760ae405f3919a7234e Mon Sep 17 00:00:00 2001 From: Virginia Date: Mon, 11 Nov 2024 13:30:08 +0000 Subject: [PATCH] Updates Signed-off-by: Virginia --- .../2d_diffusion_autoencoder_tutorial.ipynb | 62 +++++++++++-------- 1 file changed, 35 insertions(+), 27 deletions(-) diff --git a/generation/2d_diffusion_autoencoder/2d_diffusion_autoencoder_tutorial.ipynb b/generation/2d_diffusion_autoencoder/2d_diffusion_autoencoder_tutorial.ipynb index 7c1b1d549..30a098fe9 100644 --- a/generation/2d_diffusion_autoencoder/2d_diffusion_autoencoder_tutorial.ipynb +++ b/generation/2d_diffusion_autoencoder/2d_diffusion_autoencoder_tutorial.ipynb @@ -104,7 +104,6 @@ "import torch\n", "import torch.nn.functional as F\n", "import torchvision\n", - "import sys\n", "from monai import transforms\n", "from monai.apps import DecathlonDataset\n", "from monai.config import print_config\n", @@ -191,8 +190,9 @@ "2. `EnsureChannelFirstd` ensures the original data to construct \"channel first\" shape.\n", "3. The first `Lambdad` transform chooses the first channel of the image, which is the Flair image.\n", "4. `Spacingd` resamples the image to the specified voxel spacing, we use 3,3,2 mm.\n", - "5. `ScaleIntensityRangePercentilesd` Apply range scaling to a numpy array based on the intensity distribution of the input. Transform is very common with MRI images.\n", - "6. `RandSpatialCropd` randomly crop out a 2D patch from the 3D image.\n", + "5. `CenterSpatialCropd`: we crop the 3D images to a specific size\n", + "6. `ScaleIntensityRangePercentilesd` Apply range scaling to a numpy array based on the intensity distribution of the input. Transform is very common with MRI images.\n", + "7. `RandSpatialCropd` randomly crop out a 2D patch from the 3D image.\n", "6. The last `Lambdad` transform obtains `slice_label` by summing up the label to have a single scalar value (healthy `=1` or not `=2` )." ] }, @@ -388,7 +388,7 @@ }, "outputs": [], "source": [ - "class Diffusion_AE(torch.nn.Module):\n", + "class DiffusionAE(torch.nn.Module):\n", " def __init__(self, embedding_dimension=64):\n", " super().__init__()\n", " self.unet = DiffusionModelUNet(\n", @@ -413,7 +413,7 @@ "\n", "\n", "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", - "model = Diffusion_AE(embedding_dimension=512).to(device)\n", + "model = DiffusionAE(embedding_dimension=512).to(device)\n", "scheduler = DDIMScheduler(num_train_timesteps=1000)\n", "optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-5)\n", "inferer = DiffusionInferer(scheduler)" @@ -492,7 +492,8 @@ " # Create timesteps\n", " timesteps = torch.randint(0, inferer.scheduler.num_train_timesteps, (batch_size,)).to(device).long()\n", " # Get model prediction\n", - " # cross attention expects shape [batch size, sequence length, channels], we are use channels = latent dimension and sequence length = 1\n", + " # cross attention expects shape [batch size, sequence length, channels], \n", + " #we are use channels = latent dimension and sequence length = 1\n", " latent = model.semantic_encoder(images)\n", " noise_pred = inferer(\n", " inputs=images, diffusion_model=model.unet, noise=noise, timesteps=timesteps, condition=latent.unsqueeze(2)\n", @@ -509,7 +510,7 @@ " if epoch % val_interval == 0:\n", " model.eval()\n", " val_iter_loss = 0\n", - " for val_step, val_batch in enumerate(val_loader):\n", + " for _, val_batch in enumerate(val_loader):\n", " with torch.no_grad():\n", " images = val_batch[\"image\"].to(device)\n", " timesteps = torch.randint(0, inferer.scheduler.num_train_timesteps, (batch_size,)).to(device).long()\n", @@ -526,10 +527,11 @@ "\n", " val_iter_loss += val_loss.item()\n", " iter_loss_list.append(iter_loss / val_interval)\n", - " val_iter_loss_list.append(val_iter_loss / (val_step + 1))\n", + " val_iter_loss_list.append(val_iter_loss / len(val_loader))\n", " iter_loss = 0\n", " print(\n", - " f\"Iteration {epoch} - Interval Loss {iter_loss_list[-1]:.4f}, Interval Loss Val {val_iter_loss_list[-1]:.4f}\"\n", + " f\"Iteration {epoch} - Interval Loss {iter_loss_list[-1]:.4f}, \n", + " Interval Loss Val {val_iter_loss_list[-1]:.4f}\"\n", " )\n", "\n", "total_time = time.time() - total_start\n", @@ -566,8 +568,10 @@ "plt.title(\"Learning Curves Diffusion Model\", fontsize=20)\n", "plt.plot(list(range(len(iter_loss_list))), iter_loss_list, color=\"C0\", linewidth=2.0, label=\"Train\")\n", "plt.plot(list(range(len(iter_loss_list))), val_iter_loss_list, color=\"C4\", linewidth=2.0, label=\"Validation\")\n", - "plt.yticks(fontsize=12), plt.xticks(fontsize=12)\n", - "plt.xlabel(\"Iterations\", fontsize=16), plt.ylabel(\"Loss\", fontsize=16)\n", + "plt.yticks(fontsize=12)\n", + "plt.xticks(fontsize=12)\n", + "plt.xlabel(\"Iterations\", fontsize=16)\n", + "plt.ylabel(\"Loss\", fontsize=16)\n", "plt.legend(prop={\"size\": 14})\n", "plt.show()" ] @@ -713,7 +717,8 @@ } ], "source": [ - "latents_train.shape, classes_train.shape" + "print(latents_train.shape)\n", + "print(classes_train.shape)" ] }, { @@ -735,17 +740,8 @@ ], "source": [ "clf = LogisticRegression(solver=\"newton-cg\", random_state=0).fit(latents_train, classes_train)\n", - "clf.score(latents_train, classes_train), clf.score(latents_val, classes_val)" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "id": "73df71e0", - "metadata": {}, - "outputs": [], - "source": [ - "w = torch.Tensor(clf.coef_).float().to(device)" + "print(clf.score(latents_train, classes_train))\n", + "print(clf.score(latents_val, classes_val))" ] }, { @@ -777,6 +773,7 @@ "source": [ "s = -1.5\n", "\n", + "w = torch.Tensor(clf.coef_).float().to(device)\n", "scheduler.set_timesteps(num_inference_steps=100)\n", "batch = next(iter(val_loader))\n", "images = batch[\"image\"].to(device)\n", @@ -802,6 +799,14 @@ ")" ] }, + { + "cell_type": "markdown", + "id": "525702b5", + "metadata": {}, + "source": [ + "Although not perfectly, the manipulated slices do not present a tumour (unlike the middle - \"reconstructed\" - ones), because we tweaked the latents to move away from the abnormality cluster: " + ] + }, { "cell_type": "code", "execution_count": 28, @@ -831,15 +836,18 @@ "plt.figure(figsize=(15, 5))\n", "plt.imshow(grid.detach().cpu().numpy()[0], cmap=\"gray\")\n", "plt.axis(\"off\")\n", - "plt.title(f\"Original (top), Reconstruction (middle), Manipulated (bottom) s = {s}\");" + "plt.title(f\"Original (top), Reconstruction (middle), Manipulated (bottom) s = {s}\")" ] }, { - "cell_type": "markdown", - "id": "b5ac0b8c-0f9d-43ba-9959-488ab62e892e", + "cell_type": "code", + "execution_count": null, + "id": "9cf8fbf9", "metadata": {}, + "outputs": [], "source": [ - "Although not perfectly, the manipulated slices do not present a tumour (unlike the middle - \"reconstructed\" - ones), because we tweaked the latents to move away from the abnormality cluster." + "if directory is None:\n", + " shutil.rmtree(root_dir)" ] } ],