Skip to content

Commit

Permalink
refactor(docs): simplify resume.py
Browse files Browse the repository at this point in the history
  • Loading branch information
adosar committed Dec 29, 2024
1 parent 5942280 commit 786a641
Showing 1 changed file with 8 additions and 57 deletions.
65 changes: 8 additions & 57 deletions docs/source/examples/resume.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,67 +15,15 @@
# * Training was performed with :doc:`AIdsorb CLI <../cli>` or :ref:`AIdsorb +
# PyTorch Lightning <aidsorb_with_pytorch_and_lightning>`.

import yaml
import torch
import lightning as L
from lightning.pytorch.cli import LightningArgumentParser
from aidsorb.datamodules import PCDDataModule
from aidsorb.litmodels import PCDLit
from aidsorb.datamodules import PCDDataModule

# %%
# The following snipper let us instantiate:
#
# * :class:`~lightning.pytorch.trainer.trainer.Trainer`
# * :class:`~lightning.pytorch.core.LightningModule`
# * :class:`~lightning.pytorch.core.LightningDataModule`
#
# with the same settings as in the ``.yaml`` configuration file. For more
# information 👉 `here
# <https://github.com/Lightning-AI/pytorch-lightning/discussions/10363#discussioncomment-2326235>`_.

# %%
# .. note::
# You are responsible for restoring the model's state (the weights of the model).

filename = 'path/to/config.yaml'
with open(filename, 'r') as f:
config_dict = yaml.safe_load(f)

parser = LightningArgumentParser()
parser.add_lightning_class_args(PCDLit, 'model')
parser.add_lightning_class_args(PCDDataModule, 'data')
parser.add_class_arguments(L.Trainer, 'trainer')

# Any other key present in the config file must also be added.
# parser.add_argument(--<keyname>, ...)
# For more information 👉 https://jsonargparse.readthedocs.io/en/stable/#parsers
parser.add_argument('--seed_everything')
parser.add_argument('--ckpt_path')
parser.add_argument('--optimizer')
parser.add_argument('--lr_scheduler')

config = parser.parse_object(config_dict)
objects = parser.instantiate_classes(config)

# %%

trainer, litmodel, dm = objects.trainer, objects.model, objects.data

# %% The remaining part is to restore the model's state, i.e. load back the trained weights.

# %%
# Restoring model's state
# -----------------------

# Load the the checkpoint.
ckpt = torch.load('path/to/checkpoints/checkpoint.ckpt')

# %%

# Load back the weights.
litmodel.load_state_dict(ckpt['state_dict'])

# %%
# Restore lightning modules from checkpoint.
ckpt_path = 'path/to/checkpoint.ckpt'
litmodel = PCDLit.load_from_checkpoint(ckpt_path)
dm = PCDDataModule.load_from_checkpoint(ckpt_path)

# Set the model for inference (disable grads & enable eval mode).
litmodel.freeze()
Expand All @@ -89,6 +37,9 @@
# Measure performance
# -------------------

# Instantiate a trainer object.
trainer = L.Trainer(...)

# Measure performance on test set.
trainer.test(litmodel, dm)

Expand Down

0 comments on commit 786a641

Please sign in to comment.