Skip to content

Commit

Permalink
Updates
Browse files Browse the repository at this point in the history
Signed-off-by: Virginia <[email protected]>
  • Loading branch information
Virginia committed Nov 11, 2024
1 parent 04f4d21 commit f0bb2e5
Showing 1 changed file with 35 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@
"import torch\n",
"import torch.nn.functional as F\n",
"import torchvision\n",
"import sys\n",
"from monai import transforms\n",
"from monai.apps import DecathlonDataset\n",
"from monai.config import print_config\n",
Expand Down Expand Up @@ -191,8 +190,9 @@
"2. `EnsureChannelFirstd` ensures the original data to construct \"channel first\" shape.\n",
"3. The first `Lambdad` transform chooses the first channel of the image, which is the Flair image.\n",
"4. `Spacingd` resamples the image to the specified voxel spacing, we use 3,3,2 mm.\n",
"5. `ScaleIntensityRangePercentilesd` Apply range scaling to a numpy array based on the intensity distribution of the input. Transform is very common with MRI images.\n",
"6. `RandSpatialCropd` randomly crop out a 2D patch from the 3D image.\n",
"5. `CenterSpatialCropd`: we crop the 3D images to a specific size\n",
"6. `ScaleIntensityRangePercentilesd` Apply range scaling to a numpy array based on the intensity distribution of the input. Transform is very common with MRI images.\n",
"7. `RandSpatialCropd` randomly crop out a 2D patch from the 3D image.\n",
"6. The last `Lambdad` transform obtains `slice_label` by summing up the label to have a single scalar value (healthy `=1` or not `=2` )."
]
},
Expand Down Expand Up @@ -388,7 +388,7 @@
},
"outputs": [],
"source": [
"class Diffusion_AE(torch.nn.Module):\n",
"class DiffusionAE(torch.nn.Module):\n",
" def __init__(self, embedding_dimension=64):\n",
" super().__init__()\n",
" self.unet = DiffusionModelUNet(\n",
Expand All @@ -413,7 +413,7 @@
"\n",
"\n",
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
"model = Diffusion_AE(embedding_dimension=512).to(device)\n",
"model = DiffusionAE(embedding_dimension=512).to(device)\n",
"scheduler = DDIMScheduler(num_train_timesteps=1000)\n",
"optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-5)\n",
"inferer = DiffusionInferer(scheduler)"
Expand Down Expand Up @@ -492,7 +492,8 @@
" # Create timesteps\n",
" timesteps = torch.randint(0, inferer.scheduler.num_train_timesteps, (batch_size,)).to(device).long()\n",
" # Get model prediction\n",
" # cross attention expects shape [batch size, sequence length, channels], we are use channels = latent dimension and sequence length = 1\n",
" # cross attention expects shape [batch size, sequence length, channels], \n",
" #we are use channels = latent dimension and sequence length = 1\n",
" latent = model.semantic_encoder(images)\n",
" noise_pred = inferer(\n",
" inputs=images, diffusion_model=model.unet, noise=noise, timesteps=timesteps, condition=latent.unsqueeze(2)\n",
Expand All @@ -509,7 +510,7 @@
" if epoch % val_interval == 0:\n",
" model.eval()\n",
" val_iter_loss = 0\n",
" for val_step, val_batch in enumerate(val_loader):\n",
" for _, val_batch in enumerate(val_loader):\n",
" with torch.no_grad():\n",
" images = val_batch[\"image\"].to(device)\n",
" timesteps = torch.randint(0, inferer.scheduler.num_train_timesteps, (batch_size,)).to(device).long()\n",
Expand All @@ -526,10 +527,11 @@
"\n",
" val_iter_loss += val_loss.item()\n",
" iter_loss_list.append(iter_loss / val_interval)\n",
" val_iter_loss_list.append(val_iter_loss / (val_step + 1))\n",
" val_iter_loss_list.append(val_iter_loss / len(val_loader))\n",
" iter_loss = 0\n",
" print(\n",
" f\"Iteration {epoch} - Interval Loss {iter_loss_list[-1]:.4f}, Interval Loss Val {val_iter_loss_list[-1]:.4f}\"\n",
" f\"Iteration {epoch} - Interval Loss {iter_loss_list[-1]:.4f}, \n",
" Interval Loss Val {val_iter_loss_list[-1]:.4f}\"\n",
" )\n",
"\n",
"total_time = time.time() - total_start\n",
Expand Down Expand Up @@ -566,8 +568,10 @@
"plt.title(\"Learning Curves Diffusion Model\", fontsize=20)\n",
"plt.plot(list(range(len(iter_loss_list))), iter_loss_list, color=\"C0\", linewidth=2.0, label=\"Train\")\n",
"plt.plot(list(range(len(iter_loss_list))), val_iter_loss_list, color=\"C4\", linewidth=2.0, label=\"Validation\")\n",
"plt.yticks(fontsize=12), plt.xticks(fontsize=12)\n",
"plt.xlabel(\"Iterations\", fontsize=16), plt.ylabel(\"Loss\", fontsize=16)\n",
"plt.yticks(fontsize=12)\n",
"plt.xticks(fontsize=12)\n",
"plt.xlabel(\"Iterations\", fontsize=16)\n",
"plt.ylabel(\"Loss\", fontsize=16)\n",
"plt.legend(prop={\"size\": 14})\n",
"plt.show()"
]
Expand Down Expand Up @@ -713,7 +717,8 @@
}
],
"source": [
"latents_train.shape, classes_train.shape"
"print(latents_train.shape)\n",
"print(classes_train.shape)"
]
},
{
Expand All @@ -735,17 +740,8 @@
],
"source": [
"clf = LogisticRegression(solver=\"newton-cg\", random_state=0).fit(latents_train, classes_train)\n",
"clf.score(latents_train, classes_train), clf.score(latents_val, classes_val)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "73df71e0",
"metadata": {},
"outputs": [],
"source": [
"w = torch.Tensor(clf.coef_).float().to(device)"
"print(clf.score(latents_train, classes_train))\n",
"print(clf.score(latents_val, classes_val))"
]
},
{
Expand Down Expand Up @@ -777,6 +773,7 @@
"source": [
"s = -1.5\n",
"\n",
"w = torch.Tensor(clf.coef_).float().to(device)\n",
"scheduler.set_timesteps(num_inference_steps=100)\n",
"batch = next(iter(val_loader))\n",
"images = batch[\"image\"].to(device)\n",
Expand All @@ -802,6 +799,14 @@
")"
]
},
{
"cell_type": "markdown",
"id": "525702b5",
"metadata": {},
"source": [
"Although not perfectly, the manipulated slices do not present a tumour (unlike the middle - \"reconstructed\" - ones), because we tweaked the latents to move away from the abnormality cluster: "
]
},
{
"cell_type": "code",
"execution_count": 28,
Expand Down Expand Up @@ -831,15 +836,18 @@
"plt.figure(figsize=(15, 5))\n",
"plt.imshow(grid.detach().cpu().numpy()[0], cmap=\"gray\")\n",
"plt.axis(\"off\")\n",
"plt.title(f\"Original (top), Reconstruction (middle), Manipulated (bottom) s = {s}\");"
"plt.title(f\"Original (top), Reconstruction (middle), Manipulated (bottom) s = {s}\")"
]
},
{
"cell_type": "markdown",
"id": "b5ac0b8c-0f9d-43ba-9959-488ab62e892e",
"cell_type": "code",
"execution_count": null,
"id": "9cf8fbf9",
"metadata": {},
"outputs": [],
"source": [
"Although not perfectly, the manipulated slices do not present a tumour (unlike the middle - \"reconstructed\" - ones), because we tweaked the latents to move away from the abnormality cluster."
"if directory is None:\n",
" shutil.rmtree(root_dir)"
]
}
],
Expand Down

0 comments on commit f0bb2e5

Please sign in to comment.