diff --git a/docs/source/examples/resume.py b/docs/source/examples/resume.py index daa6a6f..3d3ed48 100644 --- a/docs/source/examples/resume.py +++ b/docs/source/examples/resume.py @@ -15,67 +15,15 @@ # * Training was performed with :doc:`AIdsorb CLI <../cli>` or :ref:`AIdsorb + # PyTorch Lightning `. -import yaml import torch import lightning as L -from lightning.pytorch.cli import LightningArgumentParser -from aidsorb.datamodules import PCDDataModule from aidsorb.litmodels import PCDLit +from aidsorb.datamodules import PCDDataModule -# %% -# The following snipper let us instantiate: -# -# * :class:`~lightning.pytorch.trainer.trainer.Trainer` -# * :class:`~lightning.pytorch.core.LightningModule` -# * :class:`~lightning.pytorch.core.LightningDataModule` -# -# with the same settings as in the ``.yaml`` configuration file. For more -# information 👉 `here -# `_. - -# %% -# .. note:: -# You are responsible for restoring the model's state (the weights of the model). - -filename = 'path/to/config.yaml' -with open(filename, 'r') as f: - config_dict = yaml.safe_load(f) - -parser = LightningArgumentParser() -parser.add_lightning_class_args(PCDLit, 'model') -parser.add_lightning_class_args(PCDDataModule, 'data') -parser.add_class_arguments(L.Trainer, 'trainer') - -# Any other key present in the config file must also be added. -# parser.add_argument(--, ...) -# For more information 👉 https://jsonargparse.readthedocs.io/en/stable/#parsers -parser.add_argument('--seed_everything') -parser.add_argument('--ckpt_path') -parser.add_argument('--optimizer') -parser.add_argument('--lr_scheduler') - -config = parser.parse_object(config_dict) -objects = parser.instantiate_classes(config) - -# %% - -trainer, litmodel, dm = objects.trainer, objects.model, objects.data - -# %% The remaining part is to restore the model's state, i.e. load back the trained weights. - -# %% -# Restoring model's state -# ----------------------- - -# Load the the checkpoint. -ckpt = torch.load('path/to/checkpoints/checkpoint.ckpt') - -# %% - -# Load back the weights. -litmodel.load_state_dict(ckpt['state_dict']) - -# %% +# Restore lightning modules from checkpoint. +ckpt_path = 'path/to/checkpoint.ckpt' +litmodel = PCDLit.load_from_checkpoint(ckpt_path) +dm = PCDDataModule.load_from_checkpoint(ckpt_path) # Set the model for inference (disable grads & enable eval mode). litmodel.freeze() @@ -89,6 +37,9 @@ # Measure performance # ------------------- +# Instantiate a trainer object. +trainer = L.Trainer(...) + # Measure performance on test set. trainer.test(litmodel, dm)