Skip to content

Commit

Permalink
Task Label Groups Support (#22)
Browse files Browse the repository at this point in the history
* handling SIGTERM signal

* resume argument takes path

* basic task group labels support

* updated requirements

* fixed tests

* fixed loader test

* Update luxonis_train/models/luxonis_model.py

Co-authored-by: conorsim <[email protected]>

---------

Co-authored-by: conorsim <[email protected]>
  • Loading branch information
kozlov721 and conorsim committed Oct 9, 2024
1 parent d0740d0 commit 732ca0a
Show file tree
Hide file tree
Showing 10 changed files with 71 additions and 53 deletions.
12 changes: 8 additions & 4 deletions luxonis_train/models/luxonis_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
)
from luxonis_train.utils.registry import CALLBACKS, OPTIMIZERS, SCHEDULERS, Registry
from luxonis_train.utils.tracker import LuxonisTrackerPL
from luxonis_train.utils.types import Kwargs, Labels, Packet
from luxonis_train.utils.types import Kwargs, Labels, Packet, TaskLabels

from .luxonis_output import LuxonisOutput

Expand Down Expand Up @@ -139,10 +139,13 @@ def __init__(
frozen_nodes: list[tuple[str, int]] = []
nodes: dict[str, tuple[type[BaseNode], Kwargs]] = {}

self.node_tasks: dict[str, str] = {}

for node_cfg in self.cfg.model.nodes:
node_name = node_cfg.name
Node = BaseNode.REGISTRY.get(node_name)
node_name = node_cfg.alias or node_name
self.node_tasks[node_name] = node_cfg.task_group
if node_cfg.freezing.active:
epochs = self.cfg.trainer.epochs
if node_cfg.freezing.unfreeze_after is None:
Expand Down Expand Up @@ -244,7 +247,7 @@ def _initiate_nodes(
def forward(
self,
inputs: Tensor,
labels: Labels | None = None,
task_labels: TaskLabels | None = None,
images: Tensor | None = None,
*,
compute_loss: bool = True,
Expand All @@ -259,8 +262,8 @@ def forward(
@type inputs: L{Tensor}
@param inputs: Input tensor.
@type labels: L{Labels} | None
@param labels: Labels dictionary. Defaults to C{None}.
@type task_labels: L{TaskLabels} | None
@param task_labels: Labels dictionary. Defaults to C{None}.
@type images: L{Tensor} | None
@param images: Canvas tensor for visualizers. Defaults to C{None}.
@type compute_loss: bool
Expand Down Expand Up @@ -296,6 +299,7 @@ def forward(
node_inputs = [computed[pred] for pred in input_names]
outputs = node.run(node_inputs)
computed[node_name] = outputs
labels = task_labels[self.node_tasks[node_name]] if task_labels else None

if compute_loss and node_name in self.losses and labels is not None:
for loss_name, loss in self.losses[node_name].items():
Expand Down
4 changes: 3 additions & 1 deletion luxonis_train/utils/boxutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,7 @@ def anchors_from_dataset(
n_anchors: int = 9,
n_generations: int = 1000,
ratio_threshold: float = 4.0,
task_group: str = "default",
) -> tuple[Tensor, float]:
"""Generates anchors based on bounding box annotations present in provided data
loader. It uses K-Means for initial proposals which are then refined with genetic
Expand All @@ -425,7 +426,8 @@ def anchors_from_dataset(

widths = []
inputs = None
for inp, labels in loader:
for inp, task_labels in loader:
labels = next(iter(task_labels.values())) # TODO: handle multiple tasks
boxes = labels[LabelType.BOUNDINGBOX]
curr_wh = boxes[:, 4:]
widths.append(curr_wh)
Expand Down
1 change: 1 addition & 0 deletions luxonis_train/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class ModelNodeConfig(CustomBaseModel):
inputs: list[str] = []
params: dict[str, Any] = {}
freezing: FreezingConfig = FreezingConfig()
task_group: str = "default"


class PredefinedModelConfig(CustomBaseModel):
Expand Down
81 changes: 43 additions & 38 deletions luxonis_train/utils/loaders/base_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from luxonis_train.utils.registry import LOADERS
from luxonis_train.utils.types import Labels, LabelType

LuxonisLoaderTorchOutput = tuple[Tensor, Labels]
LuxonisLoaderTorchOutput = tuple[Tensor, dict[str, Labels]]
"""LuxonisLoaderTorchOutput is a tuple of images and corresponding labels."""


Expand Down Expand Up @@ -46,7 +46,7 @@ def __getitem__(self, idx: int) -> LuxonisLoaderTorchOutput:

def collate_fn(
batch: list[LuxonisLoaderTorchOutput],
) -> tuple[Tensor, dict[LabelType, Tensor]]:
) -> tuple[Tensor, dict[str, dict[LabelType, Tensor]]]:
"""Default collate function used for training.
@type batch: list[LuxonisLoaderTorchOutput]
Expand All @@ -55,41 +55,46 @@ def collate_fn(
@rtype: tuple[Tensor, dict[LabelType, Tensor]]
@return: Tuple of images and annotations in the format expected by the model.
"""
zipped = zip(*batch)
imgs, anno_dicts = zipped
imgs, group_dicts = zip(*batch)
out_group_dicts = {task: {} for task in group_dicts[0].keys()}
imgs = torch.stack(imgs, 0)

present_annotations = anno_dicts[0].keys()
out_annotations: dict[LabelType, Tensor] = {
anno: torch.empty(0) for anno in present_annotations
}

if LabelType.CLASSIFICATION in present_annotations:
class_annos = [anno[LabelType.CLASSIFICATION] for anno in anno_dicts]
out_annotations[LabelType.CLASSIFICATION] = torch.stack(class_annos, 0)

if LabelType.SEGMENTATION in present_annotations:
seg_annos = [anno[LabelType.SEGMENTATION] for anno in anno_dicts]
out_annotations[LabelType.SEGMENTATION] = torch.stack(seg_annos, 0)

if LabelType.BOUNDINGBOX in present_annotations:
bbox_annos = [anno[LabelType.BOUNDINGBOX] for anno in anno_dicts]
label_box: list[Tensor] = []
for i, box in enumerate(bbox_annos):
l_box = torch.zeros((box.shape[0], 6))
l_box[:, 0] = i # add target image index for build_targets()
l_box[:, 1:] = box
label_box.append(l_box)
out_annotations[LabelType.BOUNDINGBOX] = torch.cat(label_box, 0)

if LabelType.KEYPOINT in present_annotations:
keypoint_annos = [anno[LabelType.KEYPOINT] for anno in anno_dicts]
label_keypoints: list[Tensor] = []
for i, points in enumerate(keypoint_annos):
l_kps = torch.zeros((points.shape[0], points.shape[1] + 1))
l_kps[:, 0] = i # add target image index for build_targets()
l_kps[:, 1:] = points
label_keypoints.append(l_kps)
out_annotations[LabelType.KEYPOINT] = torch.cat(label_keypoints, 0)

return imgs, out_annotations
for task in list(group_dicts[0].keys()):
anno_dicts = [group[task] for group in group_dicts]

present_annotations = anno_dicts[0].keys()
out_annotations: dict[LabelType, Tensor] = {
anno: torch.empty(0) for anno in present_annotations
}

if LabelType.CLASSIFICATION in present_annotations:
class_annos = [anno[LabelType.CLASSIFICATION] for anno in anno_dicts]
out_annotations[LabelType.CLASSIFICATION] = torch.stack(class_annos, 0)

if LabelType.SEGMENTATION in present_annotations:
seg_annos = [anno[LabelType.SEGMENTATION] for anno in anno_dicts]
out_annotations[LabelType.SEGMENTATION] = torch.stack(seg_annos, 0)

if LabelType.BOUNDINGBOX in present_annotations:
bbox_annos = [anno[LabelType.BOUNDINGBOX] for anno in anno_dicts]
label_box: list[Tensor] = []
for i, box in enumerate(bbox_annos):
l_box = torch.zeros((box.shape[0], 6))
l_box[:, 0] = i # add target image index for build_targets()
l_box[:, 1:] = box
label_box.append(l_box)
out_annotations[LabelType.BOUNDINGBOX] = torch.cat(label_box, 0)

if LabelType.KEYPOINT in present_annotations:
keypoint_annos = [anno[LabelType.KEYPOINT] for anno in anno_dicts]
label_keypoints: list[Tensor] = []
for i, points in enumerate(keypoint_annos):
l_kps = torch.zeros((points.shape[0], points.shape[1] + 1))
l_kps[:, 0] = i # add target image index for build_targets()
l_kps[:, 1:] = points
label_keypoints.append(l_kps)
out_annotations[LabelType.KEYPOINT] = torch.cat(label_keypoints, 0)

out_group_dicts[task] = out_annotations

return imgs, out_group_dicts
10 changes: 6 additions & 4 deletions luxonis_train/utils/loaders/luxonis_loader_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,13 @@ def input_shape(self) -> Size:
return Size([1, *img.shape])

def __getitem__(self, idx: int) -> LuxonisLoaderTorchOutput:
img, annotations = self.base_loader[idx]
img, group_annotations = self.base_loader[idx]

img = np.transpose(img, (2, 0, 1)) # HWC to CHW
tensor_img = Tensor(img)
for key in annotations:
annotations[key] = Tensor(annotations[key]) # type: ignore
for task in group_annotations:
annotations = group_annotations[task]
for key in annotations:
annotations[key] = Tensor(annotations[key]) # type: ignore

return tensor_img, annotations
return tensor_img, group_annotations
1 change: 1 addition & 0 deletions luxonis_train/utils/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
Kwargs = dict[str, Any]
OutputTypes = Literal["boxes", "class", "keypoints", "segmentation", "features"]
Labels = dict[LabelType, Tensor]
TaskLabels = dict[str, Labels]

AttachIndexType = Literal["all"] | int | tuple[int, int] | tuple[int, int, int]
"""AttachIndexType is used to specify to which output of the prevoius node does the
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
blobconverter>=1.4.2
lightning>=2.0.0
luxonis-ml[all]>=0.1.0
#luxonis-ml[all]>=0.1.0
luxonis-ml[all]@git+https://github.com/luxonis/luxonis-ml.git@dev
onnx>=1.12.0
onnxruntime>=1.13.1
onnxsim>=0.4.10
Expand Down
4 changes: 2 additions & 2 deletions tests/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def COCO_people_subset_generator():
}
}
)
dataset.add(COCO_people_subset_generator) # type: ignore
dataset.add(COCO_people_subset_generator())
dataset.make_splits()


Expand Down Expand Up @@ -161,5 +161,5 @@ def CIFAR10_subset_generator():

dataset.set_classes(classes)

dataset.add(CIFAR10_subset_generator) # type: ignore
dataset.add(CIFAR10_subset_generator())
dataset.make_splits()
2 changes: 1 addition & 1 deletion tests/unittests/test_core/test_archiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def dataset_generator():
for label in labels
}
)
dataset.add(dataset_generator)
dataset.add(dataset_generator())
dataset.make_splits(ratios=split_ratios)

def _make_dummy_cfg_dict(head_name: str, ldf_name: str, save_path: str) -> dict:
Expand Down
6 changes: 4 additions & 2 deletions tests/unittests/test_utils/test_loaders/test_base_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ def test_collate_fn():
batch = [
(
torch.rand(3, 224, 224, dtype=torch.float32),
{LabelType.CLASSIFICATION: torch.tensor([1, 0])},
{"default": {LabelType.CLASSIFICATION: torch.tensor([1, 0])}},
),
(
torch.rand(3, 224, 224, dtype=torch.float32),
{LabelType.CLASSIFICATION: torch.tensor([0, 1])},
{"default": {LabelType.CLASSIFICATION: torch.tensor([0, 1])}},
),
]

Expand All @@ -28,6 +28,8 @@ def test_collate_fn():
assert imgs.dtype == torch.float32

# Check annotations
assert "default" in annotations
annotations = annotations["default"]
assert LabelType.CLASSIFICATION in annotations
assert annotations[LabelType.CLASSIFICATION].shape == (2, 2)
assert annotations[LabelType.CLASSIFICATION].dtype == torch.int64
Expand Down

0 comments on commit 732ca0a

Please sign in to comment.