diff --git a/luxonis_train/callbacks/export_on_train_end.py b/luxonis_train/callbacks/export_on_train_end.py index 923267c1..5d7bf6da 100644 --- a/luxonis_train/callbacks/export_on_train_end.py +++ b/luxonis_train/callbacks/export_on_train_end.py @@ -8,6 +8,8 @@ from luxonis_train.utils.registry import CALLBACKS from luxonis_train.utils.tracker import LuxonisTrackerPL +logger = logging.getLogger(__name__) + @CALLBACKS.register_module() class ExportOnTrainEnd(pl.Callback): @@ -41,11 +43,13 @@ def on_train_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> No # NOTE: assume that first checkpoint callback is based on val loss best_model_path = model_checkpoint_callbacks[0].best_model_path if not best_model_path: - raise RuntimeError( - "No best model path found. " - "Please make sure that ModelCheckpoint callback is present " - "and at least one validation epoch has been performed." + logger.error( + "No model checkpoint found. " + "Make sure that `ModelCheckpoint` callback is present " + "and at least one validation epoch has been performed. " + "Skipping model export." ) + return cfg: Config = pl_module.cfg cfg.model.weights = best_model_path if self.upload_to_mlflow: @@ -54,9 +58,9 @@ def on_train_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> No new_upload_url = f"mlflow://{tracker.project_id}/{tracker.run_id}" cfg.exporter.upload_url = new_upload_url else: - logging.getLogger(__name__).warning( + logger.error( "`upload_to_mlflow` is set to True, " - "but there is no MLFlow active run, skipping." + "but there is no MLFlow active run, skipping." ) exporter = Exporter(cfg=cfg) onnx_path = str(Path(best_model_path).parent.with_suffix(".onnx")) diff --git a/luxonis_train/core/core.py b/luxonis_train/core/core.py index 75bd1d2a..86b63600 100644 --- a/luxonis_train/core/core.py +++ b/luxonis_train/core/core.py @@ -79,12 +79,14 @@ def __init__( self.run_save_dir = os.path.join( self.cfg.tracker.save_directory, self.tracker.run_name ) + self.log_file = osp.join(self.run_save_dir, "luxonis_train.log") + # NOTE: to add the file handler (we only get the save dir now, # but we want to use the logger before) reset_logging() setup_logging( use_rich=self.cfg.use_rich_text, - file=osp.join(self.run_save_dir, "luxonis_train.log"), + file=self.log_file, ) # NOTE: overriding logger in pl so it uses our logger to log device info diff --git a/luxonis_train/core/trainer.py b/luxonis_train/core/trainer.py index cb2c5a2c..2b3d6a78 100644 --- a/luxonis_train/core/trainer.py +++ b/luxonis_train/core/trainer.py @@ -3,6 +3,7 @@ from typing import Any, Literal from lightning.pytorch.utilities import rank_zero_only # type: ignore +from luxonis_ml.utils import LuxonisFileSystem from luxonis_train.models import LuxonisModel from luxonis_train.utils.config import Config @@ -39,6 +40,28 @@ def __init__( input_shape=self.loader_train.input_shape, ) + def _upload_logs(self) -> None: + if self.cfg.tracker.is_mlflow: + logger.info("Uploading logs to MLFlow.") + fs = LuxonisFileSystem( + "mlflow://", + allow_active_mlflow_run=True, + allow_local=False, + ) + fs.put_file( + local_path=self.log_file, + remote_path="luxonis_train.log", + mlflow_instance=self.tracker.experiment.get("mlflow", None), + ) + + def _trainer_fit(self, *args, **kwargs): + try: + self.pl_trainer.fit(*args, **kwargs) + except Exception: + logger.exception("Encountered exception during training.") + finally: + self._upload_logs() + def train(self, new_thread: bool = False) -> None: """Runs training. @@ -48,13 +71,14 @@ def train(self, new_thread: bool = False) -> None: if not new_thread: logger.info(f"Checkpoints will be saved in: {self.get_save_dir()}") logger.info("Starting training...") - self.pl_trainer.fit( + self._trainer_fit( self.lightning_module, self.pytorch_loader_train, self.pytorch_loader_val, ) logger.info("Training finished") logger.info(f"Checkpoints saved in: {self.get_save_dir()}") + else: # Every time exception happens in the Thread, this hook will activate def thread_exception_hook(args): @@ -63,7 +87,7 @@ def thread_exception_hook(args): threading.excepthook = thread_exception_hook self.thread = threading.Thread( - target=self.pl_trainer.fit, + target=self._trainer_fit, args=( self.lightning_module, self.pytorch_loader_train, diff --git a/luxonis_train/models/luxonis_model.py b/luxonis_train/models/luxonis_model.py index 88d4fa28..7cd396f9 100644 --- a/luxonis_train/models/luxonis_model.py +++ b/luxonis_train/models/luxonis_model.py @@ -681,7 +681,9 @@ def load_checkpoint(self, path: str | None) -> None: """ if path is None: return + checkpoint = torch.load(path, map_location=self.device) + if "state_dict" not in checkpoint: raise ValueError("Checkpoint does not contain state_dict.") state_dict = {}