-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add best_k_metrics parameter to the ModelCheckpoint #20457
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,22 +27,29 @@ | |
from copy import deepcopy | ||
from datetime import timedelta | ||
from pathlib import Path | ||
from typing import Any, Literal, Optional, Union | ||
from typing import Any, Literal, Optional, Union, cast | ||
from weakref import proxy | ||
|
||
import pytorch_lightning as pl | ||
import torch | ||
import yaml | ||
from lightning_fabric.utilities.cloud_io import ( | ||
_is_dir, | ||
_is_local_file_protocol, | ||
get_filesystem, | ||
) | ||
from lightning_fabric.utilities.types import _PATH | ||
from pytorch_lightning.callbacks import Checkpoint | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. change to |
||
from pytorch_lightning.utilities.exceptions import MisconfigurationException | ||
from pytorch_lightning.utilities.rank_zero import ( | ||
WarningCache, | ||
rank_zero_info, | ||
rank_zero_warn, | ||
) | ||
from pytorch_lightning.utilities.types import STEP_OUTPUT | ||
from torch import Tensor | ||
from typing_extensions import override | ||
|
||
import lightning.pytorch as pl | ||
from lightning.fabric.utilities.cloud_io import _is_dir, _is_local_file_protocol, get_filesystem | ||
from lightning.fabric.utilities.types import _PATH | ||
from lightning.pytorch.callbacks import Checkpoint | ||
from lightning.pytorch.utilities.exceptions import MisconfigurationException | ||
from lightning.pytorch.utilities.rank_zero import WarningCache, rank_zero_info, rank_zero_warn | ||
from lightning.pytorch.utilities.types import STEP_OUTPUT | ||
|
||
log = logging.getLogger(__name__) | ||
warning_cache = WarningCache() | ||
|
||
|
@@ -241,9 +248,10 @@ def __init__( | |
self._last_global_step_saved = 0 # no need to save when no steps were taken | ||
self._last_time_checked: Optional[float] = None | ||
self.current_score: Optional[Tensor] = None | ||
self.best_k_models: dict[str, Tensor] = {} | ||
self.best_k_models: dict[str, dict[str, Tensor | dict[str, Tensor]]] = {} | ||
self.kth_best_model_path = "" | ||
self.best_model_score: Optional[Tensor] = None | ||
self.best_model_metrics: Optional[Dict[str, Tensor]] = None | ||
self.best_model_path = "" | ||
self.last_model_path = "" | ||
self._last_checkpoint_saved = "" | ||
|
@@ -339,6 +347,7 @@ def state_dict(self) -> dict[str, Any]: | |
return { | ||
"monitor": self.monitor, | ||
"best_model_score": self.best_model_score, | ||
"best_model_metrics": self.best_model_metrics, | ||
"best_model_path": self.best_model_path, | ||
"current_score": self.current_score, | ||
"dirpath": self.dirpath, | ||
|
@@ -354,6 +363,7 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None: | |
|
||
if self.dirpath == dirpath_from_ckpt: | ||
self.best_model_score = state_dict["best_model_score"] | ||
self.best_model_metrics = state_dict["best_model_metrics"] | ||
self.kth_best_model_path = state_dict.get("kth_best_model_path", self.kth_best_model_path) | ||
self.kth_value = state_dict.get("kth_value", self.kth_value) | ||
self.best_k_models = state_dict.get("best_k_models", self.best_k_models) | ||
|
@@ -523,7 +533,9 @@ def check_monitor_top_k(self, trainer: "pl.Trainer", current: Optional[Tensor] = | |
return True | ||
|
||
monitor_op = {"min": torch.lt, "max": torch.gt}[self.mode] | ||
should_update_best_and_save = monitor_op(current, self.best_k_models[self.kth_best_model_path]) | ||
should_update_best_and_save = monitor_op( | ||
current, cast(Tensor, self.best_k_models[self.kth_best_model_path]["score"]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this will stay as in the original if we avoid changing |
||
) | ||
|
||
# If using multiple devices, make sure all processes are unanimous on the decision. | ||
should_update_best_and_save = trainer.strategy.reduce_boolean_decision(bool(should_update_best_and_save)) | ||
|
@@ -735,17 +747,22 @@ def _update_best_and_save( | |
|
||
# save the current score | ||
self.current_score = current | ||
self.best_k_models[filepath] = current | ||
self.best_k_models[filepath] = { | ||
"score": current, | ||
"metrics": monitor_candidates, | ||
} | ||
|
||
if len(self.best_k_models) == k: | ||
# monitor dict has reached k elements | ||
_op = max if self.mode == "min" else min | ||
self.kth_best_model_path = _op(self.best_k_models, key=self.best_k_models.get) # type: ignore[arg-type] | ||
self.kth_value = self.best_k_models[self.kth_best_model_path] | ||
self.kth_model_metrics = self.best_k_models[self.kth_best_model_path]["metrics"] | ||
|
||
_op = min if self.mode == "min" else max | ||
self.best_model_path = _op(self.best_k_models, key=self.best_k_models.get) # type: ignore[arg-type] | ||
self.best_model_score = self.best_k_models[self.best_model_path] | ||
self.best_model_score = self.best_k_models[self.best_model_path]["score"] | ||
self.best_model_metrics = self.best_k_models[self.best_model_path]["metrics"] | ||
|
||
if self.verbose: | ||
epoch = monitor_candidates["epoch"] | ||
|
@@ -762,7 +779,7 @@ def _update_best_and_save( | |
def to_yaml(self, filepath: Optional[_PATH] = None) -> None: | ||
"""Saves the `best_k_models` dict containing the checkpoint paths with the corresponding scores to a YAML | ||
file.""" | ||
best_k = {k: v.item() for k, v in self.best_k_models.items()} | ||
best_k = {k: v["score"].item() for k, v in self.best_k_models.items()} # type: ignore[arg-type] | ||
if filepath is None: | ||
assert self.dirpath | ||
filepath = os.path.join(self.dirpath, "best_k_models.yaml") | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,7 +25,6 @@ | |
from unittest.mock import Mock, call, patch | ||
|
||
import cloudpickle | ||
import lightning.pytorch as pl | ||
import pytest | ||
import torch | ||
import yaml | ||
|
@@ -702,6 +701,7 @@ def test_model_checkpoint_save_last_none_monitor(tmp_path, caplog): | |
assert checkpoint_callback.best_model_path == str(tmp_path / "epoch=1-step=20.ckpt") | ||
assert checkpoint_callback.last_model_path == str(tmp_path / "last.ckpt") | ||
assert checkpoint_callback.best_model_score is None | ||
assert checkpoint_callback.best_model_metrics is None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we need to add tests that exercise the new code |
||
assert checkpoint_callback.best_k_models == {} | ||
assert checkpoint_callback.kth_best_model_path == "" | ||
|
||
|
@@ -808,6 +808,7 @@ def test_model_checkpoint_topk_zero(tmp_path): | |
assert checkpoint_callback.monitor is None | ||
assert checkpoint_callback.best_model_path == "" | ||
assert checkpoint_callback.best_model_score is None | ||
assert checkpoint_callback.best_model_metrics is None | ||
assert checkpoint_callback.best_k_models == {} | ||
assert checkpoint_callback.kth_best_model_path == "" | ||
# check that only the last ckpt was created | ||
|
@@ -1073,7 +1074,7 @@ def assert_checkpoint_log_dir(idx): | |
|
||
# load from checkpoint | ||
trainer_config["logger"] = TensorBoardLogger(tmp_path) | ||
trainer = pl.Trainer(**trainer_config) | ||
trainer = Trainer(**trainer_config) | ||
assert_trainer_init(trainer) | ||
|
||
model = ExtendedBoringModel() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we probably have different formatting options, please run
from the base directory to make sure the PR conforms