From 786a641e07cf22961dbd0edbf4a5ac40beb94d72 Mon Sep 17 00:00:00 2001 From: Antonios Sarikas Date: Sun, 29 Dec 2024 22:41:48 +0200 Subject: [PATCH] refactor(docs): simplify `resume.py` --- docs/source/examples/resume.py | 65 +++++----------------------------- 1 file changed, 8 insertions(+), 57 deletions(-) diff --git a/docs/source/examples/resume.py b/docs/source/examples/resume.py index daa6a6f..3d3ed48 100644 --- a/docs/source/examples/resume.py +++ b/docs/source/examples/resume.py @@ -15,67 +15,15 @@ # * Training was performed with :doc:`AIdsorb CLI <../cli>` or :ref:`AIdsorb + # PyTorch Lightning `. -import yaml import torch import lightning as L -from lightning.pytorch.cli import LightningArgumentParser -from aidsorb.datamodules import PCDDataModule from aidsorb.litmodels import PCDLit +from aidsorb.datamodules import PCDDataModule -# %% -# The following snipper let us instantiate: -# -# * :class:`~lightning.pytorch.trainer.trainer.Trainer` -# * :class:`~lightning.pytorch.core.LightningModule` -# * :class:`~lightning.pytorch.core.LightningDataModule` -# -# with the same settings as in the ``.yaml`` configuration file. For more -# information 👉 `here -# `_. - -# %% -# .. note:: -# You are responsible for restoring the model's state (the weights of the model). - -filename = 'path/to/config.yaml' -with open(filename, 'r') as f: - config_dict = yaml.safe_load(f) - -parser = LightningArgumentParser() -parser.add_lightning_class_args(PCDLit, 'model') -parser.add_lightning_class_args(PCDDataModule, 'data') -parser.add_class_arguments(L.Trainer, 'trainer') - -# Any other key present in the config file must also be added. -# parser.add_argument(--, ...) -# For more information 👉 https://jsonargparse.readthedocs.io/en/stable/#parsers -parser.add_argument('--seed_everything') -parser.add_argument('--ckpt_path') -parser.add_argument('--optimizer') -parser.add_argument('--lr_scheduler') - -config = parser.parse_object(config_dict) -objects = parser.instantiate_classes(config) - -# %% - -trainer, litmodel, dm = objects.trainer, objects.model, objects.data - -# %% The remaining part is to restore the model's state, i.e. load back the trained weights. - -# %% -# Restoring model's state -# ----------------------- - -# Load the the checkpoint. -ckpt = torch.load('path/to/checkpoints/checkpoint.ckpt') - -# %% - -# Load back the weights. -litmodel.load_state_dict(ckpt['state_dict']) - -# %% +# Restore lightning modules from checkpoint. +ckpt_path = 'path/to/checkpoint.ckpt' +litmodel = PCDLit.load_from_checkpoint(ckpt_path) +dm = PCDDataModule.load_from_checkpoint(ckpt_path) # Set the model for inference (disable grads & enable eval mode). litmodel.freeze() @@ -89,6 +37,9 @@ # Measure performance # ------------------- +# Instantiate a trainer object. +trainer = L.Trainer(...) + # Measure performance on test set. trainer.test(litmodel, dm)