From 2c654a51b887aab0dda8449fe377154f09bd6b27 Mon Sep 17 00:00:00 2001
From: KlemenSkrlj <47853619+klemen1999@users.noreply.github.com>
Date: Tue, 20 Feb 2024 03:44:24 +0100
Subject: [PATCH] MLFlow Upload Fix (#10)

* fixed incorrect class property call

* fixed exporter uploading

* uploadCheckpoint uploads on every checkpoint epoch

* fix temp files names

* updated callback readme

* pre-commit run
---
 luxonis_train/callbacks/README.md             |  9 +++
 luxonis_train/callbacks/__init__.py           |  4 +-
 .../callbacks/export_on_train_end.py          |  4 +-
 luxonis_train/callbacks/upload_checkpoint.py  | 61 +++++++++++++++++++
 .../upload_checkpoint_on_train_end.py         | 41 -------------
 luxonis_train/core/exporter.py                |  6 +-
 6 files changed, 78 insertions(+), 47 deletions(-)
 create mode 100644 luxonis_train/callbacks/upload_checkpoint.py
 delete mode 100644 luxonis_train/callbacks/upload_checkpoint_on_train_end.py

diff --git a/luxonis_train/callbacks/README.md b/luxonis_train/callbacks/README.md
index d8e3da74..be441017 100644
--- a/luxonis_train/callbacks/README.md
+++ b/luxonis_train/callbacks/README.md
@@ -9,6 +9,7 @@ List of all supported callbacks.
 - [LuxonisProgressBar](#luxonisprogressbar)
 - [MetadataLogger](#metadatalogger)
 - [TestOnTrainEnd](#testontrainend)
+- [UploadCheckpoint](#uploadcheckpoint)
 
 ## PytorchLightning Callbacks
 
@@ -51,3 +52,11 @@ Metadata include all defined hyperparameters together with git hashes of `luxoni
 ## TestOnTrainEnd
 
 Callback to perform a test run at the end of the training.
+
+## UploadCheckpoint
+
+Callback that uploads currently best checkpoint (based on validation loss) to specified cloud directory after every validation epoch.
+
+| Key              | Type | Default value | Description                                                                                                                   |
+| ---------------- | ---- | ------------- | ----------------------------------------------------------------------------------------------------------------------------- |
+| upload_directory | str  | /             | Path to cloud directory where checkpoints should be uploaded to. If you want to use current mlflow run set it to `mlflow://`. |
diff --git a/luxonis_train/callbacks/__init__.py b/luxonis_train/callbacks/__init__.py
index 4be94600..cec9e000 100644
--- a/luxonis_train/callbacks/__init__.py
+++ b/luxonis_train/callbacks/__init__.py
@@ -13,7 +13,7 @@
 from .metadata_logger import MetadataLogger
 from .module_freezer import ModuleFreezer
 from .test_on_train_end import TestOnTrainEnd
-from .upload_checkpoint_on_train_end import UploadCheckpointOnTrainEnd
+from .upload_checkpoint import UploadCheckpoint
 
 CALLBACKS.register_module(module=EarlyStopping)
 CALLBACKS.register_module(module=LearningRateMonitor)
@@ -28,5 +28,5 @@
     "MetadataLogger",
     "ModuleFreezer",
     "TestOnTrainEnd",
-    "UploadCheckpointOnTrainEnd",
+    "UploadCheckpoint",
 ]
diff --git a/luxonis_train/callbacks/export_on_train_end.py b/luxonis_train/callbacks/export_on_train_end.py
index de5fde88..923267c1 100644
--- a/luxonis_train/callbacks/export_on_train_end.py
+++ b/luxonis_train/callbacks/export_on_train_end.py
@@ -51,8 +51,8 @@ def on_train_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> No
         if self.upload_to_mlflow:
             if cfg.tracker.is_mlflow:
                 tracker = cast(LuxonisTrackerPL, trainer.logger)
-                new_upload_directory = f"mlflow://{tracker.project_id}/{tracker.run_id}"
-                cfg.exporter.upload_directory = new_upload_directory
+                new_upload_url = f"mlflow://{tracker.project_id}/{tracker.run_id}"
+                cfg.exporter.upload_url = new_upload_url
             else:
                 logging.getLogger(__name__).warning(
                     "`upload_to_mlflow` is set to True, "
diff --git a/luxonis_train/callbacks/upload_checkpoint.py b/luxonis_train/callbacks/upload_checkpoint.py
new file mode 100644
index 00000000..a0fa137a
--- /dev/null
+++ b/luxonis_train/callbacks/upload_checkpoint.py
@@ -0,0 +1,61 @@
+import logging
+import os
+from typing import Any
+
+import lightning.pytorch as pl
+import torch
+from luxonis_ml.utils.filesystem import LuxonisFileSystem
+
+from luxonis_train.utils.registry import CALLBACKS
+
+
+@CALLBACKS.register_module()
+class UploadCheckpoint(pl.Callback):
+    """Callback that uploads best checkpoint based on the validation loss."""
+
+    def __init__(self, upload_directory: str):
+        """Constructs `UploadCheckpoint`.
+
+        @type upload_directory: str
+        @param upload_directory: Path used as upload directory
+        """
+        super().__init__()
+        self.fs = LuxonisFileSystem(
+            upload_directory, allow_active_mlflow_run=True, allow_local=False
+        )
+        self.logger = logging.getLogger(__name__)
+        self.last_logged_epoch = None
+        self.last_best_checkpoint = None
+
+    def on_save_checkpoint(
+        self,
+        trainer: pl.Trainer,
+        pl_module: pl.LightningModule,
+        checkpoint: dict[str, Any],
+    ) -> None:
+        # Log only once per epoch in case there are multiple ModelCheckpoint callbacks
+        if not self.last_logged_epoch == trainer.current_epoch:
+            model_checkpoint_callbacks = [
+                c
+                for c in trainer.callbacks  # type: ignore
+                if isinstance(c, pl.callbacks.ModelCheckpoint)  # type: ignore
+            ]
+            # NOTE: assume that first checkpoint callback is based on val loss
+            curr_best_checkpoint = model_checkpoint_callbacks[0].best_model_path
+
+            if self.last_best_checkpoint != curr_best_checkpoint:
+                self.logger.info(f"Started checkpoint upload to {self.fs.full_path}...")
+                temp_filename = "curr_best_val_loss.ckpt"
+                torch.save(checkpoint, temp_filename)
+                self.fs.put_file(
+                    local_path=temp_filename,
+                    remote_path=temp_filename,
+                    mlflow_instance=trainer.logger.experiment.get(  # type: ignore
+                        "mlflow", None
+                    ),
+                )
+                os.remove(temp_filename)
+                self.logger.info("Checkpoint upload finished")
+                self.last_best_checkpoint = curr_best_checkpoint
+
+            self.last_logged_epoch = trainer.current_epoch
diff --git a/luxonis_train/callbacks/upload_checkpoint_on_train_end.py b/luxonis_train/callbacks/upload_checkpoint_on_train_end.py
deleted file mode 100644
index 86879ec9..00000000
--- a/luxonis_train/callbacks/upload_checkpoint_on_train_end.py
+++ /dev/null
@@ -1,41 +0,0 @@
-import logging
-
-import lightning.pytorch as pl
-from luxonis_ml.utils.filesystem import LuxonisFileSystem
-
-from luxonis_train.utils.registry import CALLBACKS
-
-
-@CALLBACKS.register_module()
-class UploadCheckpointOnTrainEnd(pl.Callback):
-    """Callback that uploads best checkpoint based on the validation loss."""
-
-    def __init__(self, upload_directory: str):
-        """Constructs `UploadCheckpointOnTrainEnd`.
-
-        @type upload_directory: str
-        @param upload_directory: Path used as upload directory
-        """
-        super().__init__()
-        self.fs = LuxonisFileSystem(
-            upload_directory, allow_active_mlflow_run=True, allow_local=False
-        )
-
-    def on_train_end(self, trainer: pl.Trainer, _: pl.LightningModule) -> None:
-        logger = logging.getLogger(__name__)
-        logger.info(f"Started checkpoint upload to {self.fs.full_path()}...")
-        model_checkpoint_callbacks = [
-            c
-            for c in trainer.callbacks  # type: ignore
-            if isinstance(c, pl.callbacks.ModelCheckpoint)  # type: ignore
-        ]
-        # NOTE: assume that first checkpoint callback is based on val loss
-        local_path = model_checkpoint_callbacks[0].best_model_path
-        self.fs.put_file(
-            local_path=local_path,
-            remote_path=local_path.split("/")[-1],
-            mlflow_instance=trainer.logger.experiment.get(  # type: ignore
-                "mlflow", None
-            ),
-        )
-        logger.info("Checkpoint upload finished")
diff --git a/luxonis_train/core/exporter.py b/luxonis_train/core/exporter.py
index ab73ce72..7ed94f45 100644
--- a/luxonis_train/core/exporter.py
+++ b/luxonis_train/core/exporter.py
@@ -200,7 +200,7 @@ def _upload(self, files_to_upload: list[str]):
                 remote_path=self.cfg.exporter.export_model_name + suffix,
             )
 
-        with tempfile.TemporaryFile() as f:
+        with tempfile.NamedTemporaryFile(prefix="config", suffix=".yaml") as f:
             self.cfg.save_data(f.name)
             fs.put_file(local_path=f.name, remote_path="config.yaml")
 
@@ -209,7 +209,9 @@ def _upload(self, files_to_upload: list[str]):
         )
         modelconverter_config = self._get_modelconverter_config(onnx_path)
 
-        with tempfile.TemporaryFile() as f:
+        with tempfile.NamedTemporaryFile(
+            prefix="config_export", suffix=".yaml", mode="w+"
+        ) as f:
             yaml.dump(modelconverter_config, f, default_flow_style=False)
             fs.put_file(local_path=f.name, remote_path="config_export.yaml")