Skip to content

Commit

Permalink
Fail-Proof Checkpoint Usage in Callbacks (#65)
Browse files Browse the repository at this point in the history
  • Loading branch information
kozlov721 authored Aug 27, 2024
1 parent 54fc144 commit 19a5d2b
Show file tree
Hide file tree
Showing 6 changed files with 103 additions and 28 deletions.
3 changes: 3 additions & 0 deletions luxonis_train/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

import typer
import yaml
from luxonis_ml.utils import setup_logging

setup_logging(use_rich=True)


class _ViewType(str, Enum):
Expand Down
16 changes: 6 additions & 10 deletions luxonis_train/callbacks/archive_on_train_end.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
import luxonis_train
from luxonis_train.utils.registry import CALLBACKS

from .needs_checkpoint import NeedsCheckpoint

logger = logging.getLogger(__name__)


@CALLBACKS.register_module()
class ArchiveOnTrainEnd(pl.Callback):
class ArchiveOnTrainEnd(NeedsCheckpoint):
def on_train_end(
self,
_: pl.Trainer,
Expand All @@ -21,17 +23,11 @@ def on_train_end(
@param trainer: Pytorch Lightning trainer.
@type pl_module: L{pl.LightningModule}
@param pl_module: Pytorch Lightning module.
@raises RuntimeError: If no best model path is found.
"""

best_model_path = pl_module.core.get_min_loss_checkpoint_path()
if not best_model_path:
logger.error(
"No best model path found. "
"Please make sure that ModelCheckpoint callback is present "
"and at least one validation epoch has been performed. "
"Skipping model archiving."
)
path = self.get_checkpoint(pl_module)
if path is None:
logger.warning("Skipping model archiving.")
return

onnx_path = pl_module.core._exported_models.get("onnx")
Expand Down
19 changes: 7 additions & 12 deletions luxonis_train/callbacks/export_on_train_end.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
import luxonis_train
from luxonis_train.utils.registry import CALLBACKS

from .needs_checkpoint import NeedsCheckpoint

logger = logging.getLogger(__name__)


@CALLBACKS.register_module()
class ExportOnTrainEnd(pl.Callback):
class ExportOnTrainEnd(NeedsCheckpoint):
def on_train_end(
self,
_: pl.Trainer,
Expand All @@ -21,17 +23,10 @@ def on_train_end(
@param trainer: Pytorch Lightning trainer.
@type pl_module: L{pl.LightningModule}
@param pl_module: Pytorch Lightning module.
@raises RuntimeError: If no best model path is found.
"""

best_model_path = pl_module.core.get_best_metric_checkpoint_path()
if not best_model_path:
logger.error(
"No model checkpoint found. "
"Make sure that `ModelCheckpoint` callback is present "
"and at least one validation epoch has been performed. "
"Skipping model export."
)
path = self.get_checkpoint(pl_module)
if path is None:
logger.warning("Skipping model export.")
return

pl_module.core.export(weights=best_model_path)
pl_module.core.export(weights=self.get_checkpoint(pl_module))
56 changes: 56 additions & 0 deletions luxonis_train/callbacks/needs_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import logging
from typing import Literal

import lightning.pytorch as pl

import luxonis_train

logger = logging.getLogger(__name__)


class NeedsCheckpoint(pl.Callback):
def __init__(
self, preferred_checkpoint: Literal["metric", "loss"] = "metric", **kwargs
):
super().__init__(**kwargs)
self.preferred_checkpoint = preferred_checkpoint

@staticmethod
def _get_checkpoint(
checkpoint_type: str,
pl_module: "luxonis_train.models.LuxonisLightningModule",
) -> str | None:
if checkpoint_type == "loss":
path = pl_module.core.get_min_loss_checkpoint_path()
if not path:
logger.error(
"No checkpoint for minimum loss found. "
"Make sure that `ModelCheckpoint` callback is present "
"and at least one validation epoch has been performed."
)
return path
else:
path = pl_module.core.get_best_metric_checkpoint_path()
if not path:
logger.error(
"No checkpoint for best metric found. "
"Make sure that `ModelCheckpoint` callback is present, "
"at least one validation epoch has been performed and "
"the model has at least one metric."
)
return path

def _get_other_type(self, checkpoint_type: str) -> str:
if checkpoint_type == "loss":
return "metric"
return "loss"

def get_checkpoint(
self, pl_module: "luxonis_train.models.LuxonisLightningModule"
) -> str | None:
path = self._get_checkpoint(self.preferred_checkpoint, pl_module)
if path is not None:
return path
other_checkpoint = self._get_other_type(self.preferred_checkpoint)
logger.info(f"Attempting to use {other_checkpoint} checkpoint.")
return self._get_checkpoint(other_checkpoint, pl_module)
4 changes: 4 additions & 0 deletions luxonis_train/core/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,6 +685,8 @@ def get_min_loss_checkpoint_path(self) -> str | None:
@rtype: str
@return: Path to best checkpoint with respect to minimal validation loss
"""
if not self.pl_trainer.checkpoint_callbacks:
return None
return self.pl_trainer.checkpoint_callbacks[0].best_model_path # type: ignore

@rank_zero_only
Expand All @@ -694,4 +696,6 @@ def get_best_metric_checkpoint_path(self) -> str | None:
@rtype: str
@return: Path to best checkpoint with respect to best validation metric
"""
if len(self.pl_trainer.checkpoint_callbacks) < 2:
return None
return self.pl_trainer.checkpoint_callbacks[1].best_model_path # type: ignore
33 changes: 27 additions & 6 deletions luxonis_train/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,27 @@ class ModelConfig(BaseModelExtraForbid):
visualizers: list[AttachedModuleConfig] = []
outputs: list[str] = []

@model_validator(mode="after")
def check_main_metric(self) -> Self:
for metric in self.metrics:
if metric.is_main_metric:
logger.info(f"Main metric: `{metric.name}`")
return self

logger.warning("No main metric specified.")
if self.metrics:
metric = self.metrics[0]
metric.is_main_metric = True
name = metric.alias or metric.name
logger.info(f"Setting '{name}' as main metric.")
else:
logger.error(
"No metrics specified. "
"This is likely unintended unless "
"the configuration is not used for training."
)
return self

@model_validator(mode="after")
def check_predefined_model(self) -> Self:
from luxonis_train.utils.registry import MODELS
Expand Down Expand Up @@ -351,12 +372,12 @@ class TunerConfig(BaseModelExtraForbid):


class Config(LuxonisConfig):
model: ModelConfig = ModelConfig()
loader: LoaderConfig = LoaderConfig()
tracker: TrackerConfig = TrackerConfig()
trainer: TrainerConfig = TrainerConfig()
exporter: ExportConfig = ExportConfig()
archiver: ArchiveConfig = ArchiveConfig()
model: Annotated[ModelConfig, Field(default_factory=ModelConfig)]
loader: Annotated[LoaderConfig, Field(default_factory=LoaderConfig)]
tracker: Annotated[TrackerConfig, Field(default_factory=TrackerConfig)]
trainer: Annotated[TrainerConfig, Field(default_factory=TrainerConfig)]
exporter: Annotated[ExportConfig, Field(default_factory=ExportConfig)]
archiver: Annotated[ArchiveConfig, Field(default_factory=ArchiveConfig)]
tuner: TunerConfig | None = None
ENVIRON: Environ = Field(Environ(), exclude=True)

Expand Down

0 comments on commit 19a5d2b

Please sign in to comment.