Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated Tuner #26

Merged
merged 31 commits into from
Jun 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
7f6aa1f
updated optuna to newer version, added continue_existing_study parame…
klemen1999 May 7, 2024
e66d325
added logging to Tuner
klemen1999 May 13, 2024
02e6e99
Merge branch 'dev' into fix/tuner
klemen1999 May 13, 2024
074214a
formatting
kozlov721 May 15, 2024
b167cab
fixed nested view for mlflow
klemen1999 May 16, 2024
bc9560e
Merge branch 'fix/tuner' of https://github.com/luxonis/luxonis-train …
klemen1999 May 16, 2024
98a6309
Merge branch 'dev' into fix/tuner
klemen1999 May 16, 2024
75221ee
removed unused code
klemen1999 May 16, 2024
be26b89
removed unused code, added note
klemen1999 May 19, 2024
2cf7898
added finalize() to LuxonisTrackerPL for graceful tracker exit
klemen1999 May 23, 2024
88b5eab
Merge branch 'dev' into fix/tuner
kozlov721 May 23, 2024
4041c99
Multi-input test case - building a complex multi-input POSET model wi…
CaptainTrojan Jun 1, 2024
6809f98
Added a test for new collate_fn, which collates dicts of tensors inst…
CaptainTrojan Jun 1, 2024
f400246
Added a config for a multi-input model
CaptainTrojan Jun 1, 2024
f546608
Added necessary changes to config: ModelNodeConfig has a parameter in…
CaptainTrojan Jun 1, 2024
769aaab
Updated input_shape type and the return type of loader, the LuxonisLo…
CaptainTrojan Jun 1, 2024
9cf7ebb
Added images_name property to BaseNode class, removed redundant batch…
CaptainTrojan Jun 1, 2024
4284fba
Implemented multi-input support in node building and model export
CaptainTrojan Jun 1, 2024
0e1ceff
Compatibility changes due to the new way shapes work in loaders, maki…
CaptainTrojan Jun 1, 2024
951b981
[Automated] Updated coverage badge
actions-user Jun 1, 2024
1cdcb9f
Moved 'images_name' setting from loader implementation to config due …
CaptainTrojan Jun 9, 2024
6c69875
Merge branch 'dev' into feature/multi-input
kozlov721 Jun 12, 2024
6b07cd8
removed macos tests
kozlov721 Jun 12, 2024
63c6ad0
removed images_name from BaseNode
kozlov721 Jun 13, 2024
a4906e7
renamed get_shape_pocket to to_shape_pocket
kozlov721 Jun 13, 2024
25d1758
simplified input source handling
kozlov721 Jun 13, 2024
9987d0d
renamed images_name to image_source
kozlov721 Jun 13, 2024
1f47e12
Merge branch 'feature/multi-input' into fix/tuner
kozlov721 Jun 13, 2024
eef4203
Merge branch 'fix/tuner' of github.com:luxonis/luxonis-train into fix…
kozlov721 Jun 13, 2024
590bd31
docformat
kozlov721 Jun 13, 2024
6a9bceb
Merge branch 'dev' into fix/tuner
kozlov721 Jun 13, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions configs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -241,13 +241,14 @@ Option specific for ONNX export.

Here you can specify options for tuning.

| Key | Type | Default value | Description |
| ---------- | ----------------- | ------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
| study_name | str | "test-study" | Name of the study. |
| use_pruner | bool | True | Whether to use the MedianPruner. |
| n_trials | int \| None | 15 | Number of trials for each process. `None` represents no limit in terms of numbner of trials. |
| timeout | int \| None | None | Stop study after the given number of seconds. |
| params | dict\[str, list\] | {} | Which parameters to tune. The keys should be in the format `key1.key2.key3_<type>`. Type can be one of `[categorical, float, int, longuniform, uniform]`. For more information about the types, visit [Optuna documentation](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.trial.Trial.html). |
| Key | Type | Default value | Description |
| ----------------------- | ----------------- | ------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
| study_name | str | "test-study" | Name of the study. |
| continue_existing_study | bool | True | Weather to continue existing study if `study_name` already exists. |
| use_pruner | bool | True | Whether to use the MedianPruner. |
| n_trials | int \| None | 15 | Number of trials for each process. `None` represents no limit in terms of numbner of trials. |
| timeout | int \| None | None | Stop study after the given number of seconds. |
| params | dict\[str, list\] | {} | Which parameters to tune. The keys should be in the format `key1.key2.key3_<type>`. Type can be one of `[categorical, float, int, longuniform, uniform]`. For more information about the types, visit [Optuna documentation](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.trial.Trial.html). |

