From 19a5d2be26fa02039b2b69029ddf40a6f578de5f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Martin=20Kozlovsk=C3=BD?= Date: Tue, 27 Aug 2024 10:51:27 -0400 Subject: [PATCH] Fail-Proof Checkpoint Usage in Callbacks (#65) --- luxonis_train/__main__.py | 3 + .../callbacks/archive_on_train_end.py | 16 ++---- .../callbacks/export_on_train_end.py | 19 +++---- luxonis_train/callbacks/needs_checkpoint.py | 56 +++++++++++++++++++ luxonis_train/core/core.py | 4 ++ luxonis_train/utils/config.py | 33 +++++++++-- 6 files changed, 103 insertions(+), 28 deletions(-) create mode 100644 luxonis_train/callbacks/needs_checkpoint.py diff --git a/luxonis_train/__main__.py b/luxonis_train/__main__.py index cdc66954..454e9525 100644 --- a/luxonis_train/__main__.py +++ b/luxonis_train/__main__.py @@ -6,6 +6,9 @@ import typer import yaml +from luxonis_ml.utils import setup_logging + +setup_logging(use_rich=True) class _ViewType(str, Enum): diff --git a/luxonis_train/callbacks/archive_on_train_end.py b/luxonis_train/callbacks/archive_on_train_end.py index 7d6da67f..d9e7b298 100644 --- a/luxonis_train/callbacks/archive_on_train_end.py +++ b/luxonis_train/callbacks/archive_on_train_end.py @@ -5,11 +5,13 @@ import luxonis_train from luxonis_train.utils.registry import CALLBACKS +from .needs_checkpoint import NeedsCheckpoint + logger = logging.getLogger(__name__) @CALLBACKS.register_module() -class ArchiveOnTrainEnd(pl.Callback): +class ArchiveOnTrainEnd(NeedsCheckpoint): def on_train_end( self, _: pl.Trainer, @@ -21,17 +23,11 @@ def on_train_end( @param trainer: Pytorch Lightning trainer. @type pl_module: L{pl.LightningModule} @param pl_module: Pytorch Lightning module. - @raises RuntimeError: If no best model path is found. """ - best_model_path = pl_module.core.get_min_loss_checkpoint_path() - if not best_model_path: - logger.error( - "No best model path found. " - "Please make sure that ModelCheckpoint callback is present " - "and at least one validation epoch has been performed. " - "Skipping model archiving." - ) + path = self.get_checkpoint(pl_module) + if path is None: + logger.warning("Skipping model archiving.") return onnx_path = pl_module.core._exported_models.get("onnx") diff --git a/luxonis_train/callbacks/export_on_train_end.py b/luxonis_train/callbacks/export_on_train_end.py index 7e8f8a71..261c4ef6 100644 --- a/luxonis_train/callbacks/export_on_train_end.py +++ b/luxonis_train/callbacks/export_on_train_end.py @@ -5,11 +5,13 @@ import luxonis_train from luxonis_train.utils.registry import CALLBACKS +from .needs_checkpoint import NeedsCheckpoint + logger = logging.getLogger(__name__) @CALLBACKS.register_module() -class ExportOnTrainEnd(pl.Callback): +class ExportOnTrainEnd(NeedsCheckpoint): def on_train_end( self, _: pl.Trainer, @@ -21,17 +23,10 @@ def on_train_end( @param trainer: Pytorch Lightning trainer. @type pl_module: L{pl.LightningModule} @param pl_module: Pytorch Lightning module. - @raises RuntimeError: If no best model path is found. """ - - best_model_path = pl_module.core.get_best_metric_checkpoint_path() - if not best_model_path: - logger.error( - "No model checkpoint found. " - "Make sure that `ModelCheckpoint` callback is present " - "and at least one validation epoch has been performed. " - "Skipping model export." - ) + path = self.get_checkpoint(pl_module) + if path is None: + logger.warning("Skipping model export.") return - pl_module.core.export(weights=best_model_path) + pl_module.core.export(weights=self.get_checkpoint(pl_module)) diff --git a/luxonis_train/callbacks/needs_checkpoint.py b/luxonis_train/callbacks/needs_checkpoint.py new file mode 100644 index 00000000..30355e82 --- /dev/null +++ b/luxonis_train/callbacks/needs_checkpoint.py @@ -0,0 +1,56 @@ +import logging +from typing import Literal + +import lightning.pytorch as pl + +import luxonis_train + +logger = logging.getLogger(__name__) + + +class NeedsCheckpoint(pl.Callback): + def __init__( + self, preferred_checkpoint: Literal["metric", "loss"] = "metric", **kwargs + ): + super().__init__(**kwargs) + self.preferred_checkpoint = preferred_checkpoint + + @staticmethod + def _get_checkpoint( + checkpoint_type: str, + pl_module: "luxonis_train.models.LuxonisLightningModule", + ) -> str | None: + if checkpoint_type == "loss": + path = pl_module.core.get_min_loss_checkpoint_path() + if not path: + logger.error( + "No checkpoint for minimum loss found. " + "Make sure that `ModelCheckpoint` callback is present " + "and at least one validation epoch has been performed." + ) + return path + else: + path = pl_module.core.get_best_metric_checkpoint_path() + if not path: + logger.error( + "No checkpoint for best metric found. " + "Make sure that `ModelCheckpoint` callback is present, " + "at least one validation epoch has been performed and " + "the model has at least one metric." + ) + return path + + def _get_other_type(self, checkpoint_type: str) -> str: + if checkpoint_type == "loss": + return "metric" + return "loss" + + def get_checkpoint( + self, pl_module: "luxonis_train.models.LuxonisLightningModule" + ) -> str | None: + path = self._get_checkpoint(self.preferred_checkpoint, pl_module) + if path is not None: + return path + other_checkpoint = self._get_other_type(self.preferred_checkpoint) + logger.info(f"Attempting to use {other_checkpoint} checkpoint.") + return self._get_checkpoint(other_checkpoint, pl_module) diff --git a/luxonis_train/core/core.py b/luxonis_train/core/core.py index e7bf35a2..2b1607ad 100644 --- a/luxonis_train/core/core.py +++ b/luxonis_train/core/core.py @@ -685,6 +685,8 @@ def get_min_loss_checkpoint_path(self) -> str | None: @rtype: str @return: Path to best checkpoint with respect to minimal validation loss """ + if not self.pl_trainer.checkpoint_callbacks: + return None return self.pl_trainer.checkpoint_callbacks[0].best_model_path # type: ignore @rank_zero_only @@ -694,4 +696,6 @@ def get_best_metric_checkpoint_path(self) -> str | None: @rtype: str @return: Path to best checkpoint with respect to best validation metric """ + if len(self.pl_trainer.checkpoint_callbacks) < 2: + return None return self.pl_trainer.checkpoint_callbacks[1].best_model_path # type: ignore diff --git a/luxonis_train/utils/config.py b/luxonis_train/utils/config.py index 3739c61e..44c00637 100644 --- a/luxonis_train/utils/config.py +++ b/luxonis_train/utils/config.py @@ -77,6 +77,27 @@ class ModelConfig(BaseModelExtraForbid): visualizers: list[AttachedModuleConfig] = [] outputs: list[str] = [] + @model_validator(mode="after") + def check_main_metric(self) -> Self: + for metric in self.metrics: + if metric.is_main_metric: + logger.info(f"Main metric: `{metric.name}`") + return self + + logger.warning("No main metric specified.") + if self.metrics: + metric = self.metrics[0] + metric.is_main_metric = True + name = metric.alias or metric.name + logger.info(f"Setting '{name}' as main metric.") + else: + logger.error( + "No metrics specified. " + "This is likely unintended unless " + "the configuration is not used for training." + ) + return self + @model_validator(mode="after") def check_predefined_model(self) -> Self: from luxonis_train.utils.registry import MODELS @@ -351,12 +372,12 @@ class TunerConfig(BaseModelExtraForbid): class Config(LuxonisConfig): - model: ModelConfig = ModelConfig() - loader: LoaderConfig = LoaderConfig() - tracker: TrackerConfig = TrackerConfig() - trainer: TrainerConfig = TrainerConfig() - exporter: ExportConfig = ExportConfig() - archiver: ArchiveConfig = ArchiveConfig() + model: Annotated[ModelConfig, Field(default_factory=ModelConfig)] + loader: Annotated[LoaderConfig, Field(default_factory=LoaderConfig)] + tracker: Annotated[TrackerConfig, Field(default_factory=TrackerConfig)] + trainer: Annotated[TrainerConfig, Field(default_factory=TrainerConfig)] + exporter: Annotated[ExportConfig, Field(default_factory=ExportConfig)] + archiver: Annotated[ArchiveConfig, Field(default_factory=ArchiveConfig)] tuner: TunerConfig | None = None ENVIRON: Environ = Field(Environ(), exclude=True)