Skip to content

Commit

Permalink
cleaned up config, renamed some fields
Browse files Browse the repository at this point in the history
  • Loading branch information
kozlov721 committed Jan 17, 2024
1 parent 99e8493 commit 2e82131
Show file tree
Hide file tree
Showing 23 changed files with 105 additions and 120 deletions.
18 changes: 9 additions & 9 deletions configs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ For list of all nodes, see [nodes](../luxonis_train/nodes/README.md).
| Key | Type | Default value | Description |
| ----------------------- | -------------------- | ------------- | ------------------------------------------------------------------------------------------------------------------------------------------- |
| name | str | | name of the node |
| override_name | str | None | custom name for the node |
| alias | str | None | custom name for the node |
| params | dict | {} | parameters for the node |
| inputs | list | \[\] | list of input nodes for this node, if empty, the node is understood to be an input node of the model |
| freezing.active | bool | False | whether to freeze the modules so the weights are not updated |
Expand All @@ -73,12 +73,12 @@ For list of all nodes, see [nodes](../luxonis_train/nodes/README.md).

Modules that are attached to a node. This include losses, metrics and visualziers.

| Key | Type | Default value | Description |
| ------------- | ---- | ------------- | ------------------------------------------- |
| name | str | | name of the module |
| attached_to | str | | Name of the node the module is attached to. |
| override_name | str | None | custom name for the module |
| params | dict | {} | parameters of the module |
| Key | Type | Default value | Description |
| ----------- | ---- | ------------- | ------------------------------------------- |
| name | str | | name of the module |
| attached_to | str | | Name of the node the module is attached to. |
| alias | str | None | custom name for the module |
| params | dict | {} | parameters of the module |

#### Losses

Expand Down Expand Up @@ -128,9 +128,9 @@ To store and load the data we use LuxonisDataset and LuxonisLoader. For specific

| Key | Type | Default value | Description |
| -------------- | ---------------------------------------- | ------------------- | ---------------------------------------------- |
| dataset_name | str \| None | None | name of the dataset |
| name | str \| None | None | name of the dataset |
| id | str \| None | None | id of the dataset |
| team_id | str \| None | None | team under which you can find all datasets |
| dataset_id | str \| None | None | id of the dataset |
| bucket_type | Literal\["intenal", "external"\] | internal | type of underlying storage |
| bucket_storage | Literal\["local", "s3", "gcc", "azure"\] | BucketStorage.LOCAL | underlying object storage for a bucket |
| train_view | str | train | view to use for training |
Expand Down
2 changes: 1 addition & 1 deletion configs/classification_model.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ model:
include_plot: True

dataset:
dataset_name: cifar10_test
name: cifar10_test

trainer:
preprocessing:
Expand Down
2 changes: 1 addition & 1 deletion configs/coco_model.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ tracker:
is_mlflow: False

dataset:
dataset_name: coco_test
name: coco_test
train_view: train
val_view: val
test_view: test
Expand Down
2 changes: 1 addition & 1 deletion configs/detection_model.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ model:
use_neck: True

dataset:
dataset_name: coco_test
name: coco_test

trainer:
preprocessing:
Expand Down
2 changes: 1 addition & 1 deletion configs/example_export.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ model:
task: binary

dataset:
dataset_name: coco_test
name: coco_test

trainer:
preprocessing:
Expand Down
2 changes: 1 addition & 1 deletion configs/example_tuning.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ model:
task: binary

dataset:
dataset_name: coco_test
name: coco_test

trainer:
preprocessing:
Expand Down
2 changes: 1 addition & 1 deletion configs/keypoint_bbox_model.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ model:
name: KeypointDetectionModel

dataset:
dataset_name: coco_test
name: coco_test

trainer:
preprocessing:
Expand Down
2 changes: 1 addition & 1 deletion configs/segmentation_model.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ model:
task: binary

dataset:
dataset_name: coco_test
name: coco_test

