From af41af510037ffbd88d4bae03469aea3ad504900 Mon Sep 17 00:00:00 2001 From: Jernej Sabadin <116955183+JSabadin@users.noreply.github.com> Date: Fri, 18 Oct 2024 15:32:31 +0200 Subject: [PATCH] Add support for multi-label visualization (#110) --- .../attached_modules/metrics/README.md | 2 ++ .../attached_modules/visualizers/README.md | 13 ++++++----- .../visualizers/classification_visualizer.py | 22 ++++++++++++++----- 3 files changed, 26 insertions(+), 11 deletions(-) diff --git a/luxonis_train/attached_modules/metrics/README.md b/luxonis_train/attached_modules/metrics/README.md index b61f4843..42f42fcb 100644 --- a/luxonis_train/attached_modules/metrics/README.md +++ b/luxonis_train/attached_modules/metrics/README.md @@ -19,6 +19,8 @@ Metrics from the [`torchmetrics`](https://lightning.ai/docs/torchmetrics/stable/ - [Precision](https://lightning.ai/docs/torchmetrics/stable/classification/precision.html) - [Recall](https://lightning.ai/docs/torchmetrics/stable/classification/recall.html) +> **Note:** For multi-label classification, ensure that you specify the `params.task` as `multilabel` when using these metrics. + ## ObjectKeypointSimilarity For more information, see [object-keypoint-similarity](https://learnopencv.com/object-keypoint-similarity/). diff --git a/luxonis_train/attached_modules/visualizers/README.md b/luxonis_train/attached_modules/visualizers/README.md index 1fca42e2..03daa87f 100644 --- a/luxonis_train/attached_modules/visualizers/README.md +++ b/luxonis_train/attached_modules/visualizers/README.md @@ -60,12 +60,13 @@ Visualizer for bounding boxes. **Parameters:** -| Key | Type | Default value | Description | -| -------------- | ---------------------- | ------------- | ------------------------------------------------------------------------- | -| `include_plot` | `bool` | `True` | Whether to include a plot of the class probabilities in the visualization | -| `color` | `tuple[int, int, int]` | `(255, 0, 0)` | Color of the text | -| `font_scale` | `float` | `1.0` | Scale of the font | -| `thickness` | `int` | `1` | Line thickness of the font | +| Key | Type | Default value | Description | +| -------------- | ---------------------- | ------------- | -------------------------------------------------------------------------------- | +| `include_plot` | `bool` | `True` | Whether to include a plot of the class probabilities in the visualization | +| `color` | `tuple[int, int, int]` | `(255, 0, 0)` | Color of the text | +| `font_scale` | `float` | `1.0` | Scale of the font | +| `thickness` | `int` | `1` | Line thickness of the font | +| `multi_label` | `bool` | `False` | Set to `True` for multi-label classification, otherwise `False` for single-label | **Example:** diff --git a/luxonis_train/attached_modules/visualizers/classification_visualizer.py b/luxonis_train/attached_modules/visualizers/classification_visualizer.py index 3ba5ce8c..9a5b0d61 100644 --- a/luxonis_train/attached_modules/visualizers/classification_visualizer.py +++ b/luxonis_train/attached_modules/visualizers/classification_visualizer.py @@ -20,6 +20,7 @@ def __init__( font_scale: float = 1.0, color: tuple[int, int, int] = (255, 0, 0), thickness: int = 1, + multilabel: bool = False, **kwargs, ): """Visualizer for classification tasks. @@ -33,17 +34,28 @@ def __init__( self.font_scale = font_scale self.color = color self.thickness = thickness + self.multilabel = multilabel def _get_class_name(self, pred: Tensor) -> str: - idx = int((pred.argmax()).item()) - if self.class_names is None: - return str(idx) - return self.class_names[idx] + """Handles both single-label and multi-label classification.""" + if self.multilabel: + idxs = (pred > 0.5).nonzero(as_tuple=True)[0].tolist() + if self.class_names is None: + return ", ".join([str(idx) for idx in idxs]) + return ", ".join([self.class_names[idx] for idx in idxs]) + else: + idx = int((pred.argmax()).item()) + if self.class_names is None: + return str(idx) + return self.class_names[idx] def _generate_plot( self, prediction: Tensor, width: int, height: int ) -> Tensor: - pred = prediction.softmax(-1).detach().cpu().numpy() + if self.multilabel: + pred = prediction.sigmoid().detach().cpu().numpy() + else: + pred = prediction.softmax(-1).detach().cpu().numpy() fig, ax = plt.subplots(figsize=(width / 100, height / 100)) ax.bar(np.arange(len(pred)), pred) ax.set_xticks(np.arange(len(pred)))