From 98447af49012fef636207dc09090a52d3cfb4fd8 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Thu, 18 Jan 2024 14:32:32 +0800 Subject: [PATCH] revert bundle tutorial Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- bundle/01_bundle_intro.ipynb | 6 +- bundle/02_mednist_classification.ipynb | 96 +++---- bundle/03_mednist_classification_v2.ipynb | 234 +++++++++--------- bundle/04_integrating_code.ipynb | 41 +-- bundle/05_spleen_segmentation_lightning.ipynb | 183 +++++++------- 5 files changed, 296 insertions(+), 264 deletions(-) diff --git a/bundle/01_bundle_intro.ipynb b/bundle/01_bundle_intro.ipynb index 56d601072f..976e61e71e 100644 --- a/bundle/01_bundle_intro.ipynb +++ b/bundle/01_bundle_intro.ipynb @@ -402,9 +402,9 @@ "datadicts: '$[{i: (i * i)} for i in range(10)]' # create a fake dataset as a list of dicts\n", "\n", "test_dataset: # creates an instance of an object because _target_ is present\n", - " _target_: Dataset # name of type to create is monai.data.Dataset (loaded implicitly from MONAI)\n", - " data: '@datadicts' # argument data provided by a definition\n", - " transform: '$None' # argument transform provided by a Python expression\n", + " _target_: Dataset # name of type to create is monai.data.Dataset (loaded implicitly from MONAI)\n", + " data: '@datadicts' # argument data provided by a definition\n", + " transform: '$None' # argument transform provided by a Python expression\n", "\n", "test:\n", "- '$print(\"Dataset\", @test_dataset)'\n", diff --git a/bundle/02_mednist_classification.ipynb b/bundle/02_mednist_classification.ipynb index b1dda0d784..95a3178c11 100644 --- a/bundle/02_mednist_classification.ipynb +++ b/bundle/02_mednist_classification.ipynb @@ -294,25 +294,25 @@ "\n", "# define the network separately, don't need to refer to MONAI types by name or import MONAI\n", "network_def:\n", - " _target_: densenet121\n", - " spatial_dims: 2\n", - " in_channels: 1\n", - " out_channels: 6\n", + " _target_: densenet121\n", + " spatial_dims: 2\n", + " in_channels: 1\n", + " out_channels: 6\n", "\n", "# define the network to be the given definition moved to the device\n", "net: '$@network_def.to(@device)'\n", "\n", "# define a transform sequence by instantiating a Compose instance with a transform sequence\n", "transform:\n", - " _target_: Compose\n", - " transforms:\n", - " - _target_: LoadImaged\n", - " keys: 'image'\n", - " image_only: true\n", - " - _target_: EnsureChannelFirstd\n", - " keys: 'image'\n", - " - _target_: ScaleIntensityd\n", - " keys: 'image'" + " _target_: Compose\n", + " transforms:\n", + " - _target_: LoadImaged\n", + " keys: 'image'\n", + " image_only: true\n", + " - _target_: EnsureChannelFirstd\n", + " keys: 'image'\n", + " - _target_: ScaleIntensityd\n", + " keys: 'image'" ] }, { @@ -356,32 +356,32 @@ "max_epochs: 25\n", "\n", "dataset:\n", - " _target_: MedNISTDataset\n", - " root_dir: '@root_dir'\n", - " transform: '@transform'\n", - " section: training\n", - " download: true\n", + " _target_: MedNISTDataset\n", + " root_dir: '@root_dir'\n", + " transform: '@transform'\n", + " section: training\n", + " download: true\n", "\n", "train_dl:\n", - " _target_: DataLoader\n", - " dataset: '@dataset'\n", - " batch_size: 512\n", - " shuffle: true\n", - " num_workers: 4\n", + " _target_: DataLoader\n", + " dataset: '@dataset'\n", + " batch_size: 512\n", + " shuffle: true\n", + " num_workers: 4\n", "\n", "trainer:\n", - " _target_: SupervisedTrainer\n", - " device: '@device'\n", - " max_epochs: '@max_epochs'\n", - " train_data_loader: '@train_dl'\n", - " network: '@net'\n", - " optimizer: \n", - " _target_: torch.optim.Adam\n", - " params: '$@net.parameters()'\n", - " lr: 0.00001 # learning rate set slow so that you can see network improvement over epochs\n", - " loss_function: \n", - " _target_: torch.nn.CrossEntropyLoss\n", - " inferer: \n", + " _target_: SupervisedTrainer\n", + " device: '@device'\n", + " max_epochs: '@max_epochs'\n", + " train_data_loader: '@train_dl'\n", + " network: '@net'\n", + " optimizer: \n", + " _target_: torch.optim.Adam\n", + " params: '$@net.parameters()'\n", + " lr: 0.00001 # learning rate set slow so that you can see network improvement over epochs\n", + " loss_function: \n", + " _target_: torch.nn.CrossEntropyLoss\n", + " inferer: \n", " _target_: SimpleInferer\n", "\n", "train:\n", @@ -519,6 +519,7 @@ "source": [ "%%writefile MedNISTClassifier/scripts/__init__.py\n", "\n", + "from monai.networks.utils import eval_mode\n", "\n", "def evaluate(net, dataloader, class_names, device):\n", " with eval_mode(net):\n", @@ -527,7 +528,7 @@ " prob = result.detach().to(\"cpu\")[0]\n", " pred = class_names[prob.argmax()]\n", " gt = item[\"class_name\"][0]\n", - " print(f\"Prediction: {pred}. Ground-truth: {gt}\")" + " print(f\"Prediction: {pred}. Ground-truth: {gt}\")\n" ] }, { @@ -556,6 +557,7 @@ ], "source": [ "%%writefile MedNISTClassifier/configs/evaluate.yaml\n", + "\n", "imports: \n", "- $import scripts\n", "\n", @@ -564,23 +566,23 @@ "ckpt_file: \"\"\n", "\n", "testdata:\n", - " _target_: MedNISTDataset\n", - " root_dir: '@root_dir'\n", - " transform: '@transform'\n", - " section: test\n", - " download: false\n", - " runtime_cache: true\n", + " _target_: MedNISTDataset\n", + " root_dir: '@root_dir'\n", + " transform: '@transform'\n", + " section: test\n", + " download: false\n", + " runtime_cache: true\n", "\n", "eval_dl:\n", - " _target_: DataLoader\n", - " dataset: '$@testdata[:@max_items_to_print]'\n", - " batch_size: 1\n", - " num_workers: 0\n", + " _target_: DataLoader\n", + " dataset: '$@testdata[:@max_items_to_print]'\n", + " batch_size: 1\n", + " num_workers: 0\n", "\n", "# loads the weights from the given file (which needs to be set on the command line) then calls \"evaluate\"\n", "evaluate:\n", "- '$@net.load_state_dict(torch.load(@ckpt_file))'\n", - "- '$scripts.evaluate(@net, @eval_dl, @class_names, @device)'" + "- '$scripts.evaluate(@net, @eval_dl, @class_names, @device)'\n" ] }, { diff --git a/bundle/03_mednist_classification_v2.ipynb b/bundle/03_mednist_classification_v2.ipynb index b14a974a80..a7edbc6322 100644 --- a/bundle/03_mednist_classification_v2.ipynb +++ b/bundle/03_mednist_classification_v2.ipynb @@ -201,28 +201,28 @@ "%%writefile MedNISTClassifier_v2/configs/logging.conf\n", "\n", "[loggers]\n", - "keys = root\n", + "keys=root\n", "\n", "[handlers]\n", - "keys = consoleHandler\n", + "keys=consoleHandler\n", "\n", "[formatters]\n", - "keys = fullFormatter\n", + "keys=fullFormatter\n", "\n", "[logger_root]\n", - "level = INFO\n", - "handlers = consoleHandler\n", + "level=INFO\n", + "handlers=consoleHandler\n", "\n", "[handler_consoleHandler]\n", - "class = StreamHandler\n", + "class=StreamHandler\n", "\n", "\n", - "level = INFO\n", - "formatter = fullFormatter\n", - "args = (sys.stdout,)\n", + "level=INFO\n", + "formatter=fullFormatter\n", + "args=(sys.stdout,)\n", "\n", "[formatter_fullFormatter]\n", - "format = %(asctime)s - %(name)s - %(levelname)s - %(message)s" + "format=%(asctime)s - %(name)s - %(levelname)s - %(message)s" ] }, { @@ -267,7 +267,7 @@ "\n", "# these are added definitions\n", "bundle_root: .\n", - "ckpt_path: $@ bundle_root + '/models/model.pt'\n", + "ckpt_path: $@bundle_root + '/models/model.pt'\n", "\n", "# define a device for the network\n", "device: '$torch.device(''cuda:0'')'\n", @@ -277,10 +277,10 @@ "\n", "# define the network separately, don't need to refer to MONAI types by name or import MONAI\n", "network_def:\n", - " _target_: densenet121\n", - " spatial_dims: 2\n", - " in_channels: 1\n", - " out_channels: 6\n", + " _target_: densenet121\n", + " spatial_dims: 2\n", + " in_channels: 1\n", + " out_channels: 6\n", "\n", "# define the network to be the given definition moved to the device\n", "net: '$@network_def.to(@device)'\n", @@ -289,11 +289,11 @@ "train_transforms:\n", "- _target_: LoadImaged\n", " keys: '@image'\n", - " image_only: true\n", + " image_only: true\n", "- _target_: EnsureChannelFirstd\n", " keys: '@image'\n", "- _target_: ScaleIntensityd\n", - " keys: '@image'" + " keys: '@image'\n" ] }, { @@ -335,128 +335,128 @@ "output_dir: '$datetime.datetime.now().strftime(@root_dir+''/output/output_%y%m%d_%H%M%S'')'\n", "\n", "train_dataset:\n", - " _target_: MedNISTDataset\n", - " root_dir: '@root_dir'\n", - " transform: \n", - " _target_: Compose\n", - " transforms: '@train_transforms'\n", - " section: training\n", - " download: true\n", + " _target_: MedNISTDataset\n", + " root_dir: '@root_dir'\n", + " transform: \n", + " _target_: Compose\n", + " transforms: '@train_transforms'\n", + " section: training\n", + " download: true\n", "\n", "train_dl:\n", - " _target_: DataLoader\n", - " dataset: '@train_dataset'\n", - " batch_size: 512\n", - " shuffle: true\n", - " num_workers: 4\n", + " _target_: DataLoader\n", + " dataset: '@train_dataset'\n", + " batch_size: 512\n", + " shuffle: true\n", + " num_workers: 4\n", "\n", "# separate dataset taking from the \"validation\" section\n", "eval_dataset:\n", - " _target_: MedNISTDataset\n", - " root_dir: '@root_dir'\n", - " transform: \n", - " _target_: Compose\n", - " transforms: '$@train_transforms'\n", - " section: validation\n", - " download: true\n", + " _target_: MedNISTDataset\n", + " root_dir: '@root_dir'\n", + " transform: \n", + " _target_: Compose\n", + " transforms: '$@train_transforms'\n", + " section: validation\n", + " download: true\n", "\n", "# separate dataloader for evaluation\n", "eval_dl:\n", - " _target_: DataLoader\n", - " dataset: '@eval_dataset'\n", - " batch_size: 512\n", - " shuffle: false\n", - " num_workers: 4\n", + " _target_: DataLoader\n", + " dataset: '@eval_dataset'\n", + " batch_size: 512\n", + " shuffle: false\n", + " num_workers: 4\n", "\n", "# transforms applied to network output, in this case applying activation, argmax, and one-hot-encoding\n", "post_transform:\n", - " _target_: Compose\n", - " transforms:\n", - " - _target_: Activationsd\n", - " keys: '@pred'\n", - " softmax: true # apply softmax to the prediction to emphasize the most likely value\n", - " - _target_: AsDiscreted\n", - " keys: ['@label', '@pred']\n", - " argmax: [false, true] # apply argmax to the prediction only to get a class index number\n", - " to_onehot: 6 # convert both prediction and label to one-hot format so that both have shape (6,)\n", + " _target_: Compose\n", + " transforms:\n", + " - _target_: Activationsd\n", + " keys: '@pred'\n", + " softmax: true # apply softmax to the prediction to emphasize the most likely value\n", + " - _target_: AsDiscreted\n", + " keys: ['@label', '@pred']\n", + " argmax: [false, true] # apply argmax to the prediction only to get a class index number\n", + " to_onehot: 6 # convert both prediction and label to one-hot format so that both have shape (6,)\n", "\n", "# separating out loss, inferer, and optimizer definitions\n", "\n", "loss_function:\n", - " _target_: torch.nn.CrossEntropyLoss\n", + " _target_: torch.nn.CrossEntropyLoss\n", "\n", "inferer: \n", - " _target_: SimpleInferer\n", + " _target_: SimpleInferer\n", "\n", "optimizer: \n", - " _target_: torch.optim.Adam\n", - " params: '$@net.parameters()'\n", - " lr: '@learning_rate'\n", + " _target_: torch.optim.Adam\n", + " params: '$@net.parameters()'\n", + " lr: '@learning_rate'\n", "\n", "# Handlers to load the checkpoint if present, run validation at the chosen interval, save the checkpoint\n", "# at the chosen interval, log stats, and write the log to a file in the output directory.\n", "handlers:\n", "- _target_: CheckpointLoader\n", " _disabled_: '$not os.path.exists(@ckpt_path)'\n", - " load_path: '@ckpt_path'\n", - " load_dict:\n", - " model: '@net'\n", + " load_path: '@ckpt_path'\n", + " load_dict:\n", + " model: '@net'\n", "- _target_: ValidationHandler\n", " validator: '@evaluator'\n", - " epoch_level: true\n", - " interval: '@val_interval'\n", + " epoch_level: true\n", + " interval: '@val_interval'\n", "- _target_: CheckpointSaver\n", " save_dir: '@output_dir'\n", - " save_dict:\n", - " model: '@net'\n", - " save_interval: '@save_interval'\n", - " save_final: true # save the final weights, either when the run finishes or is interrupted somehow\n", + " save_dict:\n", + " model: '@net'\n", + " save_interval: '@save_interval'\n", + " save_final: true # save the final weights, either when the run finishes or is interrupted somehow\n", "- _target_: StatsHandler\n", " name: train_loss\n", - " tag_name: train_loss\n", - " output_transform: '$monai.handlers.from_engine([''loss''], first=True)' # print per-iteration loss\n", + " tag_name: train_loss\n", + " output_transform: '$monai.handlers.from_engine([''loss''], first=True)' # print per-iteration loss\n", "- _target_: LogfileHandler\n", " output_dir: '@output_dir'\n", "\n", "trainer:\n", - " _target_: SupervisedTrainer\n", - " device: '@device'\n", - " max_epochs: '@max_epochs'\n", - " train_data_loader: '@train_dl'\n", - " network: '@net'\n", - " optimizer: '@optimizer'\n", - " loss_function: '@loss_function'\n", - " inferer: '@inferer'\n", - " train_handlers: '@handlers'\n", + " _target_: SupervisedTrainer\n", + " device: '@device'\n", + " max_epochs: '@max_epochs'\n", + " train_data_loader: '@train_dl'\n", + " network: '@net'\n", + " optimizer: '@optimizer'\n", + " loss_function: '@loss_function'\n", + " inferer: '@inferer'\n", + " train_handlers: '@handlers'\n", "\n", "# validation handlers which log stats and direct the log to a file\n", "val_handlers:\n", "- _target_: StatsHandler\n", " name: val_stats\n", - " output_transform: '$lambda x: None'\n", + " output_transform: '$lambda x: None'\n", "- _target_: LogfileHandler\n", " output_dir: '@output_dir'\n", "\n", "# Metrics to assess validation results, you can have more than one here but may \n", "# need to adapt the format of pred and label.\n", "metrics:\n", - " accuracy:\n", - " _target_: 'ignite.metrics.Accuracy'\n", - " output_transform: '$monai.handlers.from_engine([@pred, @label])'\n", + " accuracy:\n", + " _target_: 'ignite.metrics.Accuracy'\n", + " output_transform: '$monai.handlers.from_engine([@pred, @label])'\n", "\n", "# runs the evaluation process, invoked by trainer via the ValidationHandler object\n", "evaluator:\n", - " _target_: SupervisedEvaluator\n", - " device: '@device'\n", - " val_data_loader: '@eval_dl'\n", - " network: '@net'\n", - " inferer: '@inferer'\n", - " postprocessing: '@post_transform'\n", - " key_val_metric: '@metrics'\n", - " val_handlers: '@val_handlers'\n", + " _target_: SupervisedEvaluator\n", + " device: '@device'\n", + " val_data_loader: '@eval_dl'\n", + " network: '@net'\n", + " inferer: '@inferer'\n", + " postprocessing: '@post_transform'\n", + " key_val_metric: '@metrics'\n", + " val_handlers: '@val_handlers'\n", "\n", "train:\n", - "- '$@trainer.run()'" + "- '$@trainer.run()'\n" ] }, { @@ -651,51 +651,51 @@ "input_files: '$[{@image: f} for f in sorted(glob.glob(@input_dir+''/*.*''))]'\n", "\n", "infer_dataset:\n", - " _target_: Dataset\n", - " data: '@input_files'\n", - " transform: \n", - " _target_: Compose\n", - " transforms: '@train_transforms'\n", + " _target_: Dataset\n", + " data: '@input_files'\n", + " transform: \n", + " _target_: Compose\n", + " transforms: '@train_transforms'\n", "\n", "infer_dl:\n", - " _target_: DataLoader\n", - " dataset: '@infer_dataset'\n", - " batch_size: 1\n", - " shuffle: false\n", - " num_workers: 0\n", + " _target_: DataLoader\n", + " dataset: '@infer_dataset'\n", + " batch_size: 1\n", + " shuffle: false\n", + " num_workers: 0\n", "\n", "# transforms applied to network output, same as those in training except \"label\" isn't present\n", "post_transform:\n", - " _target_: Compose\n", - " transforms:\n", - " - _target_: Activationsd\n", - " keys: '@pred'\n", - " softmax: true \n", - " - _target_: AsDiscreted\n", - " keys: ['@pred']\n", - " argmax: true \n", + " _target_: Compose\n", + " transforms:\n", + " - _target_: Activationsd\n", + " keys: '@pred'\n", + " softmax: true \n", + " - _target_: AsDiscreted\n", + " keys: ['@pred']\n", + " argmax: true \n", "\n", "# handlers to load the checkpoint file (and fail if a file isn't found), and save classification results to a csv file\n", "handlers:\n", "- _target_: CheckpointLoader\n", " load_path: '@ckpt_path'\n", - " load_dict:\n", - " model: '@net'\n", + " load_dict:\n", + " model: '@net'\n", "- _target_: ClassificationSaver\n", " batch_transform: '$lambda batch: batch[0][@image].meta'\n", - " output_transform: '$monai.handlers.from_engine([''pred''])'\n", + " output_transform: '$monai.handlers.from_engine([''pred''])'\n", "\n", "inferer: \n", - " _target_: SimpleInferer\n", + " _target_: SimpleInferer\n", "\n", "evaluator:\n", - " _target_: SupervisedEvaluator\n", - " device: '@device'\n", - " val_data_loader: '@infer_dl'\n", - " network: '@net'\n", - " inferer: '@inferer'\n", - " postprocessing: '@post_transform'\n", - " val_handlers: '@handlers'\n", + " _target_: SupervisedEvaluator\n", + " device: '@device'\n", + " val_data_loader: '@infer_dl'\n", + " network: '@net'\n", + " inferer: '@inferer'\n", + " postprocessing: '@post_transform'\n", + " val_handlers: '@handlers'\n", "\n", "inference:\n", "- '$@evaluator.run()'" diff --git a/bundle/04_integrating_code.ipynb b/bundle/04_integrating_code.ipynb index 34000ce130..7815720017 100644 --- a/bundle/04_integrating_code.ipynb +++ b/bundle/04_integrating_code.ipynb @@ -221,6 +221,9 @@ "source": [ "%%writefile IntegrationBundle/scripts/net.py\n", "\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", "\n", "class Net(nn.Module):\n", " def __init__(self):\n", @@ -269,6 +272,7 @@ "source": [ "%%writefile IntegrationBundle/scripts/transforms.py\n", "\n", + "import torchvision.transforms as transforms\n", "\n", "transform = transforms.Compose(\n", " [transforms.ToTensor(),\n", @@ -295,6 +299,8 @@ "source": [ "%%writefile IntegrationBundle/scripts/dataloaders.py\n", "\n", + "import torch\n", + "import torchvision\n", "\n", "batch_size = 4\n", "\n", @@ -342,6 +348,8 @@ "source": [ "%%writefile IntegrationBundle/scripts/train.py\n", "\n", + "import torch.nn as nn\n", + "import torch.optim as optim\n", "\n", "def train(net, trainloader):\n", " criterion = nn.CrossEntropyLoss()\n", @@ -365,7 +373,7 @@ " print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')\n", " running_loss = 0.0\n", "\n", - " print('Finished Training')" + " print('Finished Training')\n" ] }, { @@ -399,6 +407,7 @@ ], "source": [ "%%writefile IntegrationBundle/configs/train.yaml\n", + "\n", "imports:\n", "- $import torch\n", "- $import scripts\n", @@ -408,15 +417,15 @@ "- $import scripts.dataloaders\n", "\n", "net:\n", - " _target_: scripts.net.Net\n", + " _target_: scripts.net.Net\n", "\n", "transforms: '$scripts.transforms.transform'\n", "\n", "dataloader: '$scripts.dataloaders.get_dataloader(True, @transforms)'\n", "\n", "train:\n", - "- $scripts.train.train(@ net, @ dataloader)\n", - "- $torch.save( @ net.state_dict(), './cifar_net.pth')" + "- $scripts.train.train(@net, @dataloader)\n", + "- $torch.save(@net.state_dict(), './cifar_net.pth')\n" ] }, { @@ -535,6 +544,7 @@ "source": [ "%%writefile IntegrationBundle/scripts/test.py\n", "\n", + "import torch\n", "\n", "def test(net, testloader):\n", " correct = 0\n", @@ -548,7 +558,7 @@ " total += labels.size(0)\n", " correct += (predicted == labels).sum().item()\n", "\n", - " print(f'Accuracy of the network on the 10000 test images: {100 * correct // total} %')" + " print(f'Accuracy of the network on the 10000 test images: {100 * correct // total} %')\n" ] }, { @@ -569,6 +579,7 @@ ], "source": [ "%%writefile IntegrationBundle/configs/test.yaml\n", + "\n", "imports:\n", "- $import torch\n", "- $import scripts\n", @@ -577,15 +588,15 @@ "- $import scripts.dataloaders\n", "\n", "net:\n", - " _target_: scripts.net.Net\n", + " _target_: scripts.net.Net\n", "\n", "transforms: '$scripts.transforms.transform'\n", "\n", "dataloader: '$scripts.dataloaders.get_dataloader(False, @transforms)'\n", "\n", "test:\n", - "- $@ net.load_state_dict(torch.load('./cifar_net.pth'))\n", - "- $scripts.test.test(@ net, @ dataloader)" + "- $@net.load_state_dict(torch.load('./cifar_net.pth'))\n", + "- $scripts.test.test(@net, @dataloader)\n" ] }, { @@ -668,13 +679,15 @@ "source": [ "%%writefile IntegrationBundle/scripts/inference.py\n", "\n", + "import torch\n", + "from PIL import Image\n", "\n", "def inference(net, transforms, filenames):\n", " for fn in filenames:\n", " with Image.open(fn) as im:\n", - " tim = transforms(im)\n", - " outputs = net(tim[None])\n", - " _, predictions = torch.max(outputs, 1)\n", + " tim=transforms(im)\n", + " outputs=net(tim[None])\n", + " _, predictions=torch.max(outputs, 1)\n", " print(fn, predictions[0].item())" ] }, @@ -710,13 +723,13 @@ "input_files: '$sorted(glob.glob(@input_dir+''/*.*''))'\n", "\n", "net:\n", - " _target_: scripts.net.Net\n", + " _target_: scripts.net.Net\n", "\n", "transforms: '$scripts.transforms.transform'\n", "\n", "inference:\n", - "- $@ net.load_state_dict(torch.load('./cifar_net.pth'))\n", - "- $scripts.inference.inference( @ net, @ transforms, @ input_files)" + "- $@net.load_state_dict(torch.load('./cifar_net.pth'))\n", + "- $scripts.inference.inference(@net, @transforms, @input_files)" ] }, { diff --git a/bundle/05_spleen_segmentation_lightning.ipynb b/bundle/05_spleen_segmentation_lightning.ipynb index eb49934f0a..de2adc0aeb 100644 --- a/bundle/05_spleen_segmentation_lightning.ipynb +++ b/bundle/05_spleen_segmentation_lightning.ipynb @@ -438,6 +438,21 @@ "source": [ "%%writefile SpleenSegLightning/scripts/model.py\n", "\n", + "import pytorch_lightning\n", + "from monai.utils import set_determinism\n", + "from monai.transforms import (\n", + " AsDiscrete,\n", + " Compose,\n", + " EnsureType,\n", + ")\n", + "from monai.networks.nets import UNet\n", + "from monai.networks.layers import Norm\n", + "from monai.metrics import DiceMetric\n", + "from monai.losses import DiceLoss\n", + "from monai.inferers import sliding_window_inference\n", + "from monai.data import decollate_batch\n", + "import torch\n", + "\n", "\n", "class MySegNet(pytorch_lightning.LightningModule):\n", " def __init__(self):\n", @@ -556,6 +571,8 @@ "source": [ "%%writefile SpleenSegLightning/scripts/main.py\n", "\n", + "from scripts.model import MySegNet\n", + "import pytorch_lightning\n", "\n", "def train(lightninig_param, train_dl, val_dl):\n", " net = MySegNet()\n", @@ -624,44 +641,44 @@ "\n", "# define hyperparameters for the lightning trainer\n", "max_epochs: 50\n", - "default_root_dir: $@ bundle_dir + \"/logs\"\n", + "default_root_dir: $@bundle_dir+\"/logs\"\n", "check_val_every_n_epoch: 1\n", "\n", - "lightninig_param: '${\n", - " ''max_epochs'': @ max_epochs,\n", - " ''default_root_dir'': @ default_root_dir,\n", - " ''check_val_every_n_epoch'': @ check_val_every_n_epoch,\n", + "lightninig_param: '${\n", + " ''max_epochs'': @max_epochs,\n", + " ''default_root_dir'': @default_root_dir,\n", + " ''check_val_every_n_epoch'': @check_val_every_n_epoch,\n", "}'\n", "\n", "\n", "# define a transform sequence by instantiating a Compose instance with a transform sequence\n", "train_transform:\n", - " _target_: Compose\n", - " transforms:\n", - " - _target_: LoadImaged\n", - " keys: ['@image', '@label']\n", + " _target_: Compose\n", + " transforms:\n", + " - _target_: LoadImaged\n", + " keys: ['@image','@label']\n", " image_only: true\n", - " - _target_: EnsureChannelFirstd\n", - " keys: ['@image', '@label']\n", - " - _target_: Orientationd\n", - " keys: ['@image', '@label']\n", + " - _target_: EnsureChannelFirstd\n", + " keys: ['@image','@label']\n", + " - _target_: Orientationd\n", + " keys: ['@image','@label']\n", " axcodes: 'RAS'\n", - " - _target_: Spacingd\n", - " keys: ['@image', '@label']\n", + " - _target_: Spacingd\n", + " keys: ['@image','@label']\n", " pixdim: [1.5, 1.5, 2.0]\n", - " - _target_: ScaleIntensityRanged\n", + " - _target_: ScaleIntensityRanged\n", " keys: '@image'\n", " a_min: -57\n", " a_max: 164\n", " b_min: 0.0\n", " b_max: 1.0\n", " clip: True\n", - " - _target_: CropForegroundd\n", - " keys: ['@image', '@label']\n", + " - _target_: CropForegroundd\n", + " keys: ['@image','@label']\n", " allow_smaller: False\n", " source_key: '@image'\n", - " - _target_: RandCropByPosNegLabeld\n", - " keys: ['@image', '@label']\n", + " - _target_: RandCropByPosNegLabeld\n", + " keys: ['@image','@label']\n", " label_key: '@label'\n", " spatial_size: [96, 96, 96]\n", " pos: 1\n", @@ -671,58 +688,58 @@ " image_threshold: 0\n", "\n", "val_transform:\n", - " _target_: Compose\n", - " transforms:\n", - " - _target_: LoadImaged\n", - " keys: ['@image', '@label']\n", + " _target_: Compose\n", + " transforms:\n", + " - _target_: LoadImaged\n", + " keys: ['@image','@label']\n", " image_only: true\n", - " - _target_: EnsureChannelFirstd\n", - " keys: ['@image', '@label']\n", - " - _target_: Orientationd\n", - " keys: ['@image', '@label']\n", + " - _target_: EnsureChannelFirstd\n", + " keys: ['@image','@label']\n", + " - _target_: Orientationd\n", + " keys: ['@image','@label']\n", " axcodes: 'RAS'\n", - " - _target_: Spacingd\n", - " keys: ['@image', '@label']\n", + " - _target_: Spacingd\n", + " keys: ['@image','@label']\n", " pixdim: [1.5, 1.5, 2.0]\n", - " - _target_: ScaleIntensityRanged\n", + " - _target_: ScaleIntensityRanged\n", " keys: '@image'\n", " a_min: -57\n", " a_max: 164\n", " b_min: 0.0\n", " b_max: 1.0\n", " clip: True\n", - " - _target_: CropForegroundd\n", - " keys: ['@image', '@label']\n", + " - _target_: CropForegroundd\n", + " keys: ['@image','@label']\n", " source_key: '@image'\n", " allow_smaller: False\n", "\n", "val_dataset:\n", - " _target_: CacheDataset\n", - " data: '@val_files'\n", - " transform: '@val_transform'\n", - " cache_rate: 1.0\n", - " num_workers: 4\n", + " _target_: CacheDataset\n", + " data: '@val_files'\n", + " transform: '@val_transform'\n", + " cache_rate: 1.0\n", + " num_workers: 4\n", "\n", "train_dataset:\n", - " _target_: CacheDataset\n", - " data: '@train_files'\n", - " transform: '@train_transform'\n", - " cache_rate: 1.0\n", - " num_workers: 4\n", - "\n", + " _target_: CacheDataset\n", + " data: '@train_files'\n", + " transform: '@train_transform'\n", + " cache_rate: 1.0\n", + " num_workers: 4\n", + " \n", "train_dl:\n", - " _target_: DataLoader\n", - " dataset: '@train_dataset'\n", - " batch_size: 1\n", - " shuffle: true\n", - " num_workers: 4\n", - "\n", + " _target_: DataLoader\n", + " dataset: '@train_dataset'\n", + " batch_size: 1\n", + " shuffle: true\n", + " num_workers: 4\n", + " \n", "val_dl:\n", - " _target_: DataLoader\n", - " dataset: '@val_dataset'\n", - " batch_size: 1\n", - " shuffle: false\n", - " num_workers: 4\n", + " _target_: DataLoader\n", + " dataset: '@val_dataset'\n", + " batch_size: 1\n", + " shuffle: false\n", + " num_workers: 4\n", "\n", "train:\n", "- '$train(@lightninig_param, @train_dl, @val_dl)'" @@ -972,51 +989,51 @@ "ckpt_file: \"\"\n", "\n", "# define hyperparameters for the lightning trainer\n", - "default_root_dir: $@ bundle_dir + \"/logs\"\n", - "lightninig_param: '${''default_root_dir'': @default_root_dir,}'\n", + "default_root_dir: $@bundle_dir+\"/logs\"\n", + "lightninig_param: '${''default_root_dir'': @default_root_dir,}'\n", "\n", "\n", "val_transform:\n", - " _target_: Compose\n", - " transforms:\n", - " - _target_: LoadImaged\n", - " keys: ['@image', '@label']\n", + " _target_: Compose\n", + " transforms:\n", + " - _target_: LoadImaged\n", + " keys: ['@image','@label']\n", " image_only: true\n", - " - _target_: EnsureChannelFirstd\n", - " keys: ['@image', '@label']\n", - " - _target_: Orientationd\n", - " keys: ['@image', '@label']\n", + " - _target_: EnsureChannelFirstd\n", + " keys: ['@image','@label']\n", + " - _target_: Orientationd\n", + " keys: ['@image','@label']\n", " axcodes: 'RAS'\n", - " - _target_: Spacingd\n", - " keys: ['@image', '@label']\n", + " - _target_: Spacingd\n", + " keys: ['@image','@label']\n", " pixdim: [1.5, 1.5, 2.0]\n", - " - _target_: ScaleIntensityRanged\n", + " - _target_: ScaleIntensityRanged\n", " keys: '@image'\n", " a_min: -57\n", " a_max: 164\n", " b_min: 0.0\n", " b_max: 1.0\n", " clip: True\n", - " - _target_: CropForegroundd\n", - " keys: ['@image', '@label']\n", + " - _target_: CropForegroundd\n", + " keys: ['@image','@label']\n", " source_key: '@image'\n", " allow_smaller: False\n", "\n", "val_dataset:\n", - " _target_: CacheDataset\n", - " data: '@val_files'\n", - " transform: '@val_transform'\n", - " cache_rate: 1.0\n", - " num_workers: 4\n", - "\n", + " _target_: CacheDataset\n", + " data: '@val_files'\n", + " transform: '@val_transform'\n", + " cache_rate: 1.0\n", + " num_workers: 4\n", + " \n", "val_dl:\n", - " _target_: DataLoader\n", - " dataset: '@val_dataset'\n", - " batch_size: 1\n", - " shuffle: false\n", - " num_workers: 4\n", - "\n", + " _target_: DataLoader\n", + " dataset: '@val_dataset'\n", + " batch_size: 1\n", + " shuffle: false\n", + " num_workers: 4\n", "\n", + " \n", "# loads the weights from the given file (which needs to be set on the command line) then calls \"evaluate\" script\n", "evaluate:\n", "- '$evaluate(@lightninig_param,@ckpt_file, @val_dl)'" @@ -1170,7 +1187,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.18" + "version": "3.10.12" } }, "nbformat": 4,