Skip to content

Commit

Permalink
Sort classes alphabetically (#175)
Browse files Browse the repository at this point in the history
Co-authored-by: Martin Kozlovsky <[email protected]>
  • Loading branch information
JSabadin and kozlov721 authored Feb 14, 2025
1 parent 6d5031b commit 6808e6e
Show file tree
Hide file tree
Showing 12 changed files with 94 additions and 56 deletions.
9 changes: 5 additions & 4 deletions luxonis_train/attached_modules/base_attached_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from types import UnionType
from typing import Union, get_args, get_origin

from bidict import bidict
from luxonis_ml.data.utils import get_task_type
from luxonis_ml.utils.registry import AutoRegisterMeta
from torch import Size, Tensor, nn
Expand Down Expand Up @@ -162,16 +163,16 @@ def original_in_shape(self) -> Size:
return self.node.original_in_shape

@property
def class_names(self) -> list[str]:
"""Getter for the class names.
def classes(self) -> bidict[str, int]:
"""Getter for the class mapping.
@type: list[str]
@type: dict[str, int]
@raises RuntimeError: If the node doesn't define any task.
@raises ValueError: If the class names are different for
different tasks. In that case, use the L{get_class_names}
method.
"""
return self.node.class_names
return self.node.classes

def pick_labels(self, labels: Labels) -> Labels:
required_labels = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,7 @@ def __init__(
if isinstance(labels, list):
labels = {i: label for i, label in enumerate(labels)}

self.label_dict = labels or {
i: label for i, label in enumerate(self.class_names)
}
self.label_dict = labels or self.classes.inverse

if colors is None:
colors = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,10 @@ def _get_class_name(self, pred: Tensor) -> str:
"""Handles both single-label and multi-label classification."""
if self.multilabel:
idxs = (pred > 0.5).nonzero(as_tuple=True)[0].tolist()
if self.class_names is None:
return ", ".join([str(idx) for idx in idxs])
return ", ".join([self.class_names[idx] for idx in idxs])
return ", ".join([self.classes.inverse[idx] for idx in idxs])
else:
idx = int((pred.argmax()).item())
if self.class_names is None:
return str(idx)
return self.class_names[idx]
return self.classes.inverse[idx]

def _generate_plot(
self, prediction: Tensor, width: int, height: int
Expand All @@ -58,10 +54,7 @@ def _generate_plot(
fig, ax = plt.subplots(figsize=(width / 100, height / 100))
ax.bar(np.arange(len(pred)), pred)
ax.set_xticks(np.arange(len(pred)))
if self.class_names is not None:
ax.set_xticklabels(self.class_names, rotation=90)
else:
ax.set_xticklabels(np.arange(1, len(pred) + 1))
ax.set_xticklabels(self.classes.inverse, rotation=90)
ax.set_ylim(0, 1)
ax.set_xlabel("Class")
ax.set_ylabel("Probability")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Mapping

import torch
from loguru import logger
from torch import Tensor
Expand Down Expand Up @@ -59,9 +61,7 @@ def __init__(
if isinstance(labels, list):
labels = {i: label for i, label in enumerate(labels)}

self.bbox_labels = labels or {
i: label for i, label in enumerate(self.class_names)
}
self.bbox_labels = labels or self.classes.inverse

if colors is None:
colors = {
Expand All @@ -86,7 +86,7 @@ def draw_predictions(
pred_bboxes: list[Tensor],
pred_masks: list[Tensor],
width: int | None,
label_dict: dict[int, str],
label_dict: Mapping[int, str],
color_dict: dict[str, Color],
draw_labels: bool,
alpha: float,
Expand Down Expand Up @@ -142,7 +142,7 @@ def draw_targets(
target_bboxes: Tensor,
target_masks: Tensor,
width: int | None,
label_dict: dict[int, str],
label_dict: Mapping[int, str],
color_dict: dict[str, Color],
draw_labels: bool,
alpha: float,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(
backbone_params: Params | None = None,
loss_params: Params | None = None,
visualizer_params: Params | None = None,
head_params: str | None = None,
head_params: Params | None = None,
task_name: str = "",
):
var_config = get_variant(variant)
Expand Down
7 changes: 4 additions & 3 deletions luxonis_train/loaders/base_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,11 +232,12 @@ def get(self, idx: int) -> tuple[Tensor | dict[str, Tensor], Labels]:
...

@abstractmethod
def get_classes(self) -> dict[str, list[str]]:
def get_classes(self) -> dict[str, dict[str, int]]:
"""Gets classes according to computer vision task.
@rtype: dict[LabelType, list[str]]
@return: A dictionary mapping tasks to their classes.
@rtype: dict[LabelType, dict[str, int]]
@return: A dictionary mapping tasks to their classes as a
mappings from class name to class IDs.
"""
...

Expand Down
2 changes: 1 addition & 1 deletion luxonis_train/loaders/luxonis_loader_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def get(self, idx: int) -> tuple[Tensor, Labels]:
return tensor_img, self.dict_numpy_to_torch(labels)

@override
def get_classes(self) -> dict[str, list[str]]:
def get_classes(self) -> dict[str, dict[str, int]]:
return self.dataset.get_classes()

@override
Expand Down
14 changes: 13 additions & 1 deletion luxonis_train/nodes/base_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
import logging
from abc import ABC, abstractmethod
from contextlib import suppress
from operator import itemgetter
from typing import Generic, TypeVar

import torch
from bidict import bidict
from loguru import logger
from luxonis_ml.utils.registry import AutoRegisterMeta
from torch import Size, Tensor, nn
Expand Down Expand Up @@ -244,13 +246,23 @@ def n_classes(self) -> int:

return self.dataset_metadata.n_classes(self.task_name)

@property
def classes(self) -> bidict[str, int]:
"""Getter for the class mappings.
@type: dict[str, int]
"""
return self.dataset_metadata.classes(self.task_name)

@property
def class_names(self) -> list[str]:
"""Getter for the class names.
@type: list[str]
"""
return self.dataset_metadata.classes(self.task_name)
return [
name for name, _ in sorted(self.classes.items(), key=itemgetter(1))
]

@property
def input_shapes(self) -> list[Packet[Size]]:
Expand Down
24 changes: 13 additions & 11 deletions luxonis_train/utils/dataset_metadata.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Set

from bidict import bidict

from luxonis_train.loaders import BaseLoaderTorch


Expand All @@ -9,17 +11,16 @@ class DatasetMetadata:
def __init__(
self,
*,
classes: dict[str, list[str]] | None = None,
classes: dict[str, dict[str, int]] | None = None,
n_keypoints: dict[str, int] | None = None,
loader: BaseLoaderTorch | None = None,
):
"""An object containing metadata about the dataset. Used to
infer the number of classes, number of keypoints, I{etc.}
instead of passing them as arguments to the model.
@type classes: dict[str, list[str]] | None
@param classes: Dictionary mapping tasks to lists of class
names.
@type classes: dict[str, dict[str, int]] | None
@param classes: Dictionary mapping tasks to the classes.
@type n_keypoints: dict[str, int] | None
@param n_keypoints: Dictionary mapping tasks to the number of
keypoints.
Expand Down Expand Up @@ -95,7 +96,7 @@ def n_keypoints(self, task_name: str | None = None) -> int:
)
return n_keypoints

