From d0740d09f4cde19aaba8c3541c852a59a502e80e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Martin=20Kozlovsk=C3=BD?= Date: Mon, 15 Apr 2024 20:22:14 +0200 Subject: [PATCH] SIGTERM Handling (#21) * handling SIGTERM signal * resume argument takes path --- luxonis_train/__main__.py | 10 ++++-- .../callbacks/luxonis_progress_bar.py | 2 +- luxonis_train/core/trainer.py | 36 ++++++++++++++++++- 3 files changed, 44 insertions(+), 4 deletions(-) diff --git a/luxonis_train/__main__.py b/luxonis_train/__main__.py index b1fd3971..94276b60 100644 --- a/luxonis_train/__main__.py +++ b/luxonis_train/__main__.py @@ -45,11 +45,17 @@ def __str__(self): @app.command() -def train(config: ConfigType = None, opts: OptsType = None): +def train( + config: ConfigType = None, + resume: Annotated[ + Optional[str], typer.Option(help="Resume training from this checkpoint.") + ] = None, + opts: OptsType = None, +): """Start training.""" from luxonis_train.core import Trainer - Trainer(str(config), opts).train() + Trainer(str(config), opts, resume=resume).train() @app.command() diff --git a/luxonis_train/callbacks/luxonis_progress_bar.py b/luxonis_train/callbacks/luxonis_progress_bar.py index fcc130cd..16d173e7 100644 --- a/luxonis_train/callbacks/luxonis_progress_bar.py +++ b/luxonis_train/callbacks/luxonis_progress_bar.py @@ -28,7 +28,7 @@ def get_metrics( ) -> dict[str, int | str | float | dict[str, float]]: # NOTE: there might be a cleaner way of doing this items = super().get_metrics(trainer, pl_module) - if trainer.training: + if trainer.training and pl_module.training_step_outputs: items["Loss"] = pl_module.training_step_outputs[-1]["loss"].item() return items diff --git a/luxonis_train/core/trainer.py b/luxonis_train/core/trainer.py index 2b3d6a78..8326ce48 100644 --- a/luxonis_train/core/trainer.py +++ b/luxonis_train/core/trainer.py @@ -1,3 +1,5 @@ +import os.path as osp +import signal import threading from logging import getLogger from typing import Any, Literal @@ -21,6 +23,7 @@ def __init__( self, cfg: str | dict[str, Any] | Config, opts: list[str] | tuple[str, ...] | dict[str, Any] | None = None, + resume: str | None = None, ): """Constructs a new Trainer instance. @@ -30,9 +33,17 @@ def __init__( @type opts: list[str] | tuple[str, ...] | dict[str, Any] | None @param opts: Argument dict provided through command line, used for config overriding. + + @type resume: str | None + @param resume: Training will resume from this checkpoint. """ super().__init__(cfg, opts) + if resume is not None: + self.resume = str(LuxonisFileSystem.download(resume, self.run_save_dir)) + else: + self.resume = None + self.lightning_module = LuxonisModel( cfg=self.cfg, dataset_metadata=self.dataset_metadata, @@ -40,6 +51,29 @@ def __init__( input_shape=self.loader_train.input_shape, ) + def graceful_exit(signum, frame): + logger.info("SIGTERM received, stopping training...") + ckpt_path = osp.join(self.run_save_dir, "resume.ckpt") + self.pl_trainer.save_checkpoint(ckpt_path) + self._upload_logs() + + if self.cfg.tracker.is_mlflow: + logger.info("Uploading checkpoint to MLFlow.") + fs = LuxonisFileSystem( + "mlflow://", + allow_active_mlflow_run=True, + allow_local=False, + ) + fs.put_file( + local_path=ckpt_path, + remote_path="resume.ckpt", + mlflow_instance=self.tracker.experiment.get("mlflow", None), + ) + + exit(0) + + signal.signal(signal.SIGTERM, graceful_exit) + def _upload_logs(self) -> None: if self.cfg.tracker.is_mlflow: logger.info("Uploading logs to MLFlow.") @@ -56,7 +90,7 @@ def _upload_logs(self) -> None: def _trainer_fit(self, *args, **kwargs): try: - self.pl_trainer.fit(*args, **kwargs) + self.pl_trainer.fit(*args, ckpt_path=self.resume, **kwargs) except Exception: logger.exception("Encountered exception during training.") finally: