From 7f6aa1f76aa1a9f6c6886f75e55da22ce3530f43 Mon Sep 17 00:00:00 2001 From: klemen1999 Date: Tue, 7 May 2024 09:16:15 +0200 Subject: [PATCH 01/23] updated optuna to newer version, added continue_existing_study parameter to config --- configs/README.md | 1 + luxonis_train/core/tuner.py | 6 ++---- luxonis_train/utils/config.py | 1 + requirements.txt | 3 ++- 4 files changed, 6 insertions(+), 5 deletions(-) diff --git a/configs/README.md b/configs/README.md index 27e2fb6e..8422cc7d 100644 --- a/configs/README.md +++ b/configs/README.md @@ -242,6 +242,7 @@ Here you can specify options for tuning. | Key | Type | Default value | Description | | ---------- | ----------------- | ------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | | study_name | str | "test-study" | Name of the study. | +| continue_existing_study | bool | True | Weather to continue existing study if `study_name` already exists. | use_pruner | bool | True | Whether to use the MedianPruner. | | n_trials | int \| None | 15 | Number of trials for each process. `None` represents no limit in terms of numbner of trials. | | timeout | int \| None | None | Stop study after the given number of seconds. | diff --git a/luxonis_train/core/tuner.py b/luxonis_train/core/tuner.py index c9f8e151..b5e61632 100644 --- a/luxonis_train/core/tuner.py +++ b/luxonis_train/core/tuner.py @@ -57,7 +57,7 @@ def tune(self) -> None: storage=storage, direction="minimize", pruner=pruner, - load_if_exists=True, + load_if_exists=self.tune_cfg.continue_existing_study, ) study.optimize( @@ -94,9 +94,7 @@ def _objective(self, trial: optuna.trial.Trial) -> float: save_dir=run_save_dir, input_shape=self.loader_train.input_shape, ) - pruner_callback = PyTorchLightningPruningCallback( - trial, monitor="val_loss/loss" - ) + pruner_callback = PyTorchLightningPruningCallback(trial, monitor="val/loss") callbacks: list[pl.Callback] = ( [LuxonisProgressBar()] if self.cfg.use_rich_text else [] ) diff --git a/luxonis_train/utils/config.py b/luxonis_train/utils/config.py index 48661f7d..b15f407a 100644 --- a/luxonis_train/utils/config.py +++ b/luxonis_train/utils/config.py @@ -272,6 +272,7 @@ class StorageConfig(BaseModel): class TunerConfig(BaseModel): study_name: str = "test-study" + continue_existing_study: bool = True use_pruner: bool = True n_trials: int | None = 15 timeout: int | None = None diff --git a/requirements.txt b/requirements.txt index eecf828e..6e2a8714 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,7 +4,8 @@ luxonis-ml[all]>=0.0.1 onnx>=1.12.0 onnxruntime>=1.13.1 onnxsim>=0.4.10 -optuna>=3.2.0 +optuna>=3.6.0 +optuna-integration>=3.6.0 psycopg2-binary>=2.9.1 pycocotools>=2.0.7 rich>=13.0.0 From e66d325de840a46b8d8cc4421a9387bc56fa7552 Mon Sep 17 00:00:00 2001 From: klemen1999 Date: Mon, 13 May 2024 22:33:44 +0200 Subject: [PATCH 02/23] added logging to Tuner --- luxonis_train/core/tuner.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/luxonis_train/core/tuner.py b/luxonis_train/core/tuner.py index b5e61632..cdad5cc2 100644 --- a/luxonis_train/core/tuner.py +++ b/luxonis_train/core/tuner.py @@ -1,5 +1,6 @@ import os.path as osp from typing import Any +from logging import getLogger import lightning.pytorch as pl import optuna @@ -13,6 +14,8 @@ from .core import Core +logger = getLogger(__name__) + class Tuner(Core): def __init__(self, cfg: str | dict, args: list[str] | tuple[str, ...] | None): @@ -32,6 +35,7 @@ def __init__(self, cfg: str | dict, args: list[str] | tuple[str, ...] | None): def tune(self) -> None: """Runs Optuna tunning of hyperparameters.""" + logger.info("Starting tuning...") pruner = ( optuna.pruners.MedianPruner() @@ -66,6 +70,8 @@ def tune(self) -> None: timeout=self.tune_cfg.timeout, ) + logger.info(f"Best study parameters: {study.best_params}") + def _objective(self, trial: optuna.trial.Trial) -> float: """Objective function used to optimize Optuna study.""" rank = rank_zero_only.rank From 074214ac444d223afe79481a624cbb5455896f35 Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Wed, 15 May 2024 20:56:37 +0200 Subject: [PATCH 03/23] formatting --- configs/README.md | 16 ++++++++-------- luxonis_train/core/tuner.py | 2 +- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/configs/README.md b/configs/README.md index 627aa173..76dba3e8 100644 --- a/configs/README.md +++ b/configs/README.md @@ -240,14 +240,14 @@ Option specific for ONNX export. Here you can specify options for tuning. -| Key | Type | Default value | Description | -| ---------- | ----------------- | ------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | -| study_name | str | "test-study" | Name of the study. | -| continue_existing_study | bool | True | Weather to continue existing study if `study_name` already exists. -| use_pruner | bool | True | Whether to use the MedianPruner. | -| n_trials | int \| None | 15 | Number of trials for each process. `None` represents no limit in terms of numbner of trials. | -| timeout | int \| None | None | Stop study after the given number of seconds. | -| params | dict\[str, list\] | {} | Which parameters to tune. The keys should be in the format `key1.key2.key3_`. Type can be one of `[categorical, float, int, longuniform, uniform]`. For more information about the types, visit [Optuna documentation](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.trial.Trial.html). | +| Key | Type | Default value | Description | +| ----------------------- | ----------------- | ------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | +| study_name | str | "test-study" | Name of the study. | +| continue_existing_study | bool | True | Weather to continue existing study if `study_name` already exists. | +| use_pruner | bool | True | Whether to use the MedianPruner. | +| n_trials | int \| None | 15 | Number of trials for each process. `None` represents no limit in terms of numbner of trials. | +| timeout | int \| None | None | Stop study after the given number of seconds. | +| params | dict\[str, list\] | {} | Which parameters to tune. The keys should be in the format `key1.key2.key3_`. Type can be one of `[categorical, float, int, longuniform, uniform]`. For more information about the types, visit [Optuna documentation](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.trial.Trial.html). | Example of params for tuner block: diff --git a/luxonis_train/core/tuner.py b/luxonis_train/core/tuner.py index cdad5cc2..4fd074f6 100644 --- a/luxonis_train/core/tuner.py +++ b/luxonis_train/core/tuner.py @@ -1,6 +1,6 @@ import os.path as osp -from typing import Any from logging import getLogger +from typing import Any import lightning.pytorch as pl import optuna From b167cab114ce33247b1af65a0cc688f4055619d9 Mon Sep 17 00:00:00 2001 From: klemen1999 Date: Thu, 16 May 2024 14:33:07 +0200 Subject: [PATCH 04/23] fixed nested view for mlflow --- configs/example_tuning.yaml | 2 +- luxonis_train/core/tuner.py | 74 +++++++++++++++++++++++++++++++------ 2 files changed, 63 insertions(+), 13 deletions(-) diff --git a/configs/example_tuning.yaml b/configs/example_tuning.yaml index 980036ae..422ad728 100755 --- a/configs/example_tuning.yaml +++ b/configs/example_tuning.yaml @@ -22,7 +22,7 @@ trainer: active: True batch_size: 4 - epochs: &epochs 1 + epochs: &epochs 10 validation_interval: 1 num_log_images: 8 diff --git a/luxonis_train/core/tuner.py b/luxonis_train/core/tuner.py index cdad5cc2..36fcc863 100644 --- a/luxonis_train/core/tuner.py +++ b/luxonis_train/core/tuner.py @@ -33,6 +33,20 @@ def __init__(self, cfg: str | dict, args: list[str] | tuple[str, ...] | None): raise ValueError("You have to specify the `tuner` section in config.") self.tune_cfg = self.cfg.tuner + # Parent tracker that only logs the best study parameters at the end + rank = rank_zero_only.rank + cfg_tracker = self.cfg.tracker + tracker_params = cfg_tracker.model_dump() + self.parent_tracker = LuxonisTrackerPL( + rank=rank, + mlflow_tracking_uri=self.cfg.ENVIRON.MLFLOW_TRACKING_URI, + is_sweep=False, + **tracker_params, + ) + if self.parent_tracker.is_mlflow: + run = self.parent_tracker.experiment["mlflow"].active_run() + self.parent_run_id = run.info.run_id + def tune(self) -> None: """Runs Optuna tunning of hyperparameters.""" logger.info("Starting tuning...") @@ -70,27 +84,30 @@ def tune(self) -> None: timeout=self.tune_cfg.timeout, ) - logger.info(f"Best study parameters: {study.best_params}") + best_study_params = study.best_params + logger.info(f"Best study parameters: {best_study_params}") + self.parent_tracker.log_hyperparams(best_study_params) def _objective(self, trial: optuna.trial.Trial) -> float: """Objective function used to optimize Optuna study.""" rank = rank_zero_only.rank cfg_tracker = self.cfg.tracker tracker_params = cfg_tracker.model_dump() - tracker = LuxonisTrackerPL( + child_tracker = LuxonisTrackerPL( rank=rank, mlflow_tracking_uri=self.cfg.ENVIRON.MLFLOW_TRACKING_URI, is_sweep=True, **tracker_params, ) - run_save_dir = osp.join(cfg_tracker.save_directory, tracker.run_name) + + run_save_dir = osp.join(cfg_tracker.save_directory, child_tracker.run_name) curr_params = self._get_trial_params(trial) curr_params["model.predefined_model"] = None Config.clear_instance() cfg = Config.get_config(self.cfg.model_dump(), curr_params) - tracker.log_hyperparams(curr_params) + child_tracker.log_hyperparams(curr_params) cfg.save_data(osp.join(run_save_dir, "config.yaml")) @@ -100,16 +117,20 @@ def _objective(self, trial: optuna.trial.Trial) -> float: save_dir=run_save_dir, input_shape=self.loader_train.input_shape, ) - pruner_callback = PyTorchLightningPruningCallback(trial, monitor="val/loss") callbacks: list[pl.Callback] = ( [LuxonisProgressBar()] if self.cfg.use_rich_text else [] ) + pruner_callback = PyTorchLightningPruningCallback(trial, monitor="val/loss") callbacks.append(pruner_callback) + + tracker_end_run = TrackerEndRun() + callbacks.append(tracker_end_run) + pl_trainer = pl.Trainer( accelerator=cfg.trainer.accelerator, devices=cfg.trainer.devices, strategy=cfg.trainer.strategy, - logger=tracker, # type: ignore + logger=child_tracker, # type: ignore max_epochs=cfg.trainer.epochs, accumulate_grad_batches=cfg.trainer.accumulate_grad_batches, check_val_every_n_epoch=cfg.trainer.validation_interval, @@ -118,12 +139,20 @@ def _objective(self, trial: optuna.trial.Trial) -> float: callbacks=callbacks, ) - pl_trainer.fit( - lightning_module, # type: ignore - self.pytorch_loader_train, - self.pytorch_loader_val, - ) - pruner_callback.check_pruned() + try: + pl_trainer.fit( + lightning_module, # type: ignore + self.pytorch_loader_train, + self.pytorch_loader_val, + ) + + pruner_callback.check_pruned() + + except optuna.TrialPruned as e: + # Pruning is done by raising an error + # When .fit() errors out we have to gracefully also end the trackers + tracker_end_run.end_trackers(child_tracker) + logger.info(e) if "val/loss" not in pl_trainer.callback_metrics: raise ValueError( @@ -174,3 +203,24 @@ def _get_trial_params(self, trial: optuna.trial.Trial) -> dict[str, Any]: "No paramteres to tune. Specify them under `tuner.params`." ) return new_params + + +class TrackerEndRun(pl.Callback): + """Callback that ends trackers of child processes during tuning study""" + + def teardown( + self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str + ) -> None: + self.end_trackers(trainer.logger) # type: ignore + return super().teardown(trainer, pl_module, stage) + + def end_trackers(self, tracker: LuxonisTrackerPL) -> None: + """Ends WandB and MLFlow trackers + + Args: + tracker (LuxonisTrackerPL): Currently active tracker + """ + if tracker.is_wandb: + tracker.experiment["wandb"].finish() + if tracker.is_mlflow: + tracker.experiment["mlflow"].end_run() From 75221eefeea7cf091b58c3965d0556a096b9889d Mon Sep 17 00:00:00 2001 From: klemen1999 Date: Thu, 16 May 2024 20:01:13 +0200 Subject: [PATCH 05/23] removed unused code --- luxonis_train/core/tuner.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/luxonis_train/core/tuner.py b/luxonis_train/core/tuner.py index dc4411f5..eb072611 100644 --- a/luxonis_train/core/tuner.py +++ b/luxonis_train/core/tuner.py @@ -43,9 +43,6 @@ def __init__(self, cfg: str | dict, args: list[str] | tuple[str, ...] | None): is_sweep=False, **tracker_params, ) - if self.parent_tracker.is_mlflow: - run = self.parent_tracker.experiment["mlflow"].active_run() - self.parent_run_id = run.info.run_id def tune(self) -> None: """Runs Optuna tunning of hyperparameters.""" From be26b897cf8643952f6b8b75d4c806f516997650 Mon Sep 17 00:00:00 2001 From: klemen1999 Date: Sun, 19 May 2024 22:25:28 +0200 Subject: [PATCH 06/23] removed unused code, added note --- luxonis_train/core/tuner.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/luxonis_train/core/tuner.py b/luxonis_train/core/tuner.py index eb072611..48dcbd82 100644 --- a/luxonis_train/core/tuner.py +++ b/luxonis_train/core/tuner.py @@ -43,6 +43,9 @@ def __init__(self, cfg: str | dict, args: list[str] | tuple[str, ...] | None): is_sweep=False, **tracker_params, ) + if self.parent_tracker.is_mlflow: + # Experiment needs to be interacted with to create actual MLFlow run + run = self.parent_tracker.experiment["mlflow"].active_run() def tune(self) -> None: """Runs Optuna tunning of hyperparameters.""" From 2cf7898582bc78b5351c8e7daf09edbe200f9038 Mon Sep 17 00:00:00 2001 From: klemen1999 Date: Thu, 23 May 2024 09:33:23 +0200 Subject: [PATCH 07/23] added finalize() to LuxonisTrackerPL for graceful tracker exit --- luxonis_train/core/tuner.py | 44 +++++++++++++--------------------- luxonis_train/utils/tracker.py | 22 ++++++++++++++++- 2 files changed, 38 insertions(+), 28 deletions(-) diff --git a/luxonis_train/core/tuner.py b/luxonis_train/core/tuner.py index 48dcbd82..c0dbd63e 100644 --- a/luxonis_train/core/tuner.py +++ b/luxonis_train/core/tuner.py @@ -37,6 +37,9 @@ def __init__(self, cfg: str | dict, args: list[str] | tuple[str, ...] | None): rank = rank_zero_only.rank cfg_tracker = self.cfg.tracker tracker_params = cfg_tracker.model_dump() + tracker_params["is_wandb"] = ( + False # wandb doesn't allow multiple concurrent runs, handle this separately + ) self.parent_tracker = LuxonisTrackerPL( rank=rank, mlflow_tracking_uri=self.cfg.ENVIRON.MLFLOW_TRACKING_URI, @@ -86,8 +89,22 @@ def tune(self) -> None: best_study_params = study.best_params logger.info(f"Best study parameters: {best_study_params}") + self.parent_tracker.log_hyperparams(best_study_params) + if self.cfg.tracker.is_wandb: + # If wandb used then init parent tracker separately at the end + wandb_parent_tracker = LuxonisTrackerPL( + project_name=self.cfg.tracker.project_name, + project_id=self.cfg.tracker.project_id, + run_name=self.parent_tracker.run_name, + save_directory=self.cfg.tracker.save_directory, + is_wandb=True, + wandb_entity=self.cfg.tracker.wandb_entity, + rank=rank_zero_only.rank, + ) + wandb_parent_tracker.log_hyperparams(best_study_params) + def _objective(self, trial: optuna.trial.Trial) -> float: """Objective function used to optimize Optuna study.""" rank = rank_zero_only.rank @@ -122,10 +139,6 @@ def _objective(self, trial: optuna.trial.Trial) -> float: ) pruner_callback = PyTorchLightningPruningCallback(trial, monitor="val/loss") callbacks.append(pruner_callback) - - tracker_end_run = TrackerEndRun() - callbacks.append(tracker_end_run) - deterministic = False if self.cfg.trainer.seed: pl.seed_everything(cfg.trainer.seed, workers=True) @@ -156,8 +169,6 @@ def _objective(self, trial: optuna.trial.Trial) -> float: except optuna.TrialPruned as e: # Pruning is done by raising an error - # When .fit() errors out we have to gracefully also end the trackers - tracker_end_run.end_trackers(child_tracker) logger.info(e) if "val/loss" not in pl_trainer.callback_metrics: @@ -209,24 +220,3 @@ def _get_trial_params(self, trial: optuna.trial.Trial) -> dict[str, Any]: "No paramteres to tune. Specify them under `tuner.params`." ) return new_params - - -class TrackerEndRun(pl.Callback): - """Callback that ends trackers of child processes during tuning study""" - - def teardown( - self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str - ) -> None: - self.end_trackers(trainer.logger) # type: ignore - return super().teardown(trainer, pl_module, stage) - - def end_trackers(self, tracker: LuxonisTrackerPL) -> None: - """Ends WandB and MLFlow trackers - - Args: - tracker (LuxonisTrackerPL): Currently active tracker - """ - if tracker.is_wandb: - tracker.experiment["wandb"].finish() - if tracker.is_mlflow: - tracker.experiment["mlflow"].end_run() diff --git a/luxonis_train/utils/tracker.py b/luxonis_train/utils/tracker.py index 13c77cb2..65fea368 100644 --- a/luxonis_train/utils/tracker.py +++ b/luxonis_train/utils/tracker.py @@ -1,8 +1,28 @@ from lightning.pytorch.loggers.logger import Logger from luxonis_ml.tracker import LuxonisTracker +from lightning.pytorch.utilities import rank_zero_only # type: ignore class LuxonisTrackerPL(LuxonisTracker, Logger): """Implementation of LuxonisTracker that is compatible with PytorchLightning.""" - ... + @rank_zero_only + def finalize(self, status: str = "success") -> None: + """Finalizes current run""" + if self.is_tensorboard: + self.experiment["tensorboard"].flush() + self.experiment["tensorboard"].close() + if self.is_mlflow: + if status == "success": + mlflow_status = "FINISHED" + elif status == "failed": + mlflow_status = "FAILED" + elif status == "finished": + mlflow_status = "FINISHED" + self.experiment["mlflow"].end_run(mlflow_status) + if self.is_wandb: + if status == "success": + wandb_status = 0 + else: + wandb_status = 1 + self.experiment["wandb"].finish(wandb_status) From 4041c9901ccff94e70f5fb9349eee5558036c160 Mon Sep 17 00:00:00 2001 From: Michal Sejak Date: Sun, 2 Jun 2024 00:27:04 +0200 Subject: [PATCH 08/23] Multi-input test case - building a complex multi-input POSET model with export --- tests/integration/test_multi_input.py | 195 ++++++++++++++++++++++++++ 1 file changed, 195 insertions(+) create mode 100644 tests/integration/test_multi_input.py diff --git a/tests/integration/test_multi_input.py b/tests/integration/test_multi_input.py new file mode 100644 index 00000000..e7f5ba26 --- /dev/null +++ b/tests/integration/test_multi_input.py @@ -0,0 +1,195 @@ +import os +import shutil +from typing import Annotated + +import pytest +import torch +from pydantic import Field +from torch import Tensor +from torch.nn.parameter import Parameter + +from luxonis_train.core import Trainer +from luxonis_train.nodes import BaseNode +from luxonis_train.utils.loaders import BaseLoaderTorch +from luxonis_train.utils.registry import LOADERS +from luxonis_train.utils.types import BaseProtocol, FeaturesProtocol, LabelType + +LOADERS.register_module() + + +class CustomMultiInputLoader(BaseLoaderTorch): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + @property + def input_shape(self): + return { + "left": torch.Size([3, 224, 224]), + "right": torch.Size([3, 224, 224]), + "disparity": torch.Size([1, 224, 224]), + "pointcloud": torch.Size([1000, 3]), + } + + @property + def images_name(self): + return "left" + + def __getitem__(self, idx): + # Fake data + left = torch.rand(3, 224, 224, dtype=torch.float32) + right = torch.rand(3, 224, 224, dtype=torch.float32) + disparity = torch.rand(1, 224, 224, dtype=torch.float32) + pointcloud = torch.rand(1000, 3, dtype=torch.float32) + inputs = { + "left": left, + "right": right, + "disparity": disparity, + "pointcloud": pointcloud, + } + + # Fake labels + segmap = torch.zeros(1, 224, 224, dtype=torch.float32) + labels = { + "default": { + LabelType.SEGMENTATION: segmap, + } + } + + return inputs, labels + + def __len__(self): + return 10 + + def get_classes(self) -> dict[LabelType, list[str]]: + return {LabelType.SEGMENTATION: ["square"]} + + +class FullCustomMultiInputProtocol(BaseProtocol): + left: Annotated[list[Tensor], Field(min_length=1)] + right: Annotated[list[Tensor], Field(min_length=1)] + disparity: Annotated[list[Tensor], Field(min_length=1)] + pointcloud: Annotated[list[Tensor], Field(min_length=1)] + + +class RGBDCustomMultiInputProtocol(BaseProtocol): + left: Annotated[list[Tensor], Field(min_length=1)] + right: Annotated[list[Tensor], Field(min_length=1)] + disparity: Annotated[list[Tensor], Field(min_length=1)] + + +class PointcloudCustomMultiInputProtocol(BaseProtocol): + pointcloud: Annotated[list[Tensor], Field(min_length=1)] + + +class DisparityCustomMultiInputProtocol(BaseProtocol): + disparity: Annotated[list[Tensor], Field(min_length=1)] + + +class MultiInputTestBaseNode(BaseNode): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.scalar = Parameter(torch.tensor(1.0), requires_grad=True) + + def forward(self, inputs): + return [self.scalar * inp for inp in inputs] + + def unwrap(self, inputs: list[dict[str, list[Tensor]]]): + return [item for inp in inputs for key in inp for item in inp[key]] + + +class FullBackbone(MultiInputTestBaseNode): + def __init__(self, **kwargs): + in_protocols = [FullCustomMultiInputProtocol] + super().__init__(**kwargs) + self.in_protocols = in_protocols + + +class RGBDBackbone(MultiInputTestBaseNode): + def __init__(self, **kwargs): + in_protocols = [RGBDCustomMultiInputProtocol] + super().__init__(**kwargs) + self.in_protocols = in_protocols + + +class PointcloudBackbone(MultiInputTestBaseNode): + def __init__(self, **kwargs): + in_protocols = [PointcloudCustomMultiInputProtocol] + super().__init__(**kwargs) + self.in_protocols = in_protocols + + +class FusionNeck(MultiInputTestBaseNode): + def __init__(self, **kwargs): + in_protocols = [ + DisparityCustomMultiInputProtocol, + FeaturesProtocol, + FeaturesProtocol, + ] + super().__init__(**kwargs) + self.in_protocols = in_protocols + + +class FusionNeck2(MultiInputTestBaseNode): + def __init__(self, **kwargs): + in_protocols = [FeaturesProtocol, FeaturesProtocol, FeaturesProtocol] + super().__init__(**kwargs) + self.in_protocols = in_protocols + + +class CustomSegHead1(MultiInputTestBaseNode): + def __init__(self, **kwargs): + in_protocols = [FeaturesProtocol] + super().__init__(**kwargs) + self.in_protocols = in_protocols + + def wrap(self, outputs: list[Tensor]): + return {"segmentation": outputs} + + +class CustomSegHead2(MultiInputTestBaseNode): + def __init__(self, **kwargs): + in_protocols = [ + DisparityCustomMultiInputProtocol, + FeaturesProtocol, + FeaturesProtocol, + ] + super().__init__(**kwargs) + self.in_protocols = in_protocols + + def wrap(self, outputs: list[Tensor]): + return {"segmentation": outputs} + + +@pytest.fixture(scope="function", autouse=True) +def clear_output(): + shutil.rmtree("output", ignore_errors=True) + + +@pytest.mark.parametrize( + "config_file", [path for path in os.listdir("configs") if "multi_input" in path] +) +def test_sanity(config_file): + # opts = [ + # "trainer.epochs", + # "3", + # "trainer.validation_interval", + # "3", + # ] + + Trainer(f"configs/{config_file}").train() + + # TODO add export and eval tests + # opts += ["model.weights", str(list(Path("output").rglob("*.ckpt"))[0])] + # opts += ["exporter.onnx.opset_version", "11"] + + # result = subprocess.run( + # ["luxonis_train", "export", "--config", f"configs/{config_file}", *opts], + # ) + + # assert result.returncode == 0 + + # result = subprocess.run( + # ["luxonis_train", "eval", "--config", f"configs/{config_file}", *opts], + # ) + + # assert result.returncode == 0 From 6809f98e4f1b39b865175d2efed2c2cee87b44ea Mon Sep 17 00:00:00 2001 From: Michal Sejak Date: Sun, 2 Jun 2024 00:28:05 +0200 Subject: [PATCH 09/23] Added a test for new collate_fn, which collates dicts of tensors instead of bare tensors --- .../test_loaders/test_base_loader.py | 63 ++++++++++++++----- 1 file changed, 47 insertions(+), 16 deletions(-) diff --git a/tests/unittests/test_utils/test_loaders/test_base_loader.py b/tests/unittests/test_utils/test_loaders/test_base_loader.py index b5c8b299..47af2192 100644 --- a/tests/unittests/test_utils/test_loaders/test_base_loader.py +++ b/tests/unittests/test_utils/test_loaders/test_base_loader.py @@ -7,34 +7,65 @@ from luxonis_train.utils.types import LabelType -def test_collate_fn(): +@pytest.mark.parametrize( + "input_names_and_shapes", + [ + [("features", torch.Size([3, 224, 224]))], + [ + ("features", torch.Size([3, 224, 224])), + ("segmentation", torch.Size([1, 224, 224])), + ], + [ + ("features", torch.Size([3, 224, 224])), + ("segmentation", torch.Size([1, 224, 224])), + ("disparity", torch.Size([1, 224, 224])), + ], + [ + ("features", torch.Size([3, 224, 224])), + ("pointcloud", torch.Size([1000, 3])), + ], + [ + ("features", torch.Size([3, 224, 224])), + ("pointcloud", torch.Size([1000, 3])), + ("foobar", torch.Size([2, 3, 4, 5, 6])), + ], + ], +) +@pytest.mark.parametrize("batch_size", [1, 2]) +def test_collate_fn(input_names_and_shapes, batch_size): # Mock batch data - batch = [ - ( - torch.rand(3, 224, 224, dtype=torch.float32), - {"default": {LabelType.CLASSIFICATION: torch.tensor([1, 0])}}, - ), - ( - torch.rand(3, 224, 224, dtype=torch.float32), - {"default": {LabelType.CLASSIFICATION: torch.tensor([0, 1])}}, - ), - ] + + def build_batch_element(): + inputs = {} + for name, shape in input_names_and_shapes: + inputs[name] = torch.rand(shape, dtype=torch.float32) + + labels = { + "default": { + LabelType.CLASSIFICATION: torch.randint(0, 2, (2,), dtype=torch.int64), + } + } + + return inputs, labels + + batch = [build_batch_element() for _ in range(batch_size)] # Call collate_fn - imgs, annotations = collate_fn(batch) + inputs, annotations = collate_fn(batch) # Check images tensor - assert imgs.shape == (2, 3, 224, 224) - assert imgs.dtype == torch.float32 + assert inputs["features"].shape == (batch_size, 3, 224, 224) + assert inputs["features"].dtype == torch.float32 # Check annotations assert "default" in annotations annotations = annotations["default"] assert LabelType.CLASSIFICATION in annotations - assert annotations[LabelType.CLASSIFICATION].shape == (2, 2) + assert annotations[LabelType.CLASSIFICATION].shape == (batch_size, 2) assert annotations[LabelType.CLASSIFICATION].dtype == torch.int64 - # TODO: test also segmentation, boundingbox and keypoint + +# TODO: test also segmentation, boundingbox and keypoint if __name__ == "__main__": From f40024696143d16460e816652db356d0b0787171 Mon Sep 17 00:00:00 2001 From: Michal Sejak Date: Sun, 2 Jun 2024 00:29:22 +0200 Subject: [PATCH 10/23] Added a config for a multi-input model --- configs/example_multi_input.yaml | 100 +++++++++++++++++++++++++++++++ 1 file changed, 100 insertions(+) create mode 100644 configs/example_multi_input.yaml diff --git a/configs/example_multi_input.yaml b/configs/example_multi_input.yaml new file mode 100644 index 00000000..d9e0384b --- /dev/null +++ b/configs/example_multi_input.yaml @@ -0,0 +1,100 @@ +loader: + name: CustomMultiInputLoader # Yields "left", "right", "disparity", and "pointcloud" inputs + +use_rich_text: True + +model: + name: example_multi_input + nodes: + - name: FullBackbone + alias: full_backbone + + - name: RGBDBackbone + alias: rgbd_backbone + input_sources: + - left + - right + - disparity + + - name: PointcloudBackbone + alias: pointcloud_backbone + input_sources: + - pointcloud + + - name: FusionNeck + alias: fusion_neck + inputs: + - rgbd_backbone + - pointcloud_backbone + input_sources: + - disparity + + - name: FusionNeck2 + alias: fusion_neck_2 + inputs: + - rgbd_backbone + - pointcloud_backbone + - full_backbone + + - name: CustomSegHead1 + alias: head_1 + inputs: + - fusion_neck + + - name: CustomSegHead2 + alias: head_2 + inputs: + - fusion_neck + - fusion_neck_2 + input_sources: + - disparity + + losses: + - name: BCEWithLogitsLoss + alias: loss_1 + attached_to: head_1 + + - name: BCEWithLogitsLoss + alias: loss_2 + attached_to: head_2 + + metrics: + - name: JaccardIndex + alias: jaccard_index_1 + attached_to: head_1 + is_main_metric: True + params: + task: binary + + - name: JaccardIndex + alias: jaccard_index_2 + attached_to: head_2 + params: + task: binary + +trainer: + batch_size: 8 + epochs: &epochs 3 + num_workers: 4 + validation_interval: 3 + num_log_images: -1 + + callbacks: + - name: ExportOnTrainEnd + + optimizer: + name: Adam + params: + lr: 0.01 + +exporter: + onnx: + opset_version: 11 + +# tracker: +# project_name: ace_reid_spot_segmentation +# save_directory: output +# is_tensorboard: True +# is_wandb: False +# wandb_entity: luxonis +# is_mlflow: False From f54660811c558ad5152240dc36106d99cbf52833 Mon Sep 17 00:00:00 2001 From: Michal Sejak Date: Sun, 2 Jun 2024 00:31:59 +0200 Subject: [PATCH 11/23] Added necessary changes to config: ModelNodeConfig has a parameter input_sources, which tells the LuxonisModel which loader sub-elements it wants to load, and LoaderConfig has an images_name parameter which identifies the image-like input among the sub-elements for compatibility with visualizers etc. --- luxonis_train/utils/config.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/luxonis_train/utils/config.py b/luxonis_train/utils/config.py index dc2f737d..1a590ec4 100644 --- a/luxonis_train/utils/config.py +++ b/luxonis_train/utils/config.py @@ -38,7 +38,8 @@ class FreezingConfig(CustomBaseModel): class ModelNodeConfig(CustomBaseModel): name: str alias: str | None = None - inputs: list[str] = [] + inputs: list[str] = [] # From preceding nodes + input_sources: list[str] = [] # From data loader params: dict[str, Any] = {} freezing: FreezingConfig = FreezingConfig() task_group: str = "default" @@ -131,6 +132,7 @@ class TrackerConfig(CustomBaseModel): class LoaderConfig(CustomBaseModel): name: str = "LuxonisLoaderTorch" + images_name: str = "features" train_view: str = "train" val_view: str = "val" test_view: str = "test" @@ -169,7 +171,8 @@ def check_normalize(self): return self def get_active_augmentations(self) -> list[AugmentationConfig]: - """Returns list of augmentations that are active + """Returns list of augmentations that are active. + @rtype: list[AugmentationConfig] @return: Filtered list of active augmentation configs """ From 769aaab1c44f6b7b951ca65b773ef982be041cf7 Mon Sep 17 00:00:00 2001 From: Michal Sejak Date: Sun, 2 Jun 2024 00:33:21 +0200 Subject: [PATCH 12/23] Updated input_shape type and the return type of loader, the LuxonisLoaderTorchOutput. --- luxonis_train/utils/loaders/base_loader.py | 68 +++++++++++++++---- .../utils/loaders/luxonis_loader_torch.py | 8 +-- 2 files changed, 60 insertions(+), 16 deletions(-) diff --git a/luxonis_train/utils/loaders/base_loader.py b/luxonis_train/utils/loaders/base_loader.py index f96f65e1..d7f38560 100644 --- a/luxonis_train/utils/loaders/base_loader.py +++ b/luxonis_train/utils/loaders/base_loader.py @@ -9,8 +9,9 @@ from luxonis_train.utils.registry import LOADERS from luxonis_train.utils.types import Labels, LabelType -LuxonisLoaderTorchOutput = tuple[Tensor, dict[str, Labels]] -"""LuxonisLoaderTorchOutput is a tuple of images and corresponding labels.""" +LuxonisLoaderTorchOutput = tuple[dict[str, Tensor], dict[str, Labels]] +"""LuxonisLoaderTorchOutput are two dictionaries, the first one contains the input data +and the second one contains the labels.""" class BaseLoaderTorch( @@ -27,14 +28,55 @@ def __init__( self, view: str, augmentations: Augmentations | None = None, + images_name: str | None = None, ): self.view = view self.augmentations = augmentations + self._images_name = images_name + + @property + def images_name(self) -> str: + """Name of the input image group. + + Example: 'features' + """ + return self._images_name @property @abstractmethod - def input_shape(self) -> Size: - """Input shape in [N,C,H,W] format.""" + def input_shape(self) -> dict[str, Size]: + """ + Shape of each loader group (sub-element), WITHOUT batch dimension. + Examples: + + 1. Single image input: + { + 'image': torch.Size([3, 224, 224]), + } + + 2. Image and segmentation input: + { + 'image': torch.Size([3, 224, 224]), + 'segmentation': torch.Size([1, 224, 224]), + } + + 3. Left image, right image and disparity input: + { + 'left': torch.Size([3, 224, 224]), + 'right': torch.Size([3, 224, 224]), + 'disparity': torch.Size([1, 224, 224]), + } + + 4. Image, keypoints, and point cloud input: + { + 'image': torch.Size([3, 224, 224]), + 'keypoints': torch.Size([17, 2]), + 'point_cloud': torch.Size([20000, 3]), + } + + @rtype: dict[str, Size] + @return: A dictionary mapping group names to their shapes. + """ ... @abstractmethod @@ -74,18 +116,20 @@ def get_skeletons(self) -> dict[str, dict] | None: def collate_fn( batch: list[LuxonisLoaderTorchOutput], -) -> tuple[Tensor, dict[str, dict[LabelType, Tensor]]]: +) -> tuple[dict[str, Tensor], dict[str, dict[LabelType, Tensor]]]: """Default collate function used for training. @type batch: list[LuxonisLoaderTorchOutput] - @param batch: List of images and their annotations in the LuxonisLoaderTorchOutput - format. - @rtype: tuple[Tensor, dict[LabelType, Tensor]] - @return: Tuple of images and annotations in the format expected by the model. + @param batch: List of loader outputs (dict of Tensors) and labels (dict of Tensors) + in the LuxonisLoaderTorchOutput format. + @rtype: tuple[dict[str, Tensor], dict[LabelType, Tensor]] + @return: Tuple of inputs and annotations in the format expected by the model. """ - imgs, group_dicts = zip(*batch) + inputs, group_dicts = zip(*batch) + + # imgs = tuple[dict[str, Tensor]]. Stack the inputs into a single dict[str, Tensor]. + inputs = {k: torch.stack([i[k] for i in inputs], 0) for k in inputs[0].keys()} out_group_dicts = {task: {} for task in group_dicts[0].keys()} - imgs = torch.stack(imgs, 0) for task in list(group_dicts[0].keys()): anno_dicts = [group[task] for group in group_dicts] @@ -125,4 +169,4 @@ def collate_fn( out_group_dicts[task] = out_annotations - return imgs, out_group_dicts + return inputs, out_group_dicts diff --git a/luxonis_train/utils/loaders/luxonis_loader_torch.py b/luxonis_train/utils/loaders/luxonis_loader_torch.py index b2eeb168..7fe9179d 100644 --- a/luxonis_train/utils/loaders/luxonis_loader_torch.py +++ b/luxonis_train/utils/loaders/luxonis_loader_torch.py @@ -43,9 +43,9 @@ def __len__(self) -> int: return len(self.base_loader) @property - def input_shape(self) -> Size: - img, _ = self[0] - return Size([1, *img.shape]) + def input_shape(self) -> dict[str, Size]: + img = self[0][0][self._images_name] + return {self._images_name: img.shape} def __getitem__(self, idx: int) -> LuxonisLoaderTorchOutput: img, group_annotations = self.base_loader[idx] @@ -57,7 +57,7 @@ def __getitem__(self, idx: int) -> LuxonisLoaderTorchOutput: for key in annotations: annotations[key] = Tensor(annotations[key]) # type: ignore - return tensor_img, group_annotations + return {self._images_name: tensor_img}, group_annotations def get_classes(self) -> dict[LabelType, list[str]]: _, classes = self.dataset.get_classes() From 9cf7ebbd1055daf9372d59e40e0b3fa4e6d923b2 Mon Sep 17 00:00:00 2001 From: Michal Sejak Date: Sun, 2 Jun 2024 00:34:58 +0200 Subject: [PATCH 13/23] Added images_name property to BaseNode class, removed redundant batch_dim from shapes --- luxonis_train/nodes/base_node.py | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/luxonis_train/nodes/base_node.py b/luxonis_train/nodes/base_node.py index c3124f82..09772d43 100644 --- a/luxonis_train/nodes/base_node.py +++ b/luxonis_train/nodes/base_node.py @@ -86,6 +86,7 @@ def __init__( *, input_shapes: list[Packet[Size]] | None = None, original_in_shape: Size | None = None, + images_name: str = "features", dataset_metadata: DatasetMetadata | None = None, attach_index: AttachIndexType | None = None, in_protocols: list[type[BaseModel]] | None = None, @@ -115,6 +116,7 @@ def __init__( self._input_shapes = input_shapes self._original_in_shape = original_in_shape + self._images_name = images_name if n_classes is not None: if dataset_metadata is not None: raise ValueError("Cannot set both `dataset_metadata` and `n_classes`.") @@ -130,6 +132,12 @@ def _non_set_error(self, name: str) -> ValueError: "but it was not set during initialization. " ) + @property + def images_name(self) -> str: + """Getter for the images name (name of the key from the loader which contains + images).""" + return self._images_name + @property def n_classes(self) -> int: """Getter for the number of classes.""" @@ -180,13 +188,13 @@ def in_sizes(self) -> Size | list[Size]: Example: - >>> input_shapes = [{"features": [Size(1, 64, 128, 128), Size(1, 3, 224, 224)]}] + >>> input_shapes = [{"features": [Size(64, 128, 128), Size(3, 224, 224)]}] >>> attach_index = -1 - >>> in_sizes = Size(1, 3, 224, 224) + >>> in_sizes = Size(3, 224, 224) - >>> input_shapes = [{"features": [Size(1, 64, 128, 128), Size(1, 3, 224, 224)]}] + >>> input_shapes = [{"features": [Size(64, 128, 128), Size(3, 224, 224)]}] >>> attach_index = "all" - >>> in_sizes = [Size(1, 64, 128, 128), Size(1, 3, 224, 224)] + >>> in_sizes = [Size(64, 128, 128), Size(3, 224, 224)] @type: Size | list[Size] @raises IncompatibleException: If the C{input_shapes} are too complicated for @@ -195,13 +203,13 @@ def in_sizes(self) -> Size | list[Size]: if self._in_sizes is not None: return self._in_sizes - features = self.input_shapes[0].get("features") + features = self.input_shapes[0].get(self._images_name) if features is None: raise IncompatibleException( - f"Feature field is missing in {self.__class__.__name__}. " + f"Images field '{self._images_name}' is missing in {self.__class__.__name__}. " "The default implementation of `in_sizes` cannot be used." ) - shapes = self.get_attached(self.input_shapes[0]["features"]) + shapes = self.get_attached(self.input_shapes[0][self._images_name]) if isinstance(shapes, list) and len(shapes) == 1: return shapes[0] return shapes @@ -219,7 +227,7 @@ def in_channels(self) -> int | list[int]: @raises IncompatibleException: If the C{input_shapes} are too complicated for the default implementation. """ - return self._get_nth_size(1) + return self._get_nth_size(-3) @property def in_height(self) -> int | list[int]: @@ -232,7 +240,7 @@ def in_height(self) -> int | list[int]: @raises IncompatibleException: If the C{input_shapes} are too complicated for the default implementation. """ - return self._get_nth_size(2) + return self._get_nth_size(-2) @property def in_width(self) -> int | list[int]: @@ -245,7 +253,7 @@ def in_width(self) -> int | list[int]: @raises IncompatibleException: If the C{input_shapes} are too complicated for the default implementation. """ - return self._get_nth_size(3) + return self._get_nth_size(-1) @property def export(self) -> bool: From 4284fbae27cbe8f0f3837343da531ab4a15a6127 Mon Sep 17 00:00:00 2001 From: Michal Sejak Date: Sun, 2 Jun 2024 00:37:19 +0200 Subject: [PATCH 14/23] Implemented multi-input support in node building and model export --- luxonis_train/models/luxonis_model.py | 154 ++++++++++++++++++-------- 1 file changed, 108 insertions(+), 46 deletions(-) diff --git a/luxonis_train/models/luxonis_model.py b/luxonis_train/models/luxonis_model.py index d3ed26a2..cc6cc152 100644 --- a/luxonis_train/models/luxonis_model.py +++ b/luxonis_train/models/luxonis_model.py @@ -100,7 +100,7 @@ def __init__( self, cfg: Config, save_dir: str, - input_shape: list[int] | Size, + input_shape: dict[str, Size], dataset_metadata: DatasetMetadata | None = None, **kwargs, ): @@ -110,8 +110,9 @@ def __init__( @param cfg: Config object. @type save_dir: str @param save_dir: Directory to save checkpoints. - @type input_shape: list[int] | L{Size} - @param input_shape: Shape of the input tensor. + @type input_shape: dict[str, Size] + @param input_shape: Dictionary of input shapes. Keys are input names, values are + shapes. @type dataset_metadata: L{DatasetMetadata} | None @param dataset_metadata: Dataset metadata. @type kwargs: Any @@ -123,11 +124,12 @@ def __init__( self._export: bool = False self.cfg = cfg - self.original_in_shape = Size(input_shape) + self.original_in_shape = input_shape + self.images_name = cfg.loader.images_name self.dataset_metadata = dataset_metadata or DatasetMetadata() self.frozen_nodes: list[tuple[nn.Module, int]] = [] self.graph: dict[str, list[str]] = {} - self.input_shapes: dict[str, list[Size]] = {} + self.loader_input_shapes: dict[str, Size] = {} self.loss_weights: dict[str, float] = {} self.main_metric: str | None = None self.save_dir = save_dir @@ -160,8 +162,33 @@ def __init__( unfreeze_after = int(node_cfg.freezing.unfreeze_after * epochs) frozen_nodes.append((node_name, unfreeze_after)) nodes[node_name] = (Node, node_cfg.params) - if not node_cfg.inputs: - self.input_shapes[node_name] = [Size(input_shape)] + + # Handle inputs for this node + + if not node_cfg.inputs and not node_cfg.input_sources: + # If no inputs (= preceding nodes) nor any input_sources (= loader outputs) are specified, + # assume the node is the starting node and takes all inputs from the loader. + + self.loader_input_shapes[node_name] = { + k: Size(v) for k, v in input_shape.items() + } + else: + # For each input_source, check if the loader provides the required output. + # If yes, add the shape to the input_shapes dict. If not, raise an error. + self.loader_input_shapes[node_name] = {} + for input_source in node_cfg.input_sources: + if input_source not in input_shape: + raise ValueError( + f"Node {node_name} requires input source {input_source}, " + "which is not provided by the loader." + ) + + self.loader_input_shapes[node_name][input_source] = Size( + input_shape[input_source] + ) + + # Inputs (= preceding nodes) are handled in the _initiate_nodes method. + self.graph[node_name] = node_cfg.inputs self.nodes = self._initiate_nodes(nodes) @@ -213,44 +240,64 @@ def _initiate_nodes( """ initiated_nodes: dict[str, BaseNode] = {} - dummy_outputs: dict[str, Packet[Tensor]] = { - f"__{node_name}_input__": { - "features": [torch.zeros(2, *shape[1:]) for shape in shapes] - } - for node_name, shapes in self.input_shapes.items() + dummy_inputs: dict[str, Packet[Tensor]] = { + input_name: [torch.zeros(2, *shape)] + for node_name, shapes in self.loader_input_shapes.items() + for input_name, shape in shapes.items() } for node_name, (Node, node_kwargs), node_input_names, _ in traverse_graph( self.graph, nodes ): - node_input_shapes: list[Packet[Size]] = [] node_dummy_inputs: list[Packet[Tensor]] = [] + """List of dummy input packets for the node. + + The first one is always from the loader. + """ + node_input_shapes: list[Packet[Size]] = [] + """Corresponding list of input shapes.""" - if not node_input_names: - node_input_names = [f"__{node_name}_input__"] + # Add inputs from the loader (if the node has at least one input_source) + if len(self.loader_input_shapes[node_name]) > 0: + loader_packet = {} + for input_name in self.loader_input_shapes[node_name]: + loader_packet[input_name] = dummy_inputs[input_name] + # Add the loader packet to the list of inputs + node_dummy_inputs.append(loader_packet) + + # Add its shape to the list of input shapes + loader_shape_packet = get_shape_packet(loader_packet) + node_input_shapes.append(loader_shape_packet) + + # Add inputs from preceding Nodes for node_input_name in node_input_names: - dummy_output = dummy_outputs[node_input_name] - shape_packet = get_shape_packet(dummy_output) + dummy_input = dummy_inputs[node_input_name] + + # Add the input to the list of inputs + node_dummy_inputs.append(dummy_input) + + # Add its shape to the list of input shapes + shape_packet = get_shape_packet(dummy_input) node_input_shapes.append(shape_packet) - node_dummy_inputs.append(dummy_output) - node = Node( - input_shapes=node_input_shapes, - original_in_shape=self.original_in_shape, - dataset_metadata=self.dataset_metadata, - **node_kwargs, - ) - node_outputs = node.run(node_dummy_inputs) + node = Node( + input_shapes=node_input_shapes, + original_in_shape=self.original_in_shape, + images_name=self.images_name, + dataset_metadata=self.dataset_metadata, + **node_kwargs, + ) + node_outputs = node.run(node_dummy_inputs) - dummy_outputs[node_name] = node_outputs - initiated_nodes[node_name] = node + dummy_inputs[node_name] = node_outputs + initiated_nodes[node_name] = node return nn.ModuleDict(initiated_nodes) def forward( self, - inputs: Tensor, + inputs: dict[str, Tensor], task_labels: TaskLabels | None = None, images: Tensor | None = None, *, @@ -280,27 +327,32 @@ def forward( @rtype: L{LuxonisOutput} @return: Output of the model. """ - input_node_name = list(self.input_shapes.keys())[0] - input_dict = {input_node_name: [inputs]} losses: dict[ str, dict[str, Tensor | tuple[Tensor, dict[str, Tensor]]] ] = defaultdict(dict) visualizations: dict[str, dict[str, Tensor]] = defaultdict(dict) - computed: dict[str, Packet[Tensor]] = { - f"__{node_name}_input__": {"features": input_tensors} - for node_name, input_tensors in input_dict.items() - } + packetized_inputs = {name: [inputs[name]] for name in inputs} + + computed = {} for node_name, node, input_names, unprocessed in traverse_graph( self.graph, cast(dict[str, BaseNode], self.nodes) ): - # Special input for the first node. Will be changed when - # multiple inputs will be supported in `luxonis-ml.data`. - if not input_names: - input_names = [f"__{node_name}_input__"] + # Build node inputs from loader and preceding nodes + node_inputs = [] + + # Add inputs from the loader (if the node has at least one input_source) + if len(self.loader_input_shapes[node_name]) > 0: + loader_packet = {} + for input_name in self.loader_input_shapes[node_name]: + loader_packet[input_name] = packetized_inputs[input_name] + + # Add the loader packet to the list of inputs + node_inputs.append(loader_packet) - node_inputs = [computed[pred] for pred in input_names] + # Add inputs from preceding Nodes + node_inputs += [computed[pred] for pred in input_names] outputs = node.run(node_inputs) computed[node_name] = outputs labels = task_labels[self.node_tasks[node_name]] if task_labels else None @@ -334,7 +386,10 @@ def forward( if computed_name in self.outputs: continue for node_name in unprocessed: - if computed_name in self.graph[node_name]: + if ( + computed_name in self.graph[node_name] + or computed_name in self.loader_input_shapes[node_name] + ): break else: del computed[computed_name] @@ -393,18 +448,22 @@ def export_onnx(self, save_path: str, **kwargs) -> list[str]: """ inputs = { - name: [torch.zeros(shape).to(self.device) for shape in shapes] - for name, shapes in self.input_shapes.items() + input_name: torch.zeros([1, *shape]).to(self.device) + for name, shapes in self.loader_input_shapes.items() + for input_name, shape in shapes.items() } - # TODO: multiple inputs - inp = list(inputs.values())[0][0] + inputs_deep_clone = { + k: torch.zeros(elem.shape).to(self.device) for k, elem in inputs.items() + } + + inputs_for_onnx = {"inputs": inputs_deep_clone} for module in self.modules(): if isinstance(module, BaseNode): module.set_export_mode() - outputs = self.forward(inp.clone()).outputs + outputs = self.forward(inputs_deep_clone).outputs output_order = sorted( [ (node_name, output_name, i) @@ -444,10 +503,13 @@ def export_forward(inputs) -> tuple[Tensor, ...]: ) self.forward = export_forward # type: ignore + + if "input_names" not in kwargs: + kwargs["input_names"] = list(inputs.keys()) if "output_names" not in kwargs: kwargs["output_names"] = output_names - self.to_onnx(save_path, inp, **kwargs) + self.to_onnx(save_path, inputs_for_onnx, **kwargs) self.forward = old_forward # type: ignore From 0e1ceffb5ea35bfc9a201da063d1b322e894bc78 Mon Sep 17 00:00:00 2001 From: Michal Sejak Date: Sun, 2 Jun 2024 00:40:14 +0200 Subject: [PATCH 15/23] Compatibility changes due to the new way shapes work in loaders, making use of images_name. --- .../attached_modules/losses/adaptive_detection_loss.py | 2 +- .../attached_modules/metrics/mean_average_precision.py | 2 +- .../metrics/mean_average_precision_keypoints.py | 2 +- .../attached_modules/metrics/object_keypoint_similarity.py | 2 +- luxonis_train/attached_modules/visualizers/utils.py | 5 ++++- luxonis_train/core/core.py | 1 + luxonis_train/nodes/bisenet_head.py | 2 +- luxonis_train/nodes/efficient_bbox_head.py | 2 +- luxonis_train/nodes/implicit_keypoint_bbox_head.py | 2 +- luxonis_train/nodes/segmentation_head.py | 2 +- luxonis_train/utils/boxutils.py | 4 +++- 11 files changed, 16 insertions(+), 10 deletions(-) diff --git a/luxonis_train/attached_modules/losses/adaptive_detection_loss.py b/luxonis_train/attached_modules/losses/adaptive_detection_loss.py index af1a7e6a..3766dac4 100644 --- a/luxonis_train/attached_modules/losses/adaptive_detection_loss.py +++ b/luxonis_train/attached_modules/losses/adaptive_detection_loss.py @@ -82,7 +82,7 @@ def __init__( self.stride = self.node.stride self.grid_cell_size = self.node.grid_cell_size self.grid_cell_offset = self.node.grid_cell_offset - self.original_img_size = self.node.original_in_shape[2:] + self.original_img_size = self.node.original_in_shape[self.node.images_name][1:] self.n_warmup_epochs = n_warmup_epochs self.atts_assigner = ATSSAssigner(topk=9, n_classes=self.n_classes) diff --git a/luxonis_train/attached_modules/metrics/mean_average_precision.py b/luxonis_train/attached_modules/metrics/mean_average_precision.py index 0a58d061..5aace3ba 100644 --- a/luxonis_train/attached_modules/metrics/mean_average_precision.py +++ b/luxonis_train/attached_modules/metrics/mean_average_precision.py @@ -41,7 +41,7 @@ def prepare( label = labels[LabelType.BOUNDINGBOX] output_nms = outputs["boxes"] - image_size = self.node.original_in_shape[2:] + image_size = self.node.original_in_shape[self.node.images_name][1:] output_list: list[dict[str, Tensor]] = [] label_list: list[dict[str, Tensor]] = [] diff --git a/luxonis_train/attached_modules/metrics/mean_average_precision_keypoints.py b/luxonis_train/attached_modules/metrics/mean_average_precision_keypoints.py index 3740f58e..7ffdef26 100644 --- a/luxonis_train/attached_modules/metrics/mean_average_precision_keypoints.py +++ b/luxonis_train/attached_modules/metrics/mean_average_precision_keypoints.py @@ -109,7 +109,7 @@ def prepare(self, outputs: Packet[Tensor], labels: Labels): output_list_kpt_map = [] label_list_kpt_map = [] - image_size = self.node.original_in_shape[2:] + image_size = self.node.original_in_shape[self.node.images_name][1:] output_kpts: list[Tensor] = outputs["keypoints"] output_bboxes: list[Tensor] = outputs["boxes"] diff --git a/luxonis_train/attached_modules/metrics/object_keypoint_similarity.py b/luxonis_train/attached_modules/metrics/object_keypoint_similarity.py index c5e4a19b..92be3cb2 100644 --- a/luxonis_train/attached_modules/metrics/object_keypoint_similarity.py +++ b/luxonis_train/attached_modules/metrics/object_keypoint_similarity.py @@ -79,7 +79,7 @@ def prepare( output_list_oks = [] label_list_oks = [] - image_size = self.node.original_in_shape[2:] + image_size = self.node.original_in_shape[self.node.images_name][1:] for i, pred_kpt in enumerate(outputs["keypoints"]): output_list_oks.append({"keypoints": pred_kpt}) diff --git a/luxonis_train/attached_modules/visualizers/utils.py b/luxonis_train/attached_modules/visualizers/utils.py index aa1a90d3..72ba642a 100644 --- a/luxonis_train/attached_modules/visualizers/utils.py +++ b/luxonis_train/attached_modules/visualizers/utils.py @@ -220,7 +220,10 @@ def unnormalize( return out_img -def get_unnormalized_images(cfg: Config, images: Tensor) -> Tensor: +def get_unnormalized_images(cfg: Config, inputs: dict[str, Tensor]) -> Tensor: + # Get images from inputs according to config + images = inputs[cfg.loader.images_name] + normalize_params = cfg.trainer.preprocessing.normalize.params mean = std = None if cfg.trainer.preprocessing.normalize.active: diff --git a/luxonis_train/core/core.py b/luxonis_train/core/core.py index d23787fc..37f46cc4 100644 --- a/luxonis_train/core/core.py +++ b/luxonis_train/core/core.py @@ -146,6 +146,7 @@ def __init__( if view == "train" else self.cfg.loader.val_view ), + images_name=self.cfg.loader.images_name, **self.cfg.loader.params, ) for view in ["train", "val", "test"] diff --git a/luxonis_train/nodes/bisenet_head.py b/luxonis_train/nodes/bisenet_head.py index a3b11df6..f98b9c2b 100644 --- a/luxonis_train/nodes/bisenet_head.py +++ b/luxonis_train/nodes/bisenet_head.py @@ -32,7 +32,7 @@ def __init__( """ super().__init__(task_type=LabelType.SEGMENTATION, **kwargs) - original_height = self.original_in_shape[2] + original_height = self.original_in_shape[self.images_name][1] upscale_factor = 2 ** infer_upscale_factor(self.in_height, original_height) out_channels = self.n_classes * upscale_factor * upscale_factor diff --git a/luxonis_train/nodes/efficient_bbox_head.py b/luxonis_train/nodes/efficient_bbox_head.py index a4f3bc93..fbfc0ae9 100644 --- a/luxonis_train/nodes/efficient_bbox_head.py +++ b/luxonis_train/nodes/efficient_bbox_head.py @@ -126,7 +126,7 @@ def _fit_stride_to_num_heads(self): """Returns correct stride for number of heads and attach index.""" stride = torch.tensor( [ - self.original_in_shape[2] / x[2] # type: ignore + self.original_in_shape[self.images_name][1] / x[1] # type: ignore for x in self.in_sizes[: self.n_heads] ], dtype=torch.int, diff --git a/luxonis_train/nodes/implicit_keypoint_bbox_head.py b/luxonis_train/nodes/implicit_keypoint_bbox_head.py index 76a66eb6..15745803 100644 --- a/luxonis_train/nodes/implicit_keypoint_bbox_head.py +++ b/luxonis_train/nodes/implicit_keypoint_bbox_head.py @@ -218,7 +218,7 @@ def _fit_to_num_heads(self, channel_list: list): out_channel_list = channel_list[: self.num_heads] stride = torch.tensor( [ - self.original_in_shape[2] / h + self.original_in_shape[self.images_name][1] / h for h in cast(list[int], self.in_height)[: self.num_heads] ], dtype=torch.int, diff --git a/luxonis_train/nodes/segmentation_head.py b/luxonis_train/nodes/segmentation_head.py index a3420491..057e96d7 100644 --- a/luxonis_train/nodes/segmentation_head.py +++ b/luxonis_train/nodes/segmentation_head.py @@ -29,7 +29,7 @@ def __init__(self, **kwargs): """ super().__init__(task_type=LabelType.SEGMENTATION, **kwargs) - original_height = self.original_in_shape[2] + original_height = self.original_in_shape[self.images_name][1] num_up = infer_upscale_factor(self.in_height, original_height, strict=False) modules = [] diff --git a/luxonis_train/utils/boxutils.py b/luxonis_train/utils/boxutils.py index a59f4cd0..6b97cc5a 100644 --- a/luxonis_train/utils/boxutils.py +++ b/luxonis_train/utils/boxutils.py @@ -433,7 +433,9 @@ def anchors_from_dataset( widths.append(curr_wh) inputs = inp assert inputs is not None, "No inputs found in data loader" - _, _, h, w = inputs.shape # assuming all images are same size + _, _, h, w = inputs[ + loader.dataset.images_name + ].shape # assuming all images are same size img_size = torch.tensor([w, h]) wh = torch.vstack(widths) * img_size From 951b981a3b42d4195003ec727d78e114f3e73d98 Mon Sep 17 00:00:00 2001 From: GitHub Actions Date: Sat, 1 Jun 2024 23:17:29 +0000 Subject: [PATCH 16/23] [Automated] Updated coverage badge --- media/coverage_badge.svg | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/media/coverage_badge.svg b/media/coverage_badge.svg index 90299371..b750dd9c 100644 --- a/media/coverage_badge.svg +++ b/media/coverage_badge.svg @@ -15,7 +15,7 @@ coverage coverage - 76% - 76% + 77% + 77% From 1cdcb9f29a365558a74df3c410b1bcf6ce82556d Mon Sep 17 00:00:00 2001 From: Michal Sejak Date: Sun, 9 Jun 2024 18:00:11 +0200 Subject: [PATCH 17/23] Moved 'images_name' setting from loader implementation to config due to code structure requirements. Added missing tests for evaluation, export, and inference. --- configs/example_multi_input.yaml | 3 +- tests/integration/test_multi_input.py | 53 ++++++++++++--------------- 2 files changed, 26 insertions(+), 30 deletions(-) diff --git a/configs/example_multi_input.yaml b/configs/example_multi_input.yaml index d9e0384b..ba779c0f 100644 --- a/configs/example_multi_input.yaml +++ b/configs/example_multi_input.yaml @@ -1,5 +1,6 @@ loader: - name: CustomMultiInputLoader # Yields "left", "right", "disparity", and "pointcloud" inputs + name: CustomMultiInputLoader # Yields "left", "right", "disparity", and "pointcloud" inputs. See implementation in `tests/integration/test_multi_input.py`. + images_name: left # Name of the key in the batch that contains image-like data. Needs to be set for visualizers and evaluators to work. use_rich_text: True diff --git a/tests/integration/test_multi_input.py b/tests/integration/test_multi_input.py index e7f5ba26..6dced03c 100644 --- a/tests/integration/test_multi_input.py +++ b/tests/integration/test_multi_input.py @@ -1,5 +1,6 @@ import os import shutil +from pathlib import Path from typing import Annotated import pytest @@ -8,7 +9,7 @@ from torch import Tensor from torch.nn.parameter import Parameter -from luxonis_train.core import Trainer +from luxonis_train.core import Exporter, Inferer, Trainer from luxonis_train.nodes import BaseNode from luxonis_train.utils.loaders import BaseLoaderTorch from luxonis_train.utils.registry import LOADERS @@ -30,10 +31,6 @@ def input_shape(self): "pointcloud": torch.Size([1000, 3]), } - @property - def images_name(self): - return "left" - def __getitem__(self, idx): # Fake data left = torch.rand(3, 224, 224, dtype=torch.float32) @@ -169,27 +166,25 @@ def clear_output(): "config_file", [path for path in os.listdir("configs") if "multi_input" in path] ) def test_sanity(config_file): - # opts = [ - # "trainer.epochs", - # "3", - # "trainer.validation_interval", - # "3", - # ] - - Trainer(f"configs/{config_file}").train() - - # TODO add export and eval tests - # opts += ["model.weights", str(list(Path("output").rglob("*.ckpt"))[0])] - # opts += ["exporter.onnx.opset_version", "11"] - - # result = subprocess.run( - # ["luxonis_train", "export", "--config", f"configs/{config_file}", *opts], - # ) - - # assert result.returncode == 0 - - # result = subprocess.run( - # ["luxonis_train", "eval", "--config", f"configs/{config_file}", *opts], - # ) - - # assert result.returncode == 0 + # Test training + trainer = Trainer(f"configs/{config_file}") + trainer.train() + # Test evaluation + trainer.test(view="val") + + # Test export + Exporter(f"configs/{config_file}").export("test_export_multi_input.onnx") + # Cleanup after exporter + assert os.path.exists("test_export_multi_input.onnx") + os.remove("test_export_multi_input.onnx") + + # Test inference + Inferer( + f"configs/{config_file}", + opts=None, + view="train", + save_dir=Path("infer_save_dir"), + ).infer() + # Cleanup after inferer + assert os.path.exists("infer_save_dir") + shutil.rmtree("infer_save_dir") From 6b07cd82fb8562cc9bf9eae6753ac04994012a50 Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Thu, 13 Jun 2024 01:20:42 +0200 Subject: [PATCH 18/23] removed macos tests --- .github/workflows/tests.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 0b4f51da..af77c60f 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -13,7 +13,7 @@ jobs: strategy: fail-fast: false matrix: - os: [ubuntu-latest, windows-latest, macOS-latest] + os: [ubuntu-latest, windows-latest] version: ['3.10', '3.11'] runs-on: ${{ matrix.os }} From 63c6ad0c7d9bbb3a69598e85c6c703e496414bb0 Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Thu, 13 Jun 2024 04:06:40 +0200 Subject: [PATCH 19/23] removed images_name from BaseNode --- .../losses/adaptive_detection_loss.py | 2 +- .../metrics/mean_average_precision.py | 2 +- .../metrics/mean_average_precision_keypoints.py | 2 +- .../metrics/object_keypoint_similarity.py | 2 +- luxonis_train/models/luxonis_model.py | 3 +-- luxonis_train/nodes/base_node.py | 16 +++------------- luxonis_train/nodes/bisenet_head.py | 2 +- luxonis_train/nodes/efficient_bbox_head.py | 2 +- .../nodes/implicit_keypoint_bbox_head.py | 2 +- luxonis_train/nodes/segmentation_head.py | 3 +-- 10 files changed, 12 insertions(+), 24 deletions(-) diff --git a/luxonis_train/attached_modules/losses/adaptive_detection_loss.py b/luxonis_train/attached_modules/losses/adaptive_detection_loss.py index 6b8bf2f0..21291bfa 100644 --- a/luxonis_train/attached_modules/losses/adaptive_detection_loss.py +++ b/luxonis_train/attached_modules/losses/adaptive_detection_loss.py @@ -82,7 +82,7 @@ def __init__( self.stride = self.node.stride self.grid_cell_size = self.node.grid_cell_size self.grid_cell_offset = self.node.grid_cell_offset - self.original_img_size = self.node.original_in_shape[self.node.images_name][1:] + self.original_img_size = self.node.original_in_shape[1:] self.n_warmup_epochs = n_warmup_epochs self.atts_assigner = ATSSAssigner(topk=9, n_classes=self.n_classes) diff --git a/luxonis_train/attached_modules/metrics/mean_average_precision.py b/luxonis_train/attached_modules/metrics/mean_average_precision.py index 7c71374b..67c010ec 100644 --- a/luxonis_train/attached_modules/metrics/mean_average_precision.py +++ b/luxonis_train/attached_modules/metrics/mean_average_precision.py @@ -41,7 +41,7 @@ def prepare( label = labels[self.task][0] output_nms = self.get_input_tensors(outputs) - image_size = self.node.original_in_shape[self.node.images_name][1:] + image_size = self.node.original_in_shape[1:] output_list: list[dict[str, Tensor]] = [] label_list: list[dict[str, Tensor]] = [] diff --git a/luxonis_train/attached_modules/metrics/mean_average_precision_keypoints.py b/luxonis_train/attached_modules/metrics/mean_average_precision_keypoints.py index 931e886d..31bc7557 100644 --- a/luxonis_train/attached_modules/metrics/mean_average_precision_keypoints.py +++ b/luxonis_train/attached_modules/metrics/mean_average_precision_keypoints.py @@ -109,7 +109,7 @@ def prepare(self, outputs: Packet[Tensor], labels: Labels): output_list_kpt_map = [] label_list_kpt_map = [] - image_size = self.node.original_in_shape[self.node.images_name][1:] + image_size = self.node.original_in_shape[1:] output_kpts: list[Tensor] = outputs["keypoints"] output_bboxes: list[Tensor] = outputs["boundingbox"] diff --git a/luxonis_train/attached_modules/metrics/object_keypoint_similarity.py b/luxonis_train/attached_modules/metrics/object_keypoint_similarity.py index 3ca7f6fa..c1768012 100644 --- a/luxonis_train/attached_modules/metrics/object_keypoint_similarity.py +++ b/luxonis_train/attached_modules/metrics/object_keypoint_similarity.py @@ -79,7 +79,7 @@ def prepare( output_list_oks = [] label_list_oks = [] - image_size = self.node.original_in_shape[self.node.images_name][1:] + image_size = self.node.original_in_shape[1:] for i, pred_kpt in enumerate(outputs["keypoints"]): output_list_oks.append({"keypoints": pred_kpt}) diff --git a/luxonis_train/models/luxonis_model.py b/luxonis_train/models/luxonis_model.py index fc636dc9..66062e96 100644 --- a/luxonis_train/models/luxonis_model.py +++ b/luxonis_train/models/luxonis_model.py @@ -280,8 +280,7 @@ def _initiate_nodes( node = Node( input_shapes=node_input_shapes, - original_in_shape=self.original_in_shape, - images_name=self.images_name, + original_in_shape=self.original_in_shape[self.images_name], dataset_metadata=self.dataset_metadata, **node_kwargs, ) diff --git a/luxonis_train/nodes/base_node.py b/luxonis_train/nodes/base_node.py index 7ce05793..8ee03591 100644 --- a/luxonis_train/nodes/base_node.py +++ b/luxonis_train/nodes/base_node.py @@ -86,7 +86,6 @@ def __init__( *, input_shapes: list[Packet[Size]] | None = None, original_in_shape: Size | None = None, - images_name: str = "features", dataset_metadata: DatasetMetadata | None = None, attach_index: AttachIndexType | None = None, in_protocols: list[type[BaseModel]] | None = None, @@ -120,7 +119,6 @@ def __init__( self._input_shapes = input_shapes self._original_in_shape = original_in_shape - self._images_name = images_name if n_classes is not None: if dataset_metadata is not None: raise ValueError("Cannot set both `dataset_metadata` and `n_classes`.") @@ -136,14 +134,6 @@ def _non_set_error(self, name: str) -> ValueError: "but it was not set during initialization. " ) - @property - def images_name(self) -> str: - """Getter for the images name (name of the key from the loader which contains - images).""" - if self._images_name is None: - raise self._non_set_error("images_name") - return self._images_name - @property def task(self) -> str: """Getter for the task.""" @@ -216,13 +206,13 @@ def in_sizes(self) -> Size | list[Size]: if self._in_sizes is not None: return self._in_sizes - features = self.input_shapes[0].get(self._images_name) + features = self.input_shapes[0].get("features") if features is None: raise IncompatibleException( - f"Images field '{self._images_name}' is missing in {self.__class__.__name__}. " + f"Feature field is missing in {self.__class__.__name__}. " "The default implementation of `in_sizes` cannot be used." ) - shapes = self.get_attached(self.input_shapes[0][self._images_name]) + shapes = self.get_attached(self.input_shapes[0]["features"]) if isinstance(shapes, list) and len(shapes) == 1: return shapes[0] return shapes diff --git a/luxonis_train/nodes/bisenet_head.py b/luxonis_train/nodes/bisenet_head.py index 5980ebbf..8bac3573 100644 --- a/luxonis_train/nodes/bisenet_head.py +++ b/luxonis_train/nodes/bisenet_head.py @@ -32,7 +32,7 @@ def __init__( """ super().__init__(task=LabelType.SEGMENTATION, **kwargs) - original_height = self.original_in_shape[self.images_name][1] + original_height = self.original_in_shape[1] upscale_factor = 2 ** infer_upscale_factor(self.in_height, original_height) out_channels = self.n_classes * upscale_factor * upscale_factor diff --git a/luxonis_train/nodes/efficient_bbox_head.py b/luxonis_train/nodes/efficient_bbox_head.py index ec9eb446..e7b23288 100644 --- a/luxonis_train/nodes/efficient_bbox_head.py +++ b/luxonis_train/nodes/efficient_bbox_head.py @@ -126,7 +126,7 @@ def _fit_stride_to_num_heads(self): """Returns correct stride for number of heads and attach index.""" stride = torch.tensor( [ - self.original_in_shape[self.images_name][1] / x[1] # type: ignore + self.original_in_shape[1] / x[1] # type: ignore for x in self.in_sizes[: self.n_heads] ], dtype=torch.int, diff --git a/luxonis_train/nodes/implicit_keypoint_bbox_head.py b/luxonis_train/nodes/implicit_keypoint_bbox_head.py index 859b71e3..dde27ed5 100644 --- a/luxonis_train/nodes/implicit_keypoint_bbox_head.py +++ b/luxonis_train/nodes/implicit_keypoint_bbox_head.py @@ -218,7 +218,7 @@ def _fit_to_num_heads(self, channel_list: list): out_channel_list = channel_list[: self.num_heads] stride = torch.tensor( [ - self.original_in_shape[self.images_name][1] / h + self.original_in_shape[1] / h for h in cast(list[int], self.in_height)[: self.num_heads] ], dtype=torch.int, diff --git a/luxonis_train/nodes/segmentation_head.py b/luxonis_train/nodes/segmentation_head.py index 934c3090..67461eb0 100644 --- a/luxonis_train/nodes/segmentation_head.py +++ b/luxonis_train/nodes/segmentation_head.py @@ -4,7 +4,6 @@ @license: U{BSD-3 } """ - import torch.nn as nn from torch import Tensor @@ -29,7 +28,7 @@ def __init__(self, **kwargs): """ super().__init__(_task_type=LabelType.SEGMENTATION, **kwargs) - original_height = self.original_in_shape[self.images_name][1] + original_height = self.original_in_shape[1] num_up = infer_upscale_factor(self.in_height, original_height, strict=False) modules = [] From a4906e7dc44548a49ac574cb5b11aa303818b278 Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Thu, 13 Jun 2024 04:10:08 +0200 Subject: [PATCH 20/23] renamed get_shape_pocket to to_shape_pocket --- luxonis_train/models/luxonis_model.py | 5 ++--- luxonis_train/models/luxonis_output.py | 4 ++-- luxonis_train/utils/general.py | 4 +++- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/luxonis_train/models/luxonis_model.py b/luxonis_train/models/luxonis_model.py index 66062e96..4a31e602 100644 --- a/luxonis_train/models/luxonis_model.py +++ b/luxonis_train/models/luxonis_model.py @@ -33,7 +33,7 @@ from luxonis_train.utils.config import AttachedModuleConfig, Config from luxonis_train.utils.general import ( DatasetMetadata, - get_shape_packet, + to_shape_packet, traverse_graph, ) from luxonis_train.utils.registry import CALLBACKS, OPTIMIZERS, SCHEDULERS, Registry @@ -274,8 +274,7 @@ def _initiate_nodes( # Add the input to the list of inputs node_dummy_inputs.append(dummy_input) - # Add its shape to the list of input shapes - shape_packet = get_shape_packet(dummy_input) + shape_packet = to_shape_packet(dummy_input) node_input_shapes.append(shape_packet) node = Node( diff --git a/luxonis_train/models/luxonis_output.py b/luxonis_train/models/luxonis_output.py index e6b8e16c..d69943fc 100644 --- a/luxonis_train/models/luxonis_output.py +++ b/luxonis_train/models/luxonis_output.py @@ -3,7 +3,7 @@ from torch import Tensor -from luxonis_train.utils.general import get_shape_packet +from luxonis_train.utils.general import to_shape_packet from luxonis_train.utils.types import Packet @@ -16,7 +16,7 @@ class LuxonisOutput: def __str__(self) -> str: outputs = { - node_name: get_shape_packet(packet) + node_name: to_shape_packet(packet) for node_name, packet in self.outputs.items() } viz = { diff --git a/luxonis_train/utils/general.py b/luxonis_train/utils/general.py index 21c35df0..099beb66 100644 --- a/luxonis_train/utils/general.py +++ b/luxonis_train/utils/general.py @@ -1,5 +1,6 @@ import logging import math +from copy import deepcopy from typing import Generator, TypeVar from pydantic import BaseModel @@ -210,7 +211,7 @@ def infer_upscale_factor( ) -def get_shape_packet(packet: Packet[Tensor]) -> Packet[Size]: +def to_shape_packet(packet: Packet[Tensor]) -> Packet[Size]: shape_packet: Packet[Size] = {} for name, value in packet.items(): shape_packet[name] = [x.shape for x in value] @@ -281,6 +282,7 @@ def traverse_graph( ) # sort the set to allow reproducibility processed: set[str] = set() + graph = deepcopy(graph) while unprocessed_nodes: unprocessed_nodes_copy = unprocessed_nodes.copy() for node_name in unprocessed_nodes_copy: From 25d17583c5d0ad59a25beba853bf34a8d7ec44c6 Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Thu, 13 Jun 2024 04:10:46 +0200 Subject: [PATCH 21/23] simplified input source handling --- luxonis_train/models/luxonis_model.py | 70 ++++++++++----------------- luxonis_train/utils/config.py | 2 +- tests/integration/test_multi_input.py | 41 +++------------- 3 files changed, 33 insertions(+), 80 deletions(-) diff --git a/luxonis_train/models/luxonis_model.py b/luxonis_train/models/luxonis_model.py index 4a31e602..06212f2e 100644 --- a/luxonis_train/models/luxonis_model.py +++ b/luxonis_train/models/luxonis_model.py @@ -130,6 +130,7 @@ def __init__( self.frozen_nodes: list[tuple[nn.Module, int]] = [] self.graph: dict[str, list[str]] = {} self.loader_input_shapes: dict[str, dict[str, Size]] = {} + self.node_input_sources: dict[str, list[str]] = defaultdict(list) self.loss_weights: dict[str, float] = {} self.main_metric: str | None = None self.save_dir = save_dir @@ -161,6 +162,8 @@ def __init__( nodes[node_name] = (Node, {**node_cfg.params, "task": node_cfg.task}) # Handle inputs for this node + if node_cfg.input_sources: + self.node_input_sources[node_name] = node_cfg.input_sources if not node_cfg.inputs and not node_cfg.input_sources: # If no inputs (= preceding nodes) nor any input_sources (= loader outputs) are specified, @@ -169,6 +172,7 @@ def __init__( self.loader_input_shapes[node_name] = { k: Size(v) for k, v in input_shape.items() } + self.node_input_sources[node_name] = list(input_shape.keys()) else: # For each input_source, check if the loader provides the required output. # If yes, add the shape to the input_shapes dict. If not, raise an error. @@ -238,9 +242,9 @@ def _initiate_nodes( initiated_nodes: dict[str, BaseNode] = {} dummy_inputs: dict[str, Packet[Tensor]] = { - input_name: [torch.zeros(2, *shape)] - for node_name, shapes in self.loader_input_shapes.items() - for input_name, shape in shapes.items() + source_name: {"features": [torch.zeros(2, *shape)]} + for shapes in self.loader_input_shapes.values() + for source_name, shape in shapes.items() } for node_name, (Node, node_kwargs), node_input_names, _ in traverse_graph( @@ -254,24 +258,10 @@ def _initiate_nodes( node_input_shapes: list[Packet[Size]] = [] """Corresponding list of input shapes.""" - # Add inputs from the loader (if the node has at least one input_source) - if len(self.loader_input_shapes[node_name]) > 0: - loader_packet = {} - for input_name in self.loader_input_shapes[node_name]: - loader_packet[input_name] = dummy_inputs[input_name] - - # Add the loader packet to the list of inputs - node_dummy_inputs.append(loader_packet) - - # Add its shape to the list of input shapes - loader_shape_packet = get_shape_packet(loader_packet) - node_input_shapes.append(loader_shape_packet) - - # Add inputs from preceding Nodes + node_input_names += self.node_input_sources[node_name] for node_input_name in node_input_names: dummy_input = dummy_inputs[node_input_name] - # Add the input to the list of inputs node_dummy_inputs.append(dummy_input) shape_packet = to_shape_packet(dummy_input) @@ -322,32 +312,23 @@ def forward( @rtype: L{LuxonisOutput} @return: Output of the model. """ - losses: dict[ str, dict[str, Tensor | tuple[Tensor, dict[str, Tensor]]] ] = defaultdict(dict) visualizations: dict[str, dict[str, Tensor]] = defaultdict(dict) - packetized_inputs = {name: [inputs[name]] for name in inputs} - - computed = {} + computed: dict[str, Packet[Tensor]] = {} for node_name, node, input_names, unprocessed in traverse_graph( self.graph, cast(dict[str, BaseNode], self.nodes) ): - # Build node inputs from loader and preceding nodes - node_inputs = [] + input_names += self.node_input_sources[node_name] - # Add inputs from the loader (if the node has at least one input_source) - if len(self.loader_input_shapes[node_name]) > 0: - loader_packet = {} - for input_name in self.loader_input_shapes[node_name]: - loader_packet[input_name] = packetized_inputs[input_name] - - # Add the loader packet to the list of inputs - node_inputs.append(loader_packet) - - # Add inputs from preceding Nodes - node_inputs += [computed[pred] for pred in input_names] + node_inputs: list[Packet[Tensor]] = [] + for pred in input_names: + if pred in computed: + node_inputs.append(computed[pred]) + else: + node_inputs.append({"features": [inputs[pred]]}) outputs = node.run(node_inputs) computed[node_name] = outputs @@ -380,10 +361,7 @@ def forward( if computed_name in self.outputs: continue for node_name in unprocessed: - if ( - computed_name in self.graph[node_name] - or computed_name in self.loader_input_shapes[node_name] - ): + if computed_name in self.graph[node_name]: break else: del computed[computed_name] @@ -443,7 +421,7 @@ def export_onnx(self, save_path: str, **kwargs) -> list[str]: inputs = { input_name: torch.zeros([1, *shape]).to(self.device) - for name, shapes in self.loader_input_shapes.items() + for shapes in self.loader_input_shapes.values() for input_name, shape in shapes.items() } @@ -556,7 +534,7 @@ def process_losses( training_step_output["loss"] = final_loss.detach().cpu() return final_loss, training_step_output - def training_step(self, train_batch: tuple[Tensor, Labels]) -> Tensor: + def training_step(self, train_batch: tuple[dict[str, Tensor], Labels]) -> Tensor: """Performs one step of training with provided batch.""" outputs = self.forward(*train_batch) assert outputs.losses, "Losses are empty, check if you have defined any loss" @@ -565,11 +543,15 @@ def training_step(self, train_batch: tuple[Tensor, Labels]) -> Tensor: self.training_step_outputs.append(training_step_output) return loss - def validation_step(self, val_batch: tuple[Tensor, Labels]) -> dict[str, Tensor]: + def validation_step( + self, val_batch: tuple[dict[str, Tensor], Labels] + ) -> dict[str, Tensor]: """Performs one step of validation with provided batch.""" return self._evaluation_step("val", val_batch) - def test_step(self, test_batch: tuple[Tensor, Labels]) -> dict[str, Tensor]: + def test_step( + self, test_batch: tuple[dict[str, Tensor], Labels] + ) -> dict[str, Tensor]: """Performs one step of testing with provided batch.""" return self._evaluation_step("test", test_batch) @@ -609,7 +591,7 @@ def get_status_percentage(self) -> float: return (self.current_epoch / self.cfg.trainer.epochs) * 100 def _evaluation_step( - self, mode: Literal["test", "val"], batch: tuple[Tensor, Labels] + self, mode: Literal["test", "val"], batch: tuple[dict[str, Tensor], Labels] ) -> dict[str, Tensor]: inputs, labels = batch images = None diff --git a/luxonis_train/utils/config.py b/luxonis_train/utils/config.py index e4532831..c3e21bb3 100644 --- a/luxonis_train/utils/config.py +++ b/luxonis_train/utils/config.py @@ -133,7 +133,7 @@ class TrackerConfig(CustomBaseModel): class LoaderConfig(CustomBaseModel): name: str = "LuxonisLoaderTorch" - images_name: str = "features" + images_name: str = "image" train_view: str = "train" val_view: str = "val" test_view: str = "test" diff --git a/tests/integration/test_multi_input.py b/tests/integration/test_multi_input.py index 7bf4cbc0..8f1eef23 100644 --- a/tests/integration/test_multi_input.py +++ b/tests/integration/test_multi_input.py @@ -1,18 +1,16 @@ import os import shutil from pathlib import Path -from typing import Annotated import pytest import torch -from pydantic import Field from torch import Tensor from torch.nn.parameter import Parameter from luxonis_train.core import Exporter, Inferer, Trainer from luxonis_train.nodes import BaseNode from luxonis_train.utils.loaders import BaseLoaderTorch -from luxonis_train.utils.types import BaseProtocol, FeaturesProtocol, LabelType +from luxonis_train.utils.types import FeaturesProtocol, LabelType class CustomMultiInputLoader(BaseLoaderTorch): @@ -56,27 +54,6 @@ def get_classes(self) -> dict[LabelType, list[str]]: return {LabelType.SEGMENTATION: ["square"]} -class FullCustomMultiInputProtocol(BaseProtocol): - left: Annotated[list[Tensor], Field(min_length=1)] - right: Annotated[list[Tensor], Field(min_length=1)] - disparity: Annotated[list[Tensor], Field(min_length=1)] - pointcloud: Annotated[list[Tensor], Field(min_length=1)] - - -class RGBDCustomMultiInputProtocol(BaseProtocol): - left: Annotated[list[Tensor], Field(min_length=1)] - right: Annotated[list[Tensor], Field(min_length=1)] - disparity: Annotated[list[Tensor], Field(min_length=1)] - - -class PointcloudCustomMultiInputProtocol(BaseProtocol): - pointcloud: Annotated[list[Tensor], Field(min_length=1)] - - -class DisparityCustomMultiInputProtocol(BaseProtocol): - disparity: Annotated[list[Tensor], Field(min_length=1)] - - class MultiInputTestBaseNode(BaseNode): def __init__(self, **kwargs): super().__init__(**kwargs) @@ -91,21 +68,21 @@ def unwrap(self, inputs: list[dict[str, list[Tensor]]]): class FullBackbone(MultiInputTestBaseNode): def __init__(self, **kwargs): - in_protocols = [FullCustomMultiInputProtocol] + in_protocols = [FeaturesProtocol] * 4 super().__init__(**kwargs) self.in_protocols = in_protocols class RGBDBackbone(MultiInputTestBaseNode): def __init__(self, **kwargs): - in_protocols = [RGBDCustomMultiInputProtocol] + in_protocols = [FeaturesProtocol] * 3 super().__init__(**kwargs) self.in_protocols = in_protocols class PointcloudBackbone(MultiInputTestBaseNode): def __init__(self, **kwargs): - in_protocols = [PointcloudCustomMultiInputProtocol] + in_protocols = [FeaturesProtocol] super().__init__(**kwargs) self.in_protocols = in_protocols @@ -113,7 +90,7 @@ def __init__(self, **kwargs): class FusionNeck(MultiInputTestBaseNode): def __init__(self, **kwargs): in_protocols = [ - DisparityCustomMultiInputProtocol, + FeaturesProtocol, FeaturesProtocol, FeaturesProtocol, ] @@ -134,23 +111,17 @@ def __init__(self, **kwargs): super().__init__(**kwargs, _task_type=LabelType.SEGMENTATION) self.in_protocols = in_protocols - def wrap(self, outputs: list[Tensor]): - return {"segmentation": [outputs[0]]} - class CustomSegHead2(MultiInputTestBaseNode): def __init__(self, **kwargs): in_protocols = [ - DisparityCustomMultiInputProtocol, + FeaturesProtocol, FeaturesProtocol, FeaturesProtocol, ] super().__init__(**kwargs, _task_type=LabelType.SEGMENTATION) self.in_protocols = in_protocols - def wrap(self, outputs: list[Tensor]): - return {"segmentation": [outputs[0]]} - @pytest.fixture(scope="function", autouse=True) def clear_output(): From 9987d0d73c7cbf42530afb29f7c4c6137f80ff18 Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Thu, 13 Jun 2024 04:28:13 +0200 Subject: [PATCH 22/23] renamed images_name to image_source --- configs/example_multi_input.yaml | 2 +- .../attached_modules/visualizers/utils.py | 2 +- luxonis_train/core/core.py | 2 +- luxonis_train/models/luxonis_model.py | 4 ++-- luxonis_train/utils/boxutils.py | 2 +- luxonis_train/utils/config.py | 2 +- luxonis_train/utils/loaders/base_loader.py | 14 +++++++------- .../utils/loaders/luxonis_loader_torch.py | 6 +++--- 8 files changed, 17 insertions(+), 17 deletions(-) diff --git a/configs/example_multi_input.yaml b/configs/example_multi_input.yaml index af0edb6a..7d4d252b 100644 --- a/configs/example_multi_input.yaml +++ b/configs/example_multi_input.yaml @@ -6,7 +6,7 @@ loader: # Name of the key in the batch that contains image-like data. # Needs to be set for visualizers and evaluators to work. - images_name: left + image_source: left use_rich_text: True diff --git a/luxonis_train/attached_modules/visualizers/utils.py b/luxonis_train/attached_modules/visualizers/utils.py index 72ba642a..c55b12ce 100644 --- a/luxonis_train/attached_modules/visualizers/utils.py +++ b/luxonis_train/attached_modules/visualizers/utils.py @@ -222,7 +222,7 @@ def unnormalize( def get_unnormalized_images(cfg: Config, inputs: dict[str, Tensor]) -> Tensor: # Get images from inputs according to config - images = inputs[cfg.loader.images_name] + images = inputs[cfg.loader.image_source] normalize_params = cfg.trainer.preprocessing.normalize.params mean = std = None diff --git a/luxonis_train/core/core.py b/luxonis_train/core/core.py index 092656e9..c1b1fb56 100644 --- a/luxonis_train/core/core.py +++ b/luxonis_train/core/core.py @@ -147,7 +147,7 @@ def __init__( if view == "train" else self.cfg.loader.val_view ), - images_name=self.cfg.loader.images_name, + image_source=self.cfg.loader.image_source, **self.cfg.loader.params, ) for view in ["train", "val", "test"] diff --git a/luxonis_train/models/luxonis_model.py b/luxonis_train/models/luxonis_model.py index 06212f2e..2daf61cb 100644 --- a/luxonis_train/models/luxonis_model.py +++ b/luxonis_train/models/luxonis_model.py @@ -125,7 +125,7 @@ def __init__( self.cfg = cfg self.original_in_shape = input_shape - self.images_name = cfg.loader.images_name + self.image_source = cfg.loader.image_source self.dataset_metadata = dataset_metadata or DatasetMetadata() self.frozen_nodes: list[tuple[nn.Module, int]] = [] self.graph: dict[str, list[str]] = {} @@ -269,7 +269,7 @@ def _initiate_nodes( node = Node( input_shapes=node_input_shapes, - original_in_shape=self.original_in_shape[self.images_name], + original_in_shape=self.original_in_shape[self.image_source], dataset_metadata=self.dataset_metadata, **node_kwargs, ) diff --git a/luxonis_train/utils/boxutils.py b/luxonis_train/utils/boxutils.py index e0e3a198..3a26cc4f 100644 --- a/luxonis_train/utils/boxutils.py +++ b/luxonis_train/utils/boxutils.py @@ -434,7 +434,7 @@ def anchors_from_dataset( inputs = inp assert inputs is not None, "No inputs found in data loader" _, _, h, w = inputs[ - loader.dataset.images_name + loader.dataset.image_source # type: ignore ].shape # assuming all images are same size img_size = torch.tensor([w, h]) wh = torch.vstack(widths) * img_size diff --git a/luxonis_train/utils/config.py b/luxonis_train/utils/config.py index c3e21bb3..96d132ab 100644 --- a/luxonis_train/utils/config.py +++ b/luxonis_train/utils/config.py @@ -133,7 +133,7 @@ class TrackerConfig(CustomBaseModel): class LoaderConfig(CustomBaseModel): name: str = "LuxonisLoaderTorch" - images_name: str = "image" + image_source: str = "image" train_view: str = "train" val_view: str = "val" test_view: str = "test" diff --git a/luxonis_train/utils/loaders/base_loader.py b/luxonis_train/utils/loaders/base_loader.py index a46aea07..c4f22428 100644 --- a/luxonis_train/utils/loaders/base_loader.py +++ b/luxonis_train/utils/loaders/base_loader.py @@ -27,21 +27,21 @@ def __init__( self, view: str, augmentations: Augmentations | None = None, - images_name: str | None = None, + image_source: str | None = None, ): self.view = view self.augmentations = augmentations - self._images_name = images_name + self._image_source = image_source @property - def images_name(self) -> str: + def image_source(self) -> str: """Name of the input image group. - Example: 'features' + Example: 'image' """ - if self._images_name is None: - raise ValueError("images_name is not set") - return self._images_name + if self._image_source is None: + raise ValueError("image_source is not set") + return self._image_source @property @abstractmethod diff --git a/luxonis_train/utils/loaders/luxonis_loader_torch.py b/luxonis_train/utils/loaders/luxonis_loader_torch.py index 61c85978..094bc96a 100644 --- a/luxonis_train/utils/loaders/luxonis_loader_torch.py +++ b/luxonis_train/utils/loaders/luxonis_loader_torch.py @@ -43,8 +43,8 @@ def __len__(self) -> int: @property def input_shape(self) -> dict[str, Size]: - img = self[0][0][self.images_name] - return {self.images_name: img.shape} + img = self[0][0][self.image_source] + return {self.image_source: img.shape} def __getitem__(self, idx: int) -> LuxonisLoaderTorchOutput: img, labels = self.base_loader[idx] @@ -55,7 +55,7 @@ def __getitem__(self, idx: int) -> LuxonisLoaderTorchOutput: for task, (array, label_type) in labels.items(): tensor_labels[task] = (Tensor(array), label_type) - return {self.images_name: tensor_img}, tensor_labels + return {self.image_source: tensor_img}, tensor_labels def get_classes(self) -> dict[str, list[str]]: _, classes = self.dataset.get_classes() From 590bd317615ec49f27a7e8c4ba49670ddb73e98c Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Thu, 13 Jun 2024 04:56:03 +0200 Subject: [PATCH 23/23] docformat --- luxonis_train/utils/tracker.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/luxonis_train/utils/tracker.py b/luxonis_train/utils/tracker.py index 65fea368..df157b3b 100644 --- a/luxonis_train/utils/tracker.py +++ b/luxonis_train/utils/tracker.py @@ -1,6 +1,6 @@ from lightning.pytorch.loggers.logger import Logger -from luxonis_ml.tracker import LuxonisTracker from lightning.pytorch.utilities import rank_zero_only # type: ignore +from luxonis_ml.tracker import LuxonisTracker class LuxonisTrackerPL(LuxonisTracker, Logger): @@ -8,7 +8,7 @@ class LuxonisTrackerPL(LuxonisTracker, Logger): @rank_zero_only def finalize(self, status: str = "success") -> None: - """Finalizes current run""" + """Finalizes current run.""" if self.is_tensorboard: self.experiment["tensorboard"].flush() self.experiment["tensorboard"].close()