def classes(self, task_name: str | None = None) -> list[str]:
def classes(self, task_name: str | None = None) -> bidict[str, int]:
"""Gets the class names for the specified task.
@type task_name: str | None
Expand All @@ -113,14 +114,15 @@ def classes(self, task_name: str | None = None) -> list[str]:
raise ValueError(
f"Task type {task_name} is not present in the dataset."
)
return self._classes[task_name]
class_names = list(self._classes.values())[0]
for classes in self._classes.values():
if classes != class_names:
return bidict(self._classes[task_name])
classes = next(iter(self._classes.values()))
for c in self._classes.values():
if c != classes:
raise RuntimeError(
"The dataset contains different class names for different tasks."
"The dataset contains different class "
"definitions for different tasks."
)
return class_names
return bidict(classes)

@classmethod
def from_loader(cls, loader: BaseLoaderTorch) -> "DatasetMetadata":
Expand Down
4 changes: 2 additions & 2 deletions tests/integration/multi_input_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ def __len__(self):
return 10

@override
def get_classes(self) -> dict[str, list[str]]:
return {"": ["square"]}
def get_classes(self) -> dict[str, dict[str, int]]:
return {"": {"square": 0}}


class MultiInputTestBaseNode(BaseNode):
Expand Down
47 changes: 39 additions & 8 deletions tests/integration/parking_lot.json
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,29 @@
"metadata": {
"postprocessor_path": null,
"classes": [
"background", "alfa-romeo", "buick", "ducati", "harley",
"ferrari", "infiniti", "jeep", "land-rover", "roll-royce",
"yamaha", "aprilia", "bmw", "dodge", "honda", "moto",
"piaggio", "isuzu", "Kawasaki", "truimph", "pontiac",
"saab", "chrysler"
"background",
"alfa-romeo",
"aprilia",
"bmw",
"buick",
"chrysler",
"dodge",
"ducati",
"ferrari",
"harley",
"honda",
"infiniti",
"isuzu",
"jeep",
"Kawasaki",
"land-rover",
"moto",
"piaggio",
"pontiac",
"roll-royce",
"saab",
"truimph",
"yamaha"
],
"n_classes": 23,
"is_softmax": false
Expand All @@ -113,7 +131,11 @@
"parser": "SegmentationParser",
"metadata": {
"postprocessor_path": null,
"classes": ["motorbike", "car", "background"],
"classes": [
"background",
"car",
"motorbike"
],
"n_classes": 3,
"is_softmax": false
},
Expand All @@ -124,7 +146,11 @@
"parser": "YOLO",
"metadata": {
"postprocessor_path": null,
"classes": ["motorbike", "car", "background"],
"classes": [
"background",
"car",
"motorbike"
],
"n_classes": 3,
"iou_threshold": 0.45,
"conf_threshold": 0.25,
Expand Down Expand Up @@ -166,7 +192,12 @@
"parser": "SegmentationParser",
"metadata": {
"postprocessor_path": null,
"classes": ["background", "blue", "green", "red"],
"classes": [
"background",
"blue",
"green",
"red"
],
"n_classes": 4,
"is_softmax": false
},
Expand Down
14 changes: 7 additions & 7 deletions tests/unittests/test_utils/test_dataset_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
def metadata() -> DatasetMetadata:
return DatasetMetadata(
classes={
"color-segmentation": ["car", "person"],
"detection": ["car", "person"],
"color-segmentation": {"car": 0, "person": 1},
"detection": {"car": 0, "person": 1},
},
n_keypoints={"color-segmentation": 0, "detection": 0},
)
Expand All @@ -20,7 +20,7 @@ def test_n_classes(metadata: DatasetMetadata):
assert metadata.n_classes() == 2
with pytest.raises(ValueError):
metadata.n_classes("segmentation")
metadata._classes["segmentation"] = ["car", "person", "tree"]
metadata._classes["segmentation"] = {"car": 0, "person": 1, "tree": 2}
with pytest.raises(RuntimeError):
metadata.n_classes()

Expand All @@ -37,11 +37,11 @@ def test_n_keypoints(metadata: DatasetMetadata):


def test_class_names(metadata: DatasetMetadata):
assert metadata.classes("color-segmentation") == ["car", "person"]
assert metadata.classes("detection") == ["car", "person"]
assert metadata.classes() == ["car", "person"]
assert metadata.classes("color-segmentation") == {"car": 0, "person": 1}
assert metadata.classes("detection") == {"car": 0, "person": 1}
assert metadata.classes() == {"car": 0, "person": 1}
with pytest.raises(ValueError):
metadata.classes("segmentation")
metadata._classes["segmentation"] = ["car", "person", "tree"]
metadata._classes["segmentation"] = {"car": 0, "person": 1, "tree": 2}
with pytest.raises(RuntimeError):
metadata.classes()

0 comments on commit 6808e6e

Please sign in to comment.