Skip to content

Commit

Permalink
SIGTERM Handling (#21)
Browse files Browse the repository at this point in the history
* handling SIGTERM signal

* resume argument takes path
  • Loading branch information
kozlov721 committed Oct 9, 2024
1 parent 9a72eb2 commit d0740d0
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 4 deletions.
10 changes: 8 additions & 2 deletions luxonis_train/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,17 @@ def __str__(self):


@app.command()
def train(config: ConfigType = None, opts: OptsType = None):
def train(
config: ConfigType = None,
resume: Annotated[
Optional[str], typer.Option(help="Resume training from this checkpoint.")
] = None,
opts: OptsType = None,
):
"""Start training."""
from luxonis_train.core import Trainer

Trainer(str(config), opts).train()
Trainer(str(config), opts, resume=resume).train()


@app.command()
Expand Down
2 changes: 1 addition & 1 deletion luxonis_train/callbacks/luxonis_progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def get_metrics(
) -> dict[str, int | str | float | dict[str, float]]:
# NOTE: there might be a cleaner way of doing this
items = super().get_metrics(trainer, pl_module)
if trainer.training:
if trainer.training and pl_module.training_step_outputs:
items["Loss"] = pl_module.training_step_outputs[-1]["loss"].item()
return items

Expand Down
36 changes: 35 additions & 1 deletion luxonis_train/core/trainer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os.path as osp
import signal
import threading
from logging import getLogger
from typing import Any, Literal
Expand All @@ -21,6 +23,7 @@ def __init__(
self,
cfg: str | dict[str, Any] | Config,
opts: list[str] | tuple[str, ...] | dict[str, Any] | None = None,
resume: str | None = None,
):
"""Constructs a new Trainer instance.
Expand All @@ -30,16 +33,47 @@ def __init__(
@type opts: list[str] | tuple[str, ...] | dict[str, Any] | None
@param opts: Argument dict provided through command line,
used for config overriding.
@type resume: str | None
@param resume: Training will resume from this checkpoint.
"""
super().__init__(cfg, opts)

if resume is not None:
self.resume = str(LuxonisFileSystem.download(resume, self.run_save_dir))
else:
self.resume = None

self.lightning_module = LuxonisModel(
cfg=self.cfg,
dataset_metadata=self.dataset_metadata,
save_dir=self.run_save_dir,
input_shape=self.loader_train.input_shape,
)

def graceful_exit(signum, frame):
logger.info("SIGTERM received, stopping training...")
ckpt_path = osp.join(self.run_save_dir, "resume.ckpt")
self.pl_trainer.save_checkpoint(ckpt_path)
self._upload_logs()

if self.cfg.tracker.is_mlflow:
logger.info("Uploading checkpoint to MLFlow.")
fs = LuxonisFileSystem(
"mlflow://",
allow_active_mlflow_run=True,
allow_local=False,
)
fs.put_file(
local_path=ckpt_path,
remote_path="resume.ckpt",
mlflow_instance=self.tracker.experiment.get("mlflow", None),
)

exit(0)

signal.signal(signal.SIGTERM, graceful_exit)

def _upload_logs(self) -> None:
if self.cfg.tracker.is_mlflow:
logger.info("Uploading logs to MLFlow.")
Expand All @@ -56,7 +90,7 @@ def _upload_logs(self) -> None:

def _trainer_fit(self, *args, **kwargs):
try:
self.pl_trainer.fit(*args, **kwargs)
self.pl_trainer.fit(*args, ckpt_path=self.resume, **kwargs)
except Exception:
logger.exception("Encountered exception during training.")
finally:
Expand Down

0 comments on commit d0740d0

Please sign in to comment.