Example of params for tuner block:

Expand Down
2 changes: 1 addition & 1 deletion configs/example_tuning.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ trainer:
active: True

batch_size: 4
epochs: &epochs 1
epochs: &epochs 10
validation_interval: 1
num_log_images: 8

Expand Down
73 changes: 58 additions & 15 deletions luxonis_train/core/tuner.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os.path as osp
from logging import getLogger
from typing import Any

import lightning.pytorch as pl
Expand All @@ -13,6 +14,8 @@

from .core import Core

logger = getLogger(__name__)


class Tuner(Core):
def __init__(self, cfg: str | dict, args: list[str] | tuple[str, ...] | None):
Expand All @@ -30,8 +33,26 @@ def __init__(self, cfg: str | dict, args: list[str] | tuple[str, ...] | None):
raise ValueError("You have to specify the `tuner` section in config.")
self.tune_cfg = self.cfg.tuner

# Parent tracker that only logs the best study parameters at the end
rank = rank_zero_only.rank
cfg_tracker = self.cfg.tracker
tracker_params = cfg_tracker.model_dump()
tracker_params[
"is_wandb"
] = False # wandb doesn't allow multiple concurrent runs, handle this separately
self.parent_tracker = LuxonisTrackerPL(
rank=rank,
mlflow_tracking_uri=self.cfg.ENVIRON.MLFLOW_TRACKING_URI,
is_sweep=False,
**tracker_params,
)
if self.parent_tracker.is_mlflow:
# Experiment needs to be interacted with to create actual MLFlow run
self.parent_tracker.experiment["mlflow"].active_run()

def tune(self) -> None:
"""Runs Optuna tunning of hyperparameters."""
logger.info("Starting tuning...")

pruner = (
optuna.pruners.MedianPruner()
Expand All @@ -57,7 +78,7 @@ def tune(self) -> None:
storage=storage,
direction="minimize",
pruner=pruner,
load_if_exists=True,
load_if_exists=self.tune_cfg.continue_existing_study,
)

study.optimize(
Expand All @@ -66,25 +87,44 @@ def tune(self) -> None:
timeout=self.tune_cfg.timeout,
)

best_study_params = study.best_params
logger.info(f"Best study parameters: {best_study_params}")

self.parent_tracker.log_hyperparams(best_study_params)

if self.cfg.tracker.is_wandb:
# If wandb used then init parent tracker separately at the end
wandb_parent_tracker = LuxonisTrackerPL(
project_name=self.cfg.tracker.project_name,
project_id=self.cfg.tracker.project_id,
run_name=self.parent_tracker.run_name,
save_directory=self.cfg.tracker.save_directory,
is_wandb=True,
wandb_entity=self.cfg.tracker.wandb_entity,
rank=rank_zero_only.rank,
)
wandb_parent_tracker.log_hyperparams(best_study_params)