trainer:
preprocessing:
Expand Down
4 changes: 2 additions & 2 deletions luxonis_train/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,9 @@ def inspect(
image_size = cfg.trainer.preprocessing.train_image_size

dataset = LuxonisDataset(
dataset_name=cfg.dataset.dataset_name,
dataset_name=cfg.dataset.name,
team_id=cfg.dataset.team_id,
dataset_id=cfg.dataset.dataset_id,
dataset_id=cfg.dataset.id,
bucket_type=cfg.dataset.bucket_type,
bucket_storage=cfg.dataset.bucket_storage,
)
Expand Down
2 changes: 1 addition & 1 deletion luxonis_train/callbacks/export_on_train_end.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def on_train_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> No
cfg: Config = pl_module.cfg
cfg.model.weights = best_model_path
if self.upload_to_mlflow:
if pl_module.cfg.tracker.is_mlflow:
if cfg.tracker.is_mlflow:
tracker = cast(LuxonisTrackerPL, trainer.logger)
new_upload_directory = f"mlflow://{tracker.project_id}/{tracker.run_id}"
cfg.exporter.upload_directory = new_upload_directory
Expand Down
3 changes: 2 additions & 1 deletion luxonis_train/callbacks/metadata_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pkg_resources
import yaml

from luxonis_train.utils.config import Config
from luxonis_train.utils.registry import CALLBACKS


Expand All @@ -23,7 +24,7 @@ def __init__(self, hyperparams: list[str]):
self.hyperparams = hyperparams

def on_fit_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
cfg = pl_module.cfg
cfg: Config = pl_module.cfg

hparams = {key: cfg.get(key) for key in self.hyperparams}

Expand Down
28 changes: 15 additions & 13 deletions luxonis_train/callbacks/test_on_train_end.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from luxonis_ml.data import LuxonisDataset, ValAugmentations
from torch.utils.data import DataLoader

from luxonis_train.utils.config import Config
from luxonis_train.utils.loaders import LuxonisLoaderTorch, collate_fn
from luxonis_train.utils.registry import CALLBACKS

Expand All @@ -11,31 +12,32 @@ class TestOnTrainEnd(pl.Callback):
"""Callback to perform a test run at the end of the training."""

def on_train_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
cfg: Config = pl_module.cfg

dataset = LuxonisDataset(
dataset_name=pl_module.cfg.dataset.dataset_name,
team_id=pl_module.cfg.dataset.team_id,
dataset_id=pl_module.cfg.dataset.dataset_id,
bucket_type=pl_module.cfg.dataset.bucket_type,
bucket_storage=pl_module.cfg.dataset.bucket_storage,
dataset_name=cfg.dataset.name,
team_id=cfg.dataset.team_id,
dataset_id=cfg.dataset.id,
bucket_type=cfg.dataset.bucket_type,
bucket_storage=cfg.dataset.bucket_storage,
)

loader_test = LuxonisLoaderTorch(
dataset,
view=pl_module.cfg.dataset.test_view,
view=cfg.dataset.test_view,
augmentations=ValAugmentations(
image_size=pl_module.cfg.trainer.preprocessing.train_image_size,
image_size=cfg.trainer.preprocessing.train_image_size,
augmentations=[
i.model_dump()
for i in pl_module.cfg.trainer.preprocessing.augmentations
i.model_dump() for i in cfg.trainer.preprocessing.augmentations
],
train_rgb=pl_module.cfg.trainer.preprocessing.train_rgb,
keep_aspect_ratio=pl_module.cfg.trainer.preprocessing.keep_aspect_ratio,
train_rgb=cfg.trainer.preprocessing.train_rgb,
keep_aspect_ratio=cfg.trainer.preprocessing.keep_aspect_ratio,
),
)
pytorch_loader_test = DataLoader(
loader_test,
batch_size=pl_module.cfg.trainer.batch_size,
num_workers=pl_module.cfg.trainer.num_workers,
batch_size=cfg.trainer.batch_size,
num_workers=cfg.trainer.num_workers,
collate_fn=collate_fn,
)
trainer.test(pl_module, pytorch_loader_test)
4 changes: 2 additions & 2 deletions luxonis_train/core/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,9 @@ def __init__(
callbacks=LuxonisProgressBar() if self.cfg.use_rich_text else None,
)
self.dataset = LuxonisDataset(
dataset_name=self.cfg.dataset.dataset_name,
dataset_name=self.cfg.dataset.name,
team_id=self.cfg.dataset.team_id,
dataset_id=self.cfg.dataset.dataset_id,
dataset_id=self.cfg.dataset.id,
bucket_type=self.cfg.dataset.bucket_type,
bucket_storage=self.cfg.dataset.bucket_storage,
)
Expand Down
17 changes: 10 additions & 7 deletions luxonis_train/core/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,22 @@ def __init__(self, cfg: str | dict, args: list[str] | tuple[str, ...] | None):
used for config overriding.
"""
super().__init__(cfg, args)
if self.cfg.tuner is None:
raise ValueError("You have to specify the `tuner` section in config.")
self.tune_cfg = self.cfg.tuner

def tune(self) -> None:
"""Runs Optuna tunning of hyperparameters."""

pruner = (
optuna.pruners.MedianPruner()
if self.cfg.tuner.use_pruner
if self.tune_cfg.use_pruner
else optuna.pruners.NopPruner()
)

storage = None
if self.cfg.tuner.storage.active:
if self.cfg.tuner.storage.storage_type == "local":
if self.tune_cfg.storage.active:
if self.tune_cfg.storage.storage_type == "local":
storage = "sqlite:///study_local.db"
else:
storage = "postgresql://{}:{}@{}:{}/{}".format(
Expand All @@ -50,7 +53,7 @@ def tune(self) -> None:
)

study = optuna.create_study(
study_name=self.cfg.tuner.study_name,
study_name=self.tune_cfg.study_name,
storage=storage,
direction="minimize",
pruner=pruner,
Expand All @@ -59,8 +62,8 @@ def tune(self) -> None:

study.optimize(
self._objective,
n_trials=self.cfg.tuner.n_trials,
timeout=self.cfg.tuner.timeout,
n_trials=self.tune_cfg.n_trials,
timeout=self.tune_cfg.timeout,
)

def _objective(self, trial: optuna.trial.Trial) -> float:
Expand Down Expand Up @@ -128,7 +131,7 @@ def _objective(self, trial: optuna.trial.Trial) -> float:

def _get_trial_params(self, trial: optuna.trial.Trial) -> dict[str, Any]:
"""Get trial params based on specified config."""
cfg_tuner = self.cfg.tuner.params
cfg_tuner = self.tune_cfg.params
new_params = {}
for key, value in cfg_tuner.items():
key_info = key.split("_")
Expand Down
4 changes: 2 additions & 2 deletions luxonis_train/models/luxonis_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def __init__(
for node_cfg in self.cfg.model.nodes:
node_name = node_cfg.name
Node = BaseNode.REGISTRY.get(node_name)
node_name = node_cfg.override_name or node_name
node_name = node_cfg.alias or node_name
if node_cfg.freezing.active:
epochs = self.cfg.trainer.epochs
if node_cfg.freezing.unfreeze_after is None:
Expand Down Expand Up @@ -714,7 +714,7 @@ def _init_attached_module(
storage: Mapping[str, Mapping[str, BaseAttachedModule]],
) -> tuple[str, str]:
Module = registry.get(cfg.name)
module_name = cfg.override_name or cfg.name
module_name = cfg.alias or cfg.name
node_name = cfg.attached_to
module = Module(**cfg.params, node=self.nodes[node_name])
storage[node_name][module_name] = module # type: ignore
Expand Down
12 changes: 6 additions & 6 deletions luxonis_train/models/predefined_models/classification_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,13 @@ def nodes(self) -> list[ModelNodeConfig]:
return [
ModelNodeConfig(
name=self.backbone,
override_name="classification_backbone",
alias="classification_backbone",
freezing=self.backbone_params.pop("freezing", {}),
params=self.backbone_params,
),
ModelNodeConfig(
name="ClassificationHead",
override_name="classification_head",
alias="classification_head",
inputs=["classification_backbone"],
freezing=self.head_params.pop("freezing", {}),
params=self.head_params,
Expand All @@ -44,7 +44,7 @@ def losses(self) -> list[LossModuleConfig]:
return [
LossModuleConfig(
name="CrossEntropyLoss",
override_name="classification_loss",
alias="classification_loss",
attached_to="classification_head",
params=self.loss_params,
weight=1.0,
Expand All @@ -56,20 +56,20 @@ def metrics(self) -> list[MetricModuleConfig]:
return [
MetricModuleConfig(
name="F1Score",
override_name="classification_f1_score",
alias="classification_f1_score",
is_main_metric=True,
attached_to="classification_head",
params={"task": self.task},
),
MetricModuleConfig(
name="Accuracy",
override_name="classification_accuracy",
alias="classification_accuracy",
attached_to="classification_head",
params={"task": self.task},
),
MetricModuleConfig(
name="Recall",
override_name="classification_recall",
alias="classification_recall",
attached_to="classification_head",
params={"task": self.task},
),
Expand Down
12 changes: 6 additions & 6 deletions luxonis_train/models/predefined_models/detection_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def nodes(self) -> list[ModelNodeConfig]:
nodes = [
ModelNodeConfig(
name="EfficientRep",
override_name="detection_backbone",
alias="detection_backbone",
freezing=self.backbone_params.pop("freezing", {}),
params=self.backbone_params,
),
Expand All @@ -34,7 +34,7 @@ def nodes(self) -> list[ModelNodeConfig]:
nodes.append(
ModelNodeConfig(
name="RepPANNeck",
override_name="detection_neck",
alias="detection_neck",
inputs=["detection_backbone"],
freezing=self.neck_params.pop("freezing", {}),
params=self.neck_params,
Expand All @@ -44,7 +44,7 @@ def nodes(self) -> list[ModelNodeConfig]:
nodes.append(
ModelNodeConfig(
name="EfficientBBoxHead",
override_name="detection_head",
alias="detection_head",
freezing=self.head_params.pop("freezing", {}),
inputs=["detection_neck"] if self.use_neck else ["detection_backbone"],
params=self.head_params,
Expand All @@ -57,7 +57,7 @@ def losses(self) -> list[LossModuleConfig]:
return [
LossModuleConfig(
name="AdaptiveDetectionLoss",
override_name="detection_loss",
alias="detection_loss",
attached_to="detection_head",
params=self.loss_params,
weight=1.0,
Expand All @@ -69,7 +69,7 @@ def metrics(self) -> list[MetricModuleConfig]:
return [
MetricModuleConfig(
name="MeanAveragePrecision",
override_name="detection_map",
alias="detection_map",
attached_to="detection_head",
is_main_metric=True,
),
Expand All @@ -80,7 +80,7 @@ def visualizers(self) -> list[AttachedModuleConfig]:
return [
AttachedModuleConfig(
name="BBoxVisualizer",
override_name="detection_visualizer",
alias="detection_visualizer",
attached_to="detection_head",
params=self.visualizer_params,
)
Expand Down
Loading

0 comments on commit 2e82131

Please sign in to comment.