Skip to content

Commit

Permalink
Merge branch 'dev' into feature/simplify-config
Browse files Browse the repository at this point in the history
  • Loading branch information
kozlov721 committed Sep 25, 2024
2 parents f8e11df + be983ba commit 9c84111
Show file tree
Hide file tree
Showing 44 changed files with 377 additions and 267 deletions.
13 changes: 6 additions & 7 deletions luxonis_train/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
__version__ = "0.0.1"

import logging

logger = logging.getLogger(__name__)

import warnings

try:
from .attached_modules import *
Expand All @@ -14,8 +11,10 @@
from .optimizers import *
from .schedulers import *
from .utils import *
except ImportError:
logger.warning(
except ImportError as e:
warnings.warn(
"Failed to import submodules. "
"Some functionality of `luxonis-train` may be unavailable."
"Some functionality of `luxonis-train` may be unavailable. "
f"Error: `{e}`",
stacklevel=2,
)
102 changes: 51 additions & 51 deletions luxonis_train/attached_modules/base_attached_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
from contextlib import suppress
from typing import Generic

from luxonis_ml.data import LabelType
from luxonis_ml.utils.registry import AutoRegisterMeta
from torch import Size, Tensor, nn
from typing_extensions import TypeVarTuple, Unpack

from luxonis_train.enums import TaskType
from luxonis_train.nodes import BaseNode
from luxonis_train.utils import IncompatibleException, Labels, Packet

Expand Down Expand Up @@ -36,41 +36,41 @@ class BaseAttachedModule(
Override this method if the default implementation is not sufficient.
Additionally, the following attributes can be overridden:
- L{supported_labels}: List of label types that the module supports.
- L{supported_tasks}: List of task types that the module supports.
Used to determine which labels to extract from the dataset and to validate
compatibility with the node based on the node's tasks.
@type node: BaseNode
@param node: Reference to the node that this module is attached to.
@type supported_labels: list[LabelType | tuple[LabelType, ...]] | None
@ivar supported_labels: List of label types that the module supports.
Elements of the list can be either a single label type or a tuple of
label types. In case of the latter, the module requires all of the
@type supported_tasks: list[TaskType | tuple[TaskType, ...]] | None
@ivar supported_tasks: List of task types that the module supports.
Elements of the list can be either a single task type or a tuple of
task types. In case of the latter, the module requires all of the
specified labels in the tuple to be present.
Example:
- C{[LabelType.CLASSIFICATION, LabelType.SEGMENTATION]} means that the
- C{[TaskType.CLASSIFICATION, TaskType.SEGMENTATION]} means that the
module requires either classification or segmentation labels.
- C{[(LabelType.BOUNDINGBOX, LabelType.KEYPOINTS), LabelType.SEGMENTATION]}
- C{[(TaskType.BOUNDINGBOX, TaskType.KEYPOINTS), TaskType.SEGMENTATION]}
means that the module requires either both bounding box I{and} keypoint
labels I{or} segmentation labels.
"""

supported_labels: list[LabelType | tuple[LabelType, ...]] | None = None
supported_tasks: list[TaskType | tuple[TaskType, ...]] | None = None

def __init__(self, *, node: BaseNode | None = None):
super().__init__()
self._node = node
self._epoch = 0

self.required_labels: list[LabelType] = []
if self._node and self.supported_labels:
self.required_labels: list[TaskType] = []
if self._node and self.supported_tasks:
module_supported = [
label.value
if isinstance(label, LabelType)
if isinstance(label, TaskType)
else f"({' + '.join(label)})"
for label in self.supported_labels
for label in self.supported_tasks
]
module_supported = f"[{', '.join(module_supported)}]"
if not self.node.tasks:
Expand All @@ -80,8 +80,8 @@ def __init__(self, *, node: BaseNode | None = None):
f"but is connected to node '{self.node.name}' which does not specify any tasks."
)
node_tasks = set(self.node.tasks)
for required_labels in self.supported_labels:
if isinstance(required_labels, LabelType):
for required_labels in self.supported_tasks:
if isinstance(required_labels, TaskType):
required_labels = [required_labels]
else:
required_labels = list(required_labels)
Expand Down Expand Up @@ -159,10 +159,10 @@ def class_names(self) -> list[str]:
return self.node.class_names

@property
def node_tasks(self) -> dict[LabelType, str]:
def node_tasks(self) -> dict[TaskType, str]:
"""Getter for the tasks of the attached node.
@type: dict[LabelType, str]
@type: dict[TaskType, str]
@raises RuntimeError: If the node does not have the `tasks` attribute set.
"""
if self.node._tasks is None:
Expand All @@ -172,75 +172,75 @@ def node_tasks(self) -> dict[LabelType, str]:
return self.node._tasks

def get_label(
self, labels: Labels, label_type: LabelType | None = None
self, labels: Labels, task_type: TaskType | None = None
) -> Tensor:
"""Extracts a specific label from the labels dictionary.
If the label type is not provided, the first label that matches the
required label type is returned.
If the task type is not provided, the first label that matches the
required task type is returned.
Example::
>>> # supported_labels = [LabelType.SEGMENTATION]
>>> # supported_tasks = [TaskType.SEGMENTATION]
>>> labels = {"segmentation": seg_tensor, "boundingbox": bbox_tensor}
>>> get_label(labels)
seg_tensor # returns the first matching label
>>> get_label(labels, LabelType.BOUNDINGBOX)
>>> get_label(labels, TaskType.BOUNDINGBOX)
bbox_tensor # returns the bounding box label
>>> get_label(labels, LabelType.CLASSIFICATION)
>>> get_label(labels, TaskType.CLASSIFICATION)
IncompatibleException: Label 'classification' is missing from the dataset.
@type labels: L{Labels}
@param labels: Labels from the dataset.
@type label_type: LabelType | None
@param label_type: Type of the label to extract.
@type task_type: TaskType | None
@param task_type: Type of the label to extract.
@rtype: Tensor
@return: Extracted label
@raises ValueError: If the module requires multiple labels and the C{label_type} is not provided.
@raises ValueError: If the module requires multiple labels and the C{task_type} is not provided.
@raises IncompatibleException: If the label is not found in the labels dictionary.
"""
return self._get_label(labels, label_type)[0]
return self._get_label(labels, task_type)[0]

def _get_label(
self, labels: Labels, label_type: LabelType | None = None
) -> tuple[Tensor, LabelType]:
if label_type is None:
self, labels: Labels, task_type: TaskType | None = None
) -> tuple[Tensor, TaskType]:
if task_type is None:
if len(self.required_labels) == 1:
label_type = self.required_labels[0]
task_type = self.required_labels[0]

if label_type is not None:
task_name = self.node.get_task_name(label_type)
if task_type is not None:
task_name = self.node.get_task_name(task_type)
if task_name not in labels:
raise IncompatibleException.from_missing_task(
label_type.value, list(labels.keys()), self.name
task_type.value, list(labels.keys()), self.name
)
return labels[task_name]

raise ValueError(
f"{self.name} requires multiple labels. You must provide the "
"`label_type` argument to extract the desired label."
"`task_type` argument to extract the desired label."
)

def get_input_tensors(
self, inputs: Packet[Tensor], task_type: LabelType | str | None = None
self, inputs: Packet[Tensor], task_type: TaskType | str | None = None
) -> list[Tensor]:
"""Extracts the input tensors from the packet.
Example::
>>> # supported_labels = [LabelType.SEGMENTATION]
>>> # node.tasks = {LabelType.SEGMENTATION: "segmentation-task"}
>>> # supported_tasks = [TaskType.SEGMENTATION]
>>> # node.tasks = {TaskType.SEGMENTATION: "segmentation-task"}
>>> inputs = [{"segmentation-task": [seg_tensor]}, {"features": [feat_tensor]}]
>>> get_input_tensors(inputs) # matches supported labels to node's tasks
[seg_tensor]
>>> get_input_tensors(inputs, "features")
[feat_tensor]
>>> get_input_tensors(inputs, LabelType.CLASSIFICATION)
>>> get_input_tensors(inputs, TaskType.CLASSIFICATION)
ValueError: Task 'classification' is not supported by the node.
@type inputs: L{Packet}[Tensor]
@param inputs: Output from the node this module is attached to.
@type task_type: LabelType | str | None
@type task_type: TaskType | str | None
@param task_type: Type of the task to extract. Must be provided when the node
supports multiple tasks or if the module doesn't require any tasks.
@rtype: list[Tensor]
Expand All @@ -253,7 +253,7 @@ def get_input_tensors(
For such cases, the `prepare` method should be overridden.
"""
if task_type is not None:
if isinstance(task_type, LabelType):
if isinstance(task_type, TaskType):
if task_type not in self.node_tasks:
raise IncompatibleException(
f"Task {task_type.value} is not supported by the node "
Expand All @@ -280,7 +280,7 @@ def prepare(
"""Prepares node outputs for the forward pass of the module.
This default implementation selects the output and label based on
C{supported_labels} attribute. If not set, then it returns the first
C{supported_tasks} attribute. If not set, then it returns the first
matching output and label.
That is the first pair of outputs and labels that have the same type.
For more complex modules this method should be overridden.
Expand All @@ -301,32 +301,32 @@ def prepare(
implementation cannot be used and the C{prepare} method should be overridden.
@raises RuntimeError: If the C{tasks} attribute is not set on the node.
@raises RuntimeError: If the C{supported_labels} attribute is not set on the module.
@raises RuntimeError: If the C{supported_tasks} attribute is not set on the module.
"""
if self.node._tasks is None:
raise RuntimeError(
f"{self.node.name} must have the `tasks` attribute specified "
f"for {self.name} to make use of the default `prepare` method."
)
if self.supported_labels is None:
if self.supported_tasks is None:
raise RuntimeError(
f"{self.name} must have the `supported_labels` attribute "
f"{self.name} must have the `supported_tasks` attribute "
"specified in order to use the default `prepare` method."
)
if len(self.supported_labels) > 1:
if len(self.supported_tasks) > 1:
if len(self.node_tasks) > 1:
raise RuntimeError(
f"{self.name} supports more than one label type"
f"{self.name} supports more than one task type"
f"and is connected to {self.node.name} node "
"which is a multi-task node. The default `prepare` "
"implementation cannot be used in this case."
)
self.supported_labels = list(
set(self.supported_labels) & set(self.node_tasks)
self.supported_tasks = list(
set(self.supported_tasks) & set(self.node_tasks)
)
x = self.get_input_tensors(inputs)
label, label_type = self._get_label(labels)
if label_type in [LabelType.CLASSIFICATION, LabelType.SEGMENTATION]:
label, task_type = self._get_label(labels)
if task_type in [TaskType.CLASSIFICATION, TaskType.SEGMENTATION]:
if len(x) == 1:
x = x[0]
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@

import torch
import torch.nn.functional as F
from luxonis_ml.data import LabelType
from torch import Tensor, nn
from torchvision.ops import box_convert

from luxonis_train.assigners import ATSSAssigner, TaskAlignedAssigner
from luxonis_train.enums import TaskType
from luxonis_train.nodes import EfficientBBoxHead
from luxonis_train.utils import (
Labels,
Expand All @@ -27,7 +27,7 @@ class AdaptiveDetectionLoss(
BaseLoss[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]
):
node: EfficientBBoxHead
supported_labels = [LabelType.BOUNDINGBOX]
supported_tasks: list[TaskType] = [TaskType.BOUNDINGBOX]

anchors: Tensor
anchor_points: Tensor
Expand Down
8 changes: 6 additions & 2 deletions luxonis_train/attached_modules/losses/bce_with_logits.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
from typing import Any, Literal

import torch
from luxonis_ml.data import LabelType
from torch import Tensor, nn

from luxonis_train.enums import TaskType

from .base_loss import BaseLoss


class BCEWithLogitsLoss(BaseLoss[Tensor, Tensor]):
supported_labels = [LabelType.SEGMENTATION, LabelType.CLASSIFICATION]
supported_tasks: list[TaskType] = [
TaskType.SEGMENTATION,
TaskType.CLASSIFICATION,
]

def __init__(
self,
Expand Down
8 changes: 6 additions & 2 deletions luxonis_train/attached_modules/losses/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@

import torch
import torch.nn as nn
from luxonis_ml.data import LabelType
from torch import Tensor

from luxonis_train.enums import TaskType

from .base_loss import BaseLoss

logger = getLogger(__name__)
Expand All @@ -15,7 +16,10 @@ class CrossEntropyLoss(BaseLoss[Tensor, Tensor]):
"""This criterion computes the cross entropy loss between input
logits and target."""

supported_labels = [LabelType.SEGMENTATION, LabelType.CLASSIFICATION]
supported_tasks: list[TaskType] = [
TaskType.SEGMENTATION,
TaskType.CLASSIFICATION,
]

def __init__(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

import torch
import torch.nn.functional as F
from luxonis_ml.data import LabelType
from torch import Tensor

from luxonis_train.attached_modules.losses import AdaptiveDetectionLoss
from luxonis_train.enums import TaskType
from luxonis_train.nodes import EfficientKeypointBBoxHead
from luxonis_train.utils import (
Labels,
Expand All @@ -22,7 +22,9 @@

class EfficientKeypointBBoxLoss(AdaptiveDetectionLoss):
node: EfficientKeypointBBoxHead
supported_labels = [(LabelType.BOUNDINGBOX, LabelType.KEYPOINTS)]
supported_tasks: list[tuple[TaskType, ...]] = [
(TaskType.BOUNDINGBOX, TaskType.KEYPOINTS)
]

gt_kpts_scale: Tensor

Expand Down Expand Up @@ -97,8 +99,8 @@ def prepare(
pred_distri = self.get_input_tensors(inputs, "distributions")[0]
pred_kpts = self.get_input_tensors(inputs, "keypoints_raw")[0]

target_kpts = self.get_label(labels, LabelType.KEYPOINTS)
target_bbox = self.get_label(labels, LabelType.BOUNDINGBOX)
target_kpts = self.get_label(labels, TaskType.KEYPOINTS)
target_bbox = self.get_label(labels, TaskType.BOUNDINGBOX)

batch_size = pred_scores.shape[0]
n_kpts = (target_kpts.shape[1] - 2) // 3
Expand Down
Loading

0 comments on commit 9c84111

Please sign in to comment.