-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Generate NN archive from training configs (#17)
* add archiver CLI * add archiver callback * add max_det parameter to EfficientBBoxHead * add enum to categorize tasks for the implemented heads * add archiver tests * adjust Archiver to new nn archive format * pre-comit formatting * add LDF creation and adjust to new nn archive format * update requirements.txt * add opencv-python to requirements.txt * add support for ImplicitKeypointBBoxHead * remove support for ObjectDetectionSSD * Update requirements.txt * Added mlflow and removed opencv * [Automated] Updated coverage badge * add support for SegmentationHead and BiSeNetHead * base archiver tests on model from luxonis-train instead of torchvision * adjust head parameters to changes in NN Archive * adjust keypoint detection head parameters to changes in NN Archive * bugfix - make sure self.max_det is used in nms * add max_det parameter to ImplicitKeypointBBoxHead * adjust task categorization for ImplicitKeypointBBoxHead * fixing Windows PermissionError occuring on file deletion * fixing Windows PermissionError occuring on file deletion due to unreleased logging handlers * add method to remove file handlers keeping the log file open * add a logging statement at the end of archiving * add optuna_integration to requirements.txt * add hard-coded solution to determining is_softmax parameter * added help --------- Co-authored-by: Martin Kozlovský <[email protected]> Co-authored-by: GitHub Actions <[email protected]>
- Loading branch information
1 parent
f42192c
commit e1ab39b
Showing
14 changed files
with
668 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import logging | ||
import os | ||
from pathlib import Path | ||
from typing import cast | ||
|
||
import lightning.pytorch as pl | ||
|
||
from luxonis_train.utils.config import Config | ||
from luxonis_train.utils.registry import CALLBACKS | ||
from luxonis_train.utils.tracker import LuxonisTrackerPL | ||
|
||
|
||
@CALLBACKS.register_module() | ||
class ArchiveOnTrainEnd(pl.Callback): | ||
def __init__(self, upload_to_mlflow: bool = False): | ||
"""Callback that performs archiving of onnx or exported model at the end of | ||
training/export. TODO: description. | ||
@type upload_to_mlflow: bool | ||
@param upload_to_mlflow: If set to True, overrides the upload url in Archiver | ||
with currently active MLFlow run (if present). | ||
""" | ||
super().__init__() | ||
self.upload_to_mlflow = upload_to_mlflow | ||
|
||
def on_train_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: | ||
"""Archives the model on train end. | ||
@type trainer: L{pl.Trainer} | ||
@param trainer: Pytorch Lightning trainer. | ||
@type pl_module: L{pl.LightningModule} | ||
@param pl_module: Pytorch Lightning module. | ||
@raises RuntimeError: If no best model path is found. | ||
""" | ||
from luxonis_train.core.archiver import Archiver | ||
|
||
model_checkpoint_callbacks = [ | ||
c | ||
for c in trainer.callbacks # type: ignore | ||
if isinstance(c, pl.callbacks.ModelCheckpoint) # type: ignore | ||
] | ||
|
||
# NOTE: assume that first checkpoint callback is based on val loss | ||
best_model_path = model_checkpoint_callbacks[0].best_model_path | ||
if not best_model_path: | ||
raise RuntimeError( | ||
"No best model path found. " | ||
"Please make sure that ModelCheckpoint callback is present " | ||
"and at least one validation epoch has been performed." | ||
) | ||
cfg: Config = pl_module.cfg | ||
cfg.model.weights = best_model_path | ||
if self.upload_to_mlflow: | ||
if cfg.tracker.is_mlflow: | ||
tracker = cast(LuxonisTrackerPL, trainer.logger) | ||
new_upload_url = f"mlflow://{tracker.project_id}/{tracker.run_id}" | ||
cfg.archiver.upload_url = new_upload_url | ||
else: | ||
logging.getLogger(__name__).warning( | ||
"`upload_to_mlflow` is set to True, " | ||
"but there is no MLFlow active run, skipping." | ||
) | ||
|
||
onnx_path = str(Path(best_model_path).parent.with_suffix(".onnx")) | ||
if not os.path.exists(onnx_path): | ||
raise FileNotFoundError( | ||
"Model executable not found. Make sure to run exporter callback before archiver callback" | ||
) | ||
|
||
archiver = Archiver(cfg=cfg) | ||
|
||
archiver.archive(onnx_path) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
from .archiver import Archiver | ||
from .exporter import Exporter | ||
from .inferer import Inferer | ||
from .trainer import Trainer | ||
from .tuner import Tuner | ||
|
||
__all__ = ["Exporter", "Trainer", "Tuner", "Inferer"] | ||
__all__ = ["Exporter", "Trainer", "Tuner", "Inferer", "Archiver"] |
Oops, something went wrong.