Skip to content

Commit

Permalink
Uploading logs to MLFlow (#16)
Browse files Browse the repository at this point in the history
* upload logs to mlflow

* added mlflwo instance

* multithread log upload

* fixed upload logs

* fixed log file path

* removed exceptions

* logging exceptions

* fixed typo

* reverted exception

* moved line

* replaced warning with error log

* Update trainer.py
  • Loading branch information
kozlov721 committed Oct 9, 2024
1 parent be74f5c commit 27f9ab2
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 9 deletions.
16 changes: 10 additions & 6 deletions luxonis_train/callbacks/export_on_train_end.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from luxonis_train.utils.registry import CALLBACKS
from luxonis_train.utils.tracker import LuxonisTrackerPL

logger = logging.getLogger(__name__)


@CALLBACKS.register_module()
class ExportOnTrainEnd(pl.Callback):
Expand Down Expand Up @@ -41,11 +43,13 @@ def on_train_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> No
# NOTE: assume that first checkpoint callback is based on val loss
best_model_path = model_checkpoint_callbacks[0].best_model_path
if not best_model_path:
raise RuntimeError(
"No best model path found. "
"Please make sure that ModelCheckpoint callback is present "
"and at least one validation epoch has been performed."
logger.error(
"No model checkpoint found. "
"Make sure that `ModelCheckpoint` callback is present "
"and at least one validation epoch has been performed. "
"Skipping model export."
)
return
cfg: Config = pl_module.cfg
cfg.model.weights = best_model_path
if self.upload_to_mlflow:
Expand All @@ -54,9 +58,9 @@ def on_train_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> No
new_upload_url = f"mlflow://{tracker.project_id}/{tracker.run_id}"
cfg.exporter.upload_url = new_upload_url
else:
logging.getLogger(__name__).warning(
logger.error(
"`upload_to_mlflow` is set to True, "
"but there is no MLFlow active run, skipping."
"but there is no MLFlow active run, skipping."
)
exporter = Exporter(cfg=cfg)
onnx_path = str(Path(best_model_path).parent.with_suffix(".onnx"))
Expand Down
4 changes: 3 additions & 1 deletion luxonis_train/core/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,14 @@ def __init__(
self.run_save_dir = os.path.join(
self.cfg.tracker.save_directory, self.tracker.run_name
)
self.log_file = osp.join(self.run_save_dir, "luxonis_train.log")

# NOTE: to add the file handler (we only get the save dir now,
# but we want to use the logger before)
reset_logging()
setup_logging(
use_rich=self.cfg.use_rich_text,
file=osp.join(self.run_save_dir, "luxonis_train.log"),
file=self.log_file,
)

# NOTE: overriding logger in pl so it uses our logger to log device info
Expand Down
28 changes: 26 additions & 2 deletions luxonis_train/core/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any, Literal

from lightning.pytorch.utilities import rank_zero_only # type: ignore
from luxonis_ml.utils import LuxonisFileSystem

from luxonis_train.models import LuxonisModel
from luxonis_train.utils.config import Config
Expand Down Expand Up @@ -39,6 +40,28 @@ def __init__(
input_shape=self.loader_train.input_shape,
)

def _upload_logs(self) -> None:
if self.cfg.tracker.is_mlflow:
logger.info("Uploading logs to MLFlow.")
fs = LuxonisFileSystem(
"mlflow://",
allow_active_mlflow_run=True,
allow_local=False,
)
fs.put_file(
local_path=self.log_file,
remote_path="luxonis_train.log",
mlflow_instance=self.tracker.experiment.get("mlflow", None),
)

def _trainer_fit(self, *args, **kwargs):
try:
self.pl_trainer.fit(*args, **kwargs)
except Exception:
logger.exception("Encountered exception during training.")
finally:
self._upload_logs()

def train(self, new_thread: bool = False) -> None:
"""Runs training.
Expand All @@ -48,13 +71,14 @@ def train(self, new_thread: bool = False) -> None:
if not new_thread:
logger.info(f"Checkpoints will be saved in: {self.get_save_dir()}")
logger.info("Starting training...")
self.pl_trainer.fit(
self._trainer_fit(
self.lightning_module,
self.pytorch_loader_train,
self.pytorch_loader_val,
)
logger.info("Training finished")
logger.info(f"Checkpoints saved in: {self.get_save_dir()}")

else:
# Every time exception happens in the Thread, this hook will activate
def thread_exception_hook(args):
Expand All @@ -63,7 +87,7 @@ def thread_exception_hook(args):
threading.excepthook = thread_exception_hook

self.thread = threading.Thread(
target=self.pl_trainer.fit,
target=self._trainer_fit,
args=(
self.lightning_module,
self.pytorch_loader_train,
Expand Down
2 changes: 2 additions & 0 deletions luxonis_train/models/luxonis_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,7 +681,9 @@ def load_checkpoint(self, path: str | None) -> None:
"""
if path is None:
return

checkpoint = torch.load(path, map_location=self.device)

if "state_dict" not in checkpoint:
raise ValueError("Checkpoint does not contain state_dict.")
state_dict = {}
Expand Down

0 comments on commit 27f9ab2

Please sign in to comment.