Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Uploading logs to MLFlow #16

Merged
merged 12 commits into from
Feb 28, 2024
16 changes: 10 additions & 6 deletions luxonis_train/callbacks/export_on_train_end.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from luxonis_train.utils.registry import CALLBACKS
from luxonis_train.utils.tracker import LuxonisTrackerPL

logger = logging.getLogger(__name__)


@CALLBACKS.register_module()
class ExportOnTrainEnd(pl.Callback):
Expand Down Expand Up @@ -41,11 +43,13 @@ def on_train_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> No
# NOTE: assume that first checkpoint callback is based on val loss
best_model_path = model_checkpoint_callbacks[0].best_model_path
if not best_model_path:
raise RuntimeError(
"No best model path found. "
"Please make sure that ModelCheckpoint callback is present "
"and at least one validation epoch has been performed."
logger.error(
"No model checkpoint found. "
"Make sure that `ModelCheckpoint` callback is present "
"and at least one validation epoch has been performed. "
"Skipping model export."
)
return
cfg: Config = pl_module.cfg
cfg.model.weights = best_model_path
if self.upload_to_mlflow:
Expand All @@ -54,9 +58,9 @@ def on_train_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> No
new_upload_url = f"mlflow://{tracker.project_id}/{tracker.run_id}"
cfg.exporter.upload_url = new_upload_url
else:
logging.getLogger(__name__).warning(
logger.error(
"`upload_to_mlflow` is set to True, "
"but there is no MLFlow active run, skipping."
"but there is no MLFlow active run, skipping."
)
exporter = Exporter(cfg=cfg)
onnx_path = str(Path(best_model_path).parent.with_suffix(".onnx"))
Expand Down
4 changes: 3 additions & 1 deletion luxonis_train/core/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,14 @@ def __init__(
self.run_save_dir = os.path.join(
self.cfg.tracker.save_directory, self.tracker.run_name
)
self.log_file = osp.join(self.run_save_dir, "luxonis_train.log")

# NOTE: to add the file handler (we only get the save dir now,
# but we want to use the logger before)
reset_logging()
setup_logging(
use_rich=self.cfg.use_rich_text,
file=osp.join(self.run_save_dir, "luxonis_train.log"),
file=self.log_file,
)

# NOTE: overriding logger in pl so it uses our logger to log device info
Expand Down
28 changes: 26 additions & 2 deletions luxonis_train/core/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any, Literal

from lightning.pytorch.utilities import rank_zero_only # type: ignore
from luxonis_ml.utils import LuxonisFileSystem

from luxonis_train.models import LuxonisModel
from luxonis_train.utils.config import Config
Expand Down Expand Up @@ -39,6 +40,28 @@ def __init__(
input_shape=self.loader_train.input_shape,
)

def _upload_logs(self) -> None:
if self.cfg.tracker.is_mlflow:
logger.info("Uploading logs to MLFlow.")
self.fs = LuxonisFileSystem(
"mlflow://",
allow_active_mlflow_run=True,
allow_local=False,
)
self.fs.put_file(
local_path=self.log_file,
remote_path="luxonis_train.log",
mlflow_instance=self.tracker.experiment.get("mlflow", None),
)

def _trainer_fit(self, *args, **kwargs):
try:
self.pl_trainer.fit(*args, **kwargs)
except Exception:
logger.exception("Encountered exception during training.")
finally:
self._upload_logs()

def train(self, new_thread: bool = False) -> None:
"""Runs training.

Expand All @@ -48,13 +71,14 @@ def train(self, new_thread: bool = False) -> None:
if not new_thread:
logger.info(f"Checkpoints will be saved in: {self.get_save_dir()}")
logger.info("Starting training...")
self.pl_trainer.fit(
self._trainer_fit(
self.lightning_module,
self.pytorch_loader_train,
self.pytorch_loader_val,
)
logger.info("Training finished")
logger.info(f"Checkpoints saved in: {self.get_save_dir()}")

else:
# Every time exception happens in the Thread, this hook will activate
def thread_exception_hook(args):
Expand All @@ -63,7 +87,7 @@ def thread_exception_hook(args):
threading.excepthook = thread_exception_hook

self.thread = threading.Thread(
target=self.pl_trainer.fit,
target=self._trainer_fit,
args=(
self.lightning_module,
self.pytorch_loader_train,
Expand Down
2 changes: 2 additions & 0 deletions luxonis_train/models/luxonis_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,7 +681,9 @@ def load_checkpoint(self, path: str | None) -> None:
"""
if path is None:
return

checkpoint = torch.load(path, map_location=self.device)

if "state_dict" not in checkpoint:
raise ValueError("Checkpoint does not contain state_dict.")
state_dict = {}
Expand Down
Loading