Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Feb 26, 2024
1 parent 24f03f9 commit 5112ffa
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 15 deletions.
7 changes: 2 additions & 5 deletions 2d_classification/monai_101.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -264,10 +264,7 @@
" network=model,\n",
" inferer=SimpleInferer(),\n",
" key_val_metric={\"val_acc\": ignite.metrics.Accuracy(from_engine([\"pred\", \"label\"]))},\n",
" val_handlers=[\n",
" StatsHandler(iteration_log=False),\n",
" TensorBoardStatsHandler(iteration_log=False)\n",
" ],\n",
" val_handlers=[StatsHandler(iteration_log=False), TensorBoardStatsHandler(iteration_log=False)],\n",
")\n",
"trainer = SupervisedTrainer(\n",
" device=torch.device(\"cuda:0\"),\n",
Expand All @@ -280,7 +277,7 @@
" train_handlers=[\n",
" ValidationHandler(validator=evaluator, epoch_level=True, interval=1),\n",
" StatsHandler(),\n",
" TensorBoardStatsHandler(tag_name=\"train_loss\", output_transform=from_engine([\"loss\"], first=True))\n",
" TensorBoardStatsHandler(tag_name=\"train_loss\", output_transform=from_engine([\"loss\"], first=True)),\n",
" ],\n",
")"
]
Expand Down
23 changes: 13 additions & 10 deletions 2d_classification/monai_201.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,14 @@
"from monai.config import print_config\n",
"from monai.data import DataLoader\n",
"from monai.engines import SupervisedTrainer, SupervisedEvaluator\n",
"from monai.handlers import StatsHandler, TensorBoardStatsHandler, ValidationHandler, CheckpointSaver, CheckpointLoader, ClassificationSaver\n",
"from monai.handlers import (\n",
" StatsHandler,\n",
" TensorBoardStatsHandler,\n",
" ValidationHandler,\n",
" CheckpointSaver,\n",
" CheckpointLoader,\n",
" ClassificationSaver,\n",
")\n",
"from monai.handlers.utils import from_engine\n",
"from monai.inferers import SimpleInferer\n",
"from monai.networks import eval_mode\n",
Expand Down Expand Up @@ -179,7 +186,7 @@
"source": [
"max_epochs = 5\n",
"save_interval = 2\n",
"out_dir = './eval'\n",
"out_dir = \"./eval\"\n",
"model = densenet121(spatial_dims=2, in_channels=1, out_channels=6).to(\"cuda:0\")\n",
"\n",
"logging.basicConfig(stream=sys.stdout, level=logging.INFO)\n",
Expand All @@ -190,10 +197,7 @@
" network=model,\n",
" inferer=SimpleInferer(),\n",
" key_val_metric={\"val_acc\": ignite.metrics.Accuracy(from_engine([\"pred\", \"label\"]))},\n",
" val_handlers=[\n",
" StatsHandler(iteration_log=False),\n",
" TensorBoardStatsHandler(iteration_log=False)\n",
" ],\n",
" val_handlers=[StatsHandler(iteration_log=False), TensorBoardStatsHandler(iteration_log=False)],\n",
")\n",
"\n",
"trainer = SupervisedTrainer(\n",
Expand All @@ -214,7 +218,7 @@
" final_filename=\"checkpoint.pt\",\n",
" ),\n",
" StatsHandler(),\n",
" TensorBoardStatsHandler(tag_name=\"train_loss\", output_transform=from_engine([\"loss\"], first=True))\n",
" TensorBoardStatsHandler(tag_name=\"train_loss\", output_transform=from_engine([\"loss\"], first=True)),\n",
" ],\n",
")"
]
Expand Down Expand Up @@ -290,9 +294,8 @@
" val_handlers=[\n",
" CheckpointLoader(load_path=f\"{out_dir}/checkpoint.pt\", load_dict={\"model\": model}),\n",
" ClassificationSaver(\n",
" batch_transform=lambda batch: batch[0][\"image\"].meta,\n",
" output_transform=from_engine(['pred'])\n",
" )\n",
" batch_transform=lambda batch: batch[0][\"image\"].meta, output_transform=from_engine([\"pred\"])\n",
" ),\n",
" ],\n",
")\n",
"\n",
Expand Down

0 comments on commit 5112ffa

Please sign in to comment.