diff --git a/src/lightning/pytorch/callbacks/model_checkpoint.py b/src/lightning/pytorch/callbacks/model_checkpoint.py index 85bfb65c0ea6e..2d5d7df3e7c74 100644 --- a/src/lightning/pytorch/callbacks/model_checkpoint.py +++ b/src/lightning/pytorch/callbacks/model_checkpoint.py @@ -27,22 +27,29 @@ from copy import deepcopy from datetime import timedelta from pathlib import Path -from typing import Any, Literal, Optional, Union +from typing import Any, Literal, Optional, Union, cast from weakref import proxy +import pytorch_lightning as pl import torch import yaml +from lightning_fabric.utilities.cloud_io import ( + _is_dir, + _is_local_file_protocol, + get_filesystem, +) +from lightning_fabric.utilities.types import _PATH +from pytorch_lightning.callbacks import Checkpoint +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.rank_zero import ( + WarningCache, + rank_zero_info, + rank_zero_warn, +) +from pytorch_lightning.utilities.types import STEP_OUTPUT from torch import Tensor from typing_extensions import override -import lightning.pytorch as pl -from lightning.fabric.utilities.cloud_io import _is_dir, _is_local_file_protocol, get_filesystem -from lightning.fabric.utilities.types import _PATH -from lightning.pytorch.callbacks import Checkpoint -from lightning.pytorch.utilities.exceptions import MisconfigurationException -from lightning.pytorch.utilities.rank_zero import WarningCache, rank_zero_info, rank_zero_warn -from lightning.pytorch.utilities.types import STEP_OUTPUT - log = logging.getLogger(__name__) warning_cache = WarningCache() @@ -241,9 +248,10 @@ def __init__( self._last_global_step_saved = 0 # no need to save when no steps were taken self._last_time_checked: Optional[float] = None self.current_score: Optional[Tensor] = None - self.best_k_models: dict[str, Tensor] = {} + self.best_k_models: dict[str, dict[str, Tensor | dict[str, Tensor]]] = {} self.kth_best_model_path = "" self.best_model_score: Optional[Tensor] = None + self.best_model_metrics: Optional[Dict[str, Tensor]] = None self.best_model_path = "" self.last_model_path = "" self._last_checkpoint_saved = "" @@ -339,6 +347,7 @@ def state_dict(self) -> dict[str, Any]: return { "monitor": self.monitor, "best_model_score": self.best_model_score, + "best_model_metrics": self.best_model_metrics, "best_model_path": self.best_model_path, "current_score": self.current_score, "dirpath": self.dirpath, @@ -354,6 +363,7 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None: if self.dirpath == dirpath_from_ckpt: self.best_model_score = state_dict["best_model_score"] + self.best_model_metrics = state_dict["best_model_metrics"] self.kth_best_model_path = state_dict.get("kth_best_model_path", self.kth_best_model_path) self.kth_value = state_dict.get("kth_value", self.kth_value) self.best_k_models = state_dict.get("best_k_models", self.best_k_models) @@ -523,7 +533,9 @@ def check_monitor_top_k(self, trainer: "pl.Trainer", current: Optional[Tensor] = return True monitor_op = {"min": torch.lt, "max": torch.gt}[self.mode] - should_update_best_and_save = monitor_op(current, self.best_k_models[self.kth_best_model_path]) + should_update_best_and_save = monitor_op( + current, cast(Tensor, self.best_k_models[self.kth_best_model_path]["score"]) + ) # If using multiple devices, make sure all processes are unanimous on the decision. should_update_best_and_save = trainer.strategy.reduce_boolean_decision(bool(should_update_best_and_save)) @@ -735,17 +747,22 @@ def _update_best_and_save( # save the current score self.current_score = current - self.best_k_models[filepath] = current + self.best_k_models[filepath] = { + "score": current, + "metrics": monitor_candidates, + } if len(self.best_k_models) == k: # monitor dict has reached k elements _op = max if self.mode == "min" else min self.kth_best_model_path = _op(self.best_k_models, key=self.best_k_models.get) # type: ignore[arg-type] self.kth_value = self.best_k_models[self.kth_best_model_path] + self.kth_model_metrics = self.best_k_models[self.kth_best_model_path]["metrics"] _op = min if self.mode == "min" else max self.best_model_path = _op(self.best_k_models, key=self.best_k_models.get) # type: ignore[arg-type] - self.best_model_score = self.best_k_models[self.best_model_path] + self.best_model_score = self.best_k_models[self.best_model_path]["score"] + self.best_model_metrics = self.best_k_models[self.best_model_path]["metrics"] if self.verbose: epoch = monitor_candidates["epoch"] @@ -762,7 +779,7 @@ def _update_best_and_save( def to_yaml(self, filepath: Optional[_PATH] = None) -> None: """Saves the `best_k_models` dict containing the checkpoint paths with the corresponding scores to a YAML file.""" - best_k = {k: v.item() for k, v in self.best_k_models.items()} + best_k = {k: v["score"].item() for k, v in self.best_k_models.items()} # type: ignore[arg-type] if filepath is None: assert self.dirpath filepath = os.path.join(self.dirpath, "best_k_models.yaml") diff --git a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py index ef0abc7c463a8..031fd32b8bbb2 100644 --- a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py +++ b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py @@ -25,7 +25,6 @@ from unittest.mock import Mock, call, patch import cloudpickle -import lightning.pytorch as pl import pytest import torch import yaml @@ -702,6 +701,7 @@ def test_model_checkpoint_save_last_none_monitor(tmp_path, caplog): assert checkpoint_callback.best_model_path == str(tmp_path / "epoch=1-step=20.ckpt") assert checkpoint_callback.last_model_path == str(tmp_path / "last.ckpt") assert checkpoint_callback.best_model_score is None + assert checkpoint_callback.best_model_metrics is None assert checkpoint_callback.best_k_models == {} assert checkpoint_callback.kth_best_model_path == "" @@ -808,6 +808,7 @@ def test_model_checkpoint_topk_zero(tmp_path): assert checkpoint_callback.monitor is None assert checkpoint_callback.best_model_path == "" assert checkpoint_callback.best_model_score is None + assert checkpoint_callback.best_model_metrics is None assert checkpoint_callback.best_k_models == {} assert checkpoint_callback.kth_best_model_path == "" # check that only the last ckpt was created @@ -1073,7 +1074,7 @@ def assert_checkpoint_log_dir(idx): # load from checkpoint trainer_config["logger"] = TensorBoardLogger(tmp_path) - trainer = pl.Trainer(**trainer_config) + trainer = Trainer(**trainer_config) assert_trainer_init(trainer) model = ExtendedBoringModel()