Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add best_k_metrics parameter to the ModelCheckpoint #20457

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 31 additions & 14 deletions src/lightning/pytorch/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,22 +27,29 @@
from copy import deepcopy
from datetime import timedelta
from pathlib import Path
from typing import Any, Literal, Optional, Union
from typing import Any, Literal, Optional, Union, cast
from weakref import proxy

import pytorch_lightning as pl
import torch
import yaml
from lightning_fabric.utilities.cloud_io import (
_is_dir,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we probably have different formatting options, please run

pre-commit run --all-files

from the base directory to make sure the PR conforms

_is_local_file_protocol,
get_filesystem,
)
from lightning_fabric.utilities.types import _PATH
from pytorch_lightning.callbacks import Checkpoint
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change to from lightning.pytorch.callbacks ...
and same for the following lines

from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.rank_zero import (
WarningCache,
rank_zero_info,
rank_zero_warn,
)
from pytorch_lightning.utilities.types import STEP_OUTPUT
from torch import Tensor
from typing_extensions import override

import lightning.pytorch as pl
from lightning.fabric.utilities.cloud_io import _is_dir, _is_local_file_protocol, get_filesystem
from lightning.fabric.utilities.types import _PATH
from lightning.pytorch.callbacks import Checkpoint
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.rank_zero import WarningCache, rank_zero_info, rank_zero_warn
from lightning.pytorch.utilities.types import STEP_OUTPUT

log = logging.getLogger(__name__)
warning_cache = WarningCache()

Expand Down Expand Up @@ -241,9 +248,10 @@ def __init__(
self._last_global_step_saved = 0 # no need to save when no steps were taken
self._last_time_checked: Optional[float] = None
self.current_score: Optional[Tensor] = None
self.best_k_models: dict[str, Tensor] = {}
self.best_k_models: dict[str, dict[str, Tensor | dict[str, Tensor]]] = {}
self.kth_best_model_path = ""
self.best_model_score: Optional[Tensor] = None
self.best_model_metrics: Optional[Dict[str, Tensor]] = None
self.best_model_path = ""
self.last_model_path = ""
self._last_checkpoint_saved = ""
Expand Down Expand Up @@ -339,6 +347,7 @@ def state_dict(self) -> dict[str, Any]:
return {
"monitor": self.monitor,
"best_model_score": self.best_model_score,
"best_model_metrics": self.best_model_metrics,
"best_model_path": self.best_model_path,
"current_score": self.current_score,
"dirpath": self.dirpath,
Expand All @@ -354,6 +363,7 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None:

if self.dirpath == dirpath_from_ckpt:
self.best_model_score = state_dict["best_model_score"]
self.best_model_metrics = state_dict["best_model_metrics"]
self.kth_best_model_path = state_dict.get("kth_best_model_path", self.kth_best_model_path)
self.kth_value = state_dict.get("kth_value", self.kth_value)
self.best_k_models = state_dict.get("best_k_models", self.best_k_models)
Expand Down Expand Up @@ -523,7 +533,9 @@ def check_monitor_top_k(self, trainer: "pl.Trainer", current: Optional[Tensor] =
return True

monitor_op = {"min": torch.lt, "max": torch.gt}[self.mode]
should_update_best_and_save = monitor_op(current, self.best_k_models[self.kth_best_model_path])
should_update_best_and_save = monitor_op(
current, cast(Tensor, self.best_k_models[self.kth_best_model_path]["score"])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this will stay as in the original if we avoid changing best_k_models

)

# If using multiple devices, make sure all processes are unanimous on the decision.
should_update_best_and_save = trainer.strategy.reduce_boolean_decision(bool(should_update_best_and_save))
Expand Down Expand Up @@ -735,17 +747,22 @@ def _update_best_and_save(

# save the current score
self.current_score = current
self.best_k_models[filepath] = current
self.best_k_models[filepath] = {
"score": current,
"metrics": monitor_candidates,
}

if len(self.best_k_models) == k:
# monitor dict has reached k elements
_op = max if self.mode == "min" else min
self.kth_best_model_path = _op(self.best_k_models, key=self.best_k_models.get) # type: ignore[arg-type]
self.kth_value = self.best_k_models[self.kth_best_model_path]
self.kth_model_metrics = self.best_k_models[self.kth_best_model_path]["metrics"]

_op = min if self.mode == "min" else max
self.best_model_path = _op(self.best_k_models, key=self.best_k_models.get) # type: ignore[arg-type]
self.best_model_score = self.best_k_models[self.best_model_path]
self.best_model_score = self.best_k_models[self.best_model_path]["score"]
self.best_model_metrics = self.best_k_models[self.best_model_path]["metrics"]

if self.verbose:
epoch = monitor_candidates["epoch"]
Expand All @@ -762,7 +779,7 @@ def _update_best_and_save(
def to_yaml(self, filepath: Optional[_PATH] = None) -> None:
"""Saves the `best_k_models` dict containing the checkpoint paths with the corresponding scores to a YAML
file."""
best_k = {k: v.item() for k, v in self.best_k_models.items()}
best_k = {k: v["score"].item() for k, v in self.best_k_models.items()} # type: ignore[arg-type]
if filepath is None:
assert self.dirpath
filepath = os.path.join(self.dirpath, "best_k_models.yaml")
Expand Down
5 changes: 3 additions & 2 deletions tests/tests_pytorch/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from unittest.mock import Mock, call, patch

import cloudpickle
import lightning.pytorch as pl
import pytest
import torch
import yaml
Expand Down Expand Up @@ -702,6 +701,7 @@ def test_model_checkpoint_save_last_none_monitor(tmp_path, caplog):
assert checkpoint_callback.best_model_path == str(tmp_path / "epoch=1-step=20.ckpt")
assert checkpoint_callback.last_model_path == str(tmp_path / "last.ckpt")
assert checkpoint_callback.best_model_score is None
assert checkpoint_callback.best_model_metrics is None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need to add tests that exercise the new code

assert checkpoint_callback.best_k_models == {}
assert checkpoint_callback.kth_best_model_path == ""

Expand Down Expand Up @@ -808,6 +808,7 @@ def test_model_checkpoint_topk_zero(tmp_path):
assert checkpoint_callback.monitor is None
assert checkpoint_callback.best_model_path == ""
assert checkpoint_callback.best_model_score is None
assert checkpoint_callback.best_model_metrics is None
assert checkpoint_callback.best_k_models == {}
assert checkpoint_callback.kth_best_model_path == ""
# check that only the last ckpt was created
Expand Down Expand Up @@ -1073,7 +1074,7 @@ def assert_checkpoint_log_dir(idx):

# load from checkpoint
trainer_config["logger"] = TensorBoardLogger(tmp_path)
trainer = pl.Trainer(**trainer_config)
trainer = Trainer(**trainer_config)
assert_trainer_init(trainer)

model = ExtendedBoringModel()
Expand Down
Loading