def _objective(self, trial: optuna.trial.Trial) -> float:
"""Objective function used to optimize Optuna study."""
rank = rank_zero_only.rank
cfg_tracker = self.cfg.tracker
tracker_params = cfg_tracker.model_dump()
tracker = LuxonisTrackerPL(
child_tracker = LuxonisTrackerPL(
rank=rank,
mlflow_tracking_uri=self.cfg.ENVIRON.MLFLOW_TRACKING_URI,
is_sweep=True,
**tracker_params,
)
run_save_dir = osp.join(cfg_tracker.save_directory, tracker.run_name)

run_save_dir = osp.join(cfg_tracker.save_directory, child_tracker.run_name)

curr_params = self._get_trial_params(trial)
curr_params["model.predefined_model"] = None
Config.clear_instance()
cfg = Config.get_config(self.cfg.model_dump(), curr_params)

tracker.log_hyperparams(curr_params)
child_tracker.log_hyperparams(curr_params)

cfg.save_data(osp.join(run_save_dir, "config.yaml"))

Expand All @@ -95,14 +135,11 @@ def _objective(self, trial: optuna.trial.Trial) -> float:
input_shape=self.loaders["train"].input_shape,
)
lightning_module._core = self
pruner_callback = PyTorchLightningPruningCallback(
trial, monitor="val_loss/loss"
)
callbacks: list[pl.Callback] = (
[LuxonisProgressBar()] if self.cfg.use_rich_text else []
)
pruner_callback = PyTorchLightningPruningCallback(trial, monitor="val/loss")
callbacks.append(pruner_callback)

deterministic = False
if self.cfg.trainer.seed:
pl.seed_everything(cfg.trainer.seed, workers=True)
Expand All @@ -112,7 +149,7 @@ def _objective(self, trial: optuna.trial.Trial) -> float:
accelerator=cfg.trainer.accelerator,
devices=cfg.trainer.devices,
strategy=cfg.trainer.strategy,
logger=tracker, # type: ignore
logger=child_tracker, # type: ignore
max_epochs=cfg.trainer.epochs,
accumulate_grad_batches=cfg.trainer.accumulate_grad_batches,
check_val_every_n_epoch=cfg.trainer.validation_interval,
Expand All @@ -122,12 +159,18 @@ def _objective(self, trial: optuna.trial.Trial) -> float:
deterministic=deterministic,
)

pl_trainer.fit(
lightning_module, # type: ignore
self.pytorch_loaders["train"],
self.pytorch_loaders["val"],
)
pruner_callback.check_pruned()
try:
pl_trainer.fit(
lightning_module, # type: ignore
self.pytorch_loaders["val"],
self.pytorch_loaders["train"],
)

pruner_callback.check_pruned()

except optuna.TrialPruned as e:
# Pruning is done by raising an error
logger.info(e)

if "val/loss" not in pl_trainer.callback_metrics:
raise ValueError(
Expand Down
1 change: 1 addition & 0 deletions luxonis_train/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,7 @@ class StorageConfig(CustomBaseModel):

class TunerConfig(CustomBaseModel):
study_name: str = "test-study"
continue_existing_study: bool = True
use_pruner: bool = True
n_trials: int | None = 15
timeout: int | None = None
Expand Down
22 changes: 21 additions & 1 deletion luxonis_train/utils/tracker.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,28 @@
from lightning.pytorch.loggers.logger import Logger
from lightning.pytorch.utilities import rank_zero_only # type: ignore
from luxonis_ml.tracker import LuxonisTracker


class LuxonisTrackerPL(LuxonisTracker, Logger):
"""Implementation of LuxonisTracker that is compatible with PytorchLightning."""

...
@rank_zero_only
def finalize(self, status: str = "success") -> None:
"""Finalizes current run."""
if self.is_tensorboard:
self.experiment["tensorboard"].flush()
self.experiment["tensorboard"].close()
if self.is_mlflow:
if status == "success":
mlflow_status = "FINISHED"
elif status == "failed":
mlflow_status = "FAILED"
elif status == "finished":
mlflow_status = "FINISHED"
self.experiment["mlflow"].end_run(mlflow_status)
if self.is_wandb:
if status == "success":
wandb_status = 0
else:
wandb_status = 1
self.experiment["wandb"].finish(wandb_status)
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ luxonis-ml[all]@git+https://github.com/luxonis/luxonis-ml.git@dev
onnx>=1.12.0
onnxruntime>=1.13.1
onnxsim>=0.4.10
optuna>=3.2.0
optuna_integration>=3.6.0
optuna>=3.6.0
optuna-integration>=3.6.0
parameterized>=0.9.0
psycopg2-binary>=2.9.1
pycocotools>=2.0.7
Expand Down
Loading