-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
51 changed files
with
1,374 additions
and
568 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
|
||
model: | ||
name: resnet50_classification | ||
nodes: | ||
- name: ResNet | ||
params: | ||
variant: "50" | ||
download_weights: True | ||
|
||
- name: ClassificationHead | ||
inputs: | ||
- ResNet | ||
|
||
losses: | ||
- name: CrossEntropyLoss | ||
attached_to: ClassificationHead | ||
|
||
metrics: | ||
- name: Accuracy | ||
is_main_metric: true | ||
attached_to: ClassificationHead | ||
|
||
visualizers: | ||
- name: ClassificationVisualizer | ||
attached_to: ClassificationHead | ||
params: | ||
font_scale: 0.5 | ||
color: [255, 0, 0] | ||
thickness: 2 | ||
include_plot: True | ||
|
||
dataset: | ||
name: cifar10_test | ||
|
||
trainer: | ||
batch_size: 4 | ||
epochs: &epochs 200 | ||
num_workers: 4 | ||
validation_interval: 10 | ||
num_log_images: 8 | ||
|
||
preprocessing: | ||
train_image_size: [&height 224, &width 224] | ||
keep_aspect_ratio: False | ||
normalize: | ||
active: True | ||
|
||
callbacks: | ||
- name: ExportOnTrainEnd | ||
- name: TestOnTrainEnd | ||
|
||
optimizer: | ||
name: SGD | ||
params: | ||
lr: 0.02 | ||
|
||
scheduler: | ||
name: ConstantLR |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import logging | ||
import os | ||
from pathlib import Path | ||
from typing import cast | ||
|
||
import lightning.pytorch as pl | ||
|
||
from luxonis_train.utils.config import Config | ||
from luxonis_train.utils.registry import CALLBACKS | ||
from luxonis_train.utils.tracker import LuxonisTrackerPL | ||
|
||
|
||
@CALLBACKS.register_module() | ||
class ArchiveOnTrainEnd(pl.Callback): | ||
def __init__(self, upload_to_mlflow: bool = False): | ||
"""Callback that performs archiving of onnx or exported model at the end of | ||
training/export. TODO: description. | ||
@type upload_to_mlflow: bool | ||
@param upload_to_mlflow: If set to True, overrides the upload url in Archiver | ||
with currently active MLFlow run (if present). | ||
""" | ||
super().__init__() | ||
self.upload_to_mlflow = upload_to_mlflow | ||
|
||
def on_train_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: | ||
"""Archives the model on train end. | ||
@type trainer: L{pl.Trainer} | ||
@param trainer: Pytorch Lightning trainer. | ||
@type pl_module: L{pl.LightningModule} | ||
@param pl_module: Pytorch Lightning module. | ||
@raises RuntimeError: If no best model path is found. | ||
""" | ||
from luxonis_train.core.archiver import Archiver | ||
|
||
model_checkpoint_callbacks = [ | ||
c | ||
for c in trainer.callbacks # type: ignore | ||
if isinstance(c, pl.callbacks.ModelCheckpoint) # type: ignore | ||
] | ||
|
||
# NOTE: assume that first checkpoint callback is based on val loss | ||
best_model_path = model_checkpoint_callbacks[0].best_model_path | ||
if not best_model_path: | ||
raise RuntimeError( | ||
"No best model path found. " | ||
"Please make sure that ModelCheckpoint callback is present " | ||
"and at least one validation epoch has been performed." | ||
) | ||
cfg: Config = pl_module.cfg | ||
cfg.model.weights = best_model_path | ||
if self.upload_to_mlflow: | ||
if cfg.tracker.is_mlflow: | ||
tracker = cast(LuxonisTrackerPL, trainer.logger) | ||
new_upload_url = f"mlflow://{tracker.project_id}/{tracker.run_id}" | ||
cfg.archiver.upload_url = new_upload_url | ||
else: | ||
logging.getLogger(__name__).warning( | ||
"`upload_to_mlflow` is set to True, " | ||
"but there is no MLFlow active run, skipping." | ||
) | ||
|
||
onnx_path = str(Path(best_model_path).parent.with_suffix(".onnx")) | ||
if not os.path.exists(onnx_path): | ||
raise FileNotFoundError( | ||
"Model executable not found. Make sure to run exporter callback before archiver callback" | ||
) | ||
|
||
archiver = Archiver(cfg=cfg) | ||
|
||
archiver.archive(onnx_path) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.