From ca570637eefae0912dae338cf4b25871b3bba52f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Martin=20Kozlovsk=C3=BD?= Date: Wed, 24 Apr 2024 02:06:57 +0200 Subject: [PATCH] Task Label Groups Support (#22) * handling SIGTERM signal * resume argument takes path * basic task group labels support * updated requirements * fixed tests * fixed loader test * Update luxonis_train/models/luxonis_model.py Co-authored-by: conorsim <60359299+conorsim@users.noreply.github.com> --------- Co-authored-by: conorsim <60359299+conorsim@users.noreply.github.com> --- luxonis_train/models/luxonis_model.py | 12 ++- luxonis_train/utils/boxutils.py | 4 +- luxonis_train/utils/config.py | 1 + luxonis_train/utils/loaders/base_loader.py | 81 ++++++++++--------- .../utils/loaders/luxonis_loader_torch.py | 10 ++- luxonis_train/utils/types.py | 1 + requirements.txt | 3 +- tests/integration/conftest.py | 4 +- tests/unittests/test_core/test_archiver.py | 2 +- .../test_loaders/test_base_loader.py | 6 +- 10 files changed, 71 insertions(+), 53 deletions(-) diff --git a/luxonis_train/models/luxonis_model.py b/luxonis_train/models/luxonis_model.py index 7cd396f9..58aeccd1 100644 --- a/luxonis_train/models/luxonis_model.py +++ b/luxonis_train/models/luxonis_model.py @@ -35,7 +35,7 @@ ) from luxonis_train.utils.registry import CALLBACKS, OPTIMIZERS, SCHEDULERS, Registry from luxonis_train.utils.tracker import LuxonisTrackerPL -from luxonis_train.utils.types import Kwargs, Labels, Packet +from luxonis_train.utils.types import Kwargs, Labels, Packet, TaskLabels from .luxonis_output import LuxonisOutput @@ -139,10 +139,13 @@ def __init__( frozen_nodes: list[tuple[str, int]] = [] nodes: dict[str, tuple[type[BaseNode], Kwargs]] = {} + self.node_tasks: dict[str, str] = {} + for node_cfg in self.cfg.model.nodes: node_name = node_cfg.name Node = BaseNode.REGISTRY.get(node_name) node_name = node_cfg.alias or node_name + self.node_tasks[node_name] = node_cfg.task_group if node_cfg.freezing.active: epochs = self.cfg.trainer.epochs if node_cfg.freezing.unfreeze_after is None: @@ -244,7 +247,7 @@ def _initiate_nodes( def forward( self, inputs: Tensor, - labels: Labels | None = None, + task_labels: TaskLabels | None = None, images: Tensor | None = None, *, compute_loss: bool = True, @@ -259,8 +262,8 @@ def forward( @type inputs: L{Tensor} @param inputs: Input tensor. - @type labels: L{Labels} | None - @param labels: Labels dictionary. Defaults to C{None}. + @type task_labels: L{TaskLabels} | None + @param task_labels: Labels dictionary. Defaults to C{None}. @type images: L{Tensor} | None @param images: Canvas tensor for visualizers. Defaults to C{None}. @type compute_loss: bool @@ -296,6 +299,7 @@ def forward( node_inputs = [computed[pred] for pred in input_names] outputs = node.run(node_inputs) computed[node_name] = outputs + labels = task_labels[self.node_tasks[node_name]] if task_labels else None if compute_loss and node_name in self.losses and labels is not None: for loss_name, loss in self.losses[node_name].items(): diff --git a/luxonis_train/utils/boxutils.py b/luxonis_train/utils/boxutils.py index 0d708f79..a59f4cd0 100644 --- a/luxonis_train/utils/boxutils.py +++ b/luxonis_train/utils/boxutils.py @@ -404,6 +404,7 @@ def anchors_from_dataset( n_anchors: int = 9, n_generations: int = 1000, ratio_threshold: float = 4.0, + task_group: str = "default", ) -> tuple[Tensor, float]: """Generates anchors based on bounding box annotations present in provided data loader. It uses K-Means for initial proposals which are then refined with genetic @@ -425,7 +426,8 @@ def anchors_from_dataset( widths = [] inputs = None - for inp, labels in loader: + for inp, task_labels in loader: + labels = next(iter(task_labels.values())) # TODO: handle multiple tasks boxes = labels[LabelType.BOUNDINGBOX] curr_wh = boxes[:, 4:] widths.append(curr_wh) diff --git a/luxonis_train/utils/config.py b/luxonis_train/utils/config.py index a2d4f332..45dde192 100644 --- a/luxonis_train/utils/config.py +++ b/luxonis_train/utils/config.py @@ -43,6 +43,7 @@ class ModelNodeConfig(CustomBaseModel): inputs: list[str] = [] params: dict[str, Any] = {} freezing: FreezingConfig = FreezingConfig() + task_group: str = "default" class PredefinedModelConfig(CustomBaseModel): diff --git a/luxonis_train/utils/loaders/base_loader.py b/luxonis_train/utils/loaders/base_loader.py index 93f3fd0c..be12b439 100644 --- a/luxonis_train/utils/loaders/base_loader.py +++ b/luxonis_train/utils/loaders/base_loader.py @@ -8,7 +8,7 @@ from luxonis_train.utils.registry import LOADERS from luxonis_train.utils.types import Labels, LabelType -LuxonisLoaderTorchOutput = tuple[Tensor, Labels] +LuxonisLoaderTorchOutput = tuple[Tensor, dict[str, Labels]] """LuxonisLoaderTorchOutput is a tuple of images and corresponding labels.""" @@ -46,7 +46,7 @@ def __getitem__(self, idx: int) -> LuxonisLoaderTorchOutput: def collate_fn( batch: list[LuxonisLoaderTorchOutput], -) -> tuple[Tensor, dict[LabelType, Tensor]]: +) -> tuple[Tensor, dict[str, dict[LabelType, Tensor]]]: """Default collate function used for training. @type batch: list[LuxonisLoaderTorchOutput] @@ -55,41 +55,46 @@ def collate_fn( @rtype: tuple[Tensor, dict[LabelType, Tensor]] @return: Tuple of images and annotations in the format expected by the model. """ - zipped = zip(*batch) - imgs, anno_dicts = zipped + imgs, group_dicts = zip(*batch) + out_group_dicts = {task: {} for task in group_dicts[0].keys()} imgs = torch.stack(imgs, 0) - present_annotations = anno_dicts[0].keys() - out_annotations: dict[LabelType, Tensor] = { - anno: torch.empty(0) for anno in present_annotations - } - - if LabelType.CLASSIFICATION in present_annotations: - class_annos = [anno[LabelType.CLASSIFICATION] for anno in anno_dicts] - out_annotations[LabelType.CLASSIFICATION] = torch.stack(class_annos, 0) - - if LabelType.SEGMENTATION in present_annotations: - seg_annos = [anno[LabelType.SEGMENTATION] for anno in anno_dicts] - out_annotations[LabelType.SEGMENTATION] = torch.stack(seg_annos, 0) - - if LabelType.BOUNDINGBOX in present_annotations: - bbox_annos = [anno[LabelType.BOUNDINGBOX] for anno in anno_dicts] - label_box: list[Tensor] = [] - for i, box in enumerate(bbox_annos): - l_box = torch.zeros((box.shape[0], 6)) - l_box[:, 0] = i # add target image index for build_targets() - l_box[:, 1:] = box - label_box.append(l_box) - out_annotations[LabelType.BOUNDINGBOX] = torch.cat(label_box, 0) - - if LabelType.KEYPOINT in present_annotations: - keypoint_annos = [anno[LabelType.KEYPOINT] for anno in anno_dicts] - label_keypoints: list[Tensor] = [] - for i, points in enumerate(keypoint_annos): - l_kps = torch.zeros((points.shape[0], points.shape[1] + 1)) - l_kps[:, 0] = i # add target image index for build_targets() - l_kps[:, 1:] = points - label_keypoints.append(l_kps) - out_annotations[LabelType.KEYPOINT] = torch.cat(label_keypoints, 0) - - return imgs, out_annotations + for task in list(group_dicts[0].keys()): + anno_dicts = [group[task] for group in group_dicts] + + present_annotations = anno_dicts[0].keys() + out_annotations: dict[LabelType, Tensor] = { + anno: torch.empty(0) for anno in present_annotations + } + + if LabelType.CLASSIFICATION in present_annotations: + class_annos = [anno[LabelType.CLASSIFICATION] for anno in anno_dicts] + out_annotations[LabelType.CLASSIFICATION] = torch.stack(class_annos, 0) + + if LabelType.SEGMENTATION in present_annotations: + seg_annos = [anno[LabelType.SEGMENTATION] for anno in anno_dicts] + out_annotations[LabelType.SEGMENTATION] = torch.stack(seg_annos, 0) + + if LabelType.BOUNDINGBOX in present_annotations: + bbox_annos = [anno[LabelType.BOUNDINGBOX] for anno in anno_dicts] + label_box: list[Tensor] = [] + for i, box in enumerate(bbox_annos): + l_box = torch.zeros((box.shape[0], 6)) + l_box[:, 0] = i # add target image index for build_targets() + l_box[:, 1:] = box + label_box.append(l_box) + out_annotations[LabelType.BOUNDINGBOX] = torch.cat(label_box, 0) + + if LabelType.KEYPOINT in present_annotations: + keypoint_annos = [anno[LabelType.KEYPOINT] for anno in anno_dicts] + label_keypoints: list[Tensor] = [] + for i, points in enumerate(keypoint_annos): + l_kps = torch.zeros((points.shape[0], points.shape[1] + 1)) + l_kps[:, 0] = i # add target image index for build_targets() + l_kps[:, 1:] = points + label_keypoints.append(l_kps) + out_annotations[LabelType.KEYPOINT] = torch.cat(label_keypoints, 0) + + out_group_dicts[task] = out_annotations + + return imgs, out_group_dicts diff --git a/luxonis_train/utils/loaders/luxonis_loader_torch.py b/luxonis_train/utils/loaders/luxonis_loader_torch.py index a0e1f324..dfd4091a 100644 --- a/luxonis_train/utils/loaders/luxonis_loader_torch.py +++ b/luxonis_train/utils/loaders/luxonis_loader_torch.py @@ -29,11 +29,13 @@ def input_shape(self) -> Size: return Size([1, *img.shape]) def __getitem__(self, idx: int) -> LuxonisLoaderTorchOutput: - img, annotations = self.base_loader[idx] + img, group_annotations = self.base_loader[idx] img = np.transpose(img, (2, 0, 1)) # HWC to CHW tensor_img = Tensor(img) - for key in annotations: - annotations[key] = Tensor(annotations[key]) # type: ignore + for task in group_annotations: + annotations = group_annotations[task] + for key in annotations: + annotations[key] = Tensor(annotations[key]) # type: ignore - return tensor_img, annotations + return tensor_img, group_annotations diff --git a/luxonis_train/utils/types.py b/luxonis_train/utils/types.py index dbbf471e..3fb724c3 100644 --- a/luxonis_train/utils/types.py +++ b/luxonis_train/utils/types.py @@ -7,6 +7,7 @@ Kwargs = dict[str, Any] OutputTypes = Literal["boxes", "class", "keypoints", "segmentation", "features"] Labels = dict[LabelType, Tensor] +TaskLabels = dict[str, Labels] AttachIndexType = Literal["all"] | int | tuple[int, int] | tuple[int, int, int] """AttachIndexType is used to specify to which output of the prevoius node does the diff --git a/requirements.txt b/requirements.txt index 03081b48..7f7e996a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,7 @@ blobconverter>=1.4.2 lightning>=2.0.0 -luxonis-ml[all]>=0.1.0 +#luxonis-ml[all]>=0.1.0 +luxonis-ml[all]@git+https://github.com/luxonis/luxonis-ml.git@dev onnx>=1.12.0 onnxruntime>=1.13.1 onnxsim>=0.4.10 diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 35c893d4..815a4bd5 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -120,7 +120,7 @@ def COCO_people_subset_generator(): } } ) - dataset.add(COCO_people_subset_generator) # type: ignore + dataset.add(COCO_people_subset_generator()) dataset.make_splits() @@ -161,5 +161,5 @@ def CIFAR10_subset_generator(): dataset.set_classes(classes) - dataset.add(CIFAR10_subset_generator) # type: ignore + dataset.add(CIFAR10_subset_generator()) dataset.make_splits() diff --git a/tests/unittests/test_core/test_archiver.py b/tests/unittests/test_core/test_archiver.py index a044be52..fe10a46e 100644 --- a/tests/unittests/test_core/test_archiver.py +++ b/tests/unittests/test_core/test_archiver.py @@ -226,7 +226,7 @@ def dataset_generator(): for label in labels } ) - dataset.add(dataset_generator) + dataset.add(dataset_generator()) dataset.make_splits(ratios=split_ratios) def _make_dummy_cfg_dict(head_name: str, ldf_name: str, save_path: str) -> dict: diff --git a/tests/unittests/test_utils/test_loaders/test_base_loader.py b/tests/unittests/test_utils/test_loaders/test_base_loader.py index e48f81ad..b5c8b299 100644 --- a/tests/unittests/test_utils/test_loaders/test_base_loader.py +++ b/tests/unittests/test_utils/test_loaders/test_base_loader.py @@ -12,11 +12,11 @@ def test_collate_fn(): batch = [ ( torch.rand(3, 224, 224, dtype=torch.float32), - {LabelType.CLASSIFICATION: torch.tensor([1, 0])}, + {"default": {LabelType.CLASSIFICATION: torch.tensor([1, 0])}}, ), ( torch.rand(3, 224, 224, dtype=torch.float32), - {LabelType.CLASSIFICATION: torch.tensor([0, 1])}, + {"default": {LabelType.CLASSIFICATION: torch.tensor([0, 1])}}, ), ] @@ -28,6 +28,8 @@ def test_collate_fn(): assert imgs.dtype == torch.float32 # Check annotations + assert "default" in annotations + annotations = annotations["default"] assert LabelType.CLASSIFICATION in annotations assert annotations[LabelType.CLASSIFICATION].shape == (2, 2) assert annotations[LabelType.CLASSIFICATION].dtype == torch.int64