Skip to content

Commit

Permalink
Add support for multi-label visualization (#110)
Browse files Browse the repository at this point in the history
  • Loading branch information
JSabadin authored Oct 18, 2024
1 parent 72217f1 commit af41af5
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 11 deletions.
2 changes: 2 additions & 0 deletions luxonis_train/attached_modules/metrics/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ Metrics from the [`torchmetrics`](https://lightning.ai/docs/torchmetrics/stable/
- [Precision](https://lightning.ai/docs/torchmetrics/stable/classification/precision.html)
- [Recall](https://lightning.ai/docs/torchmetrics/stable/classification/recall.html)

> **Note:** For multi-label classification, ensure that you specify the `params.task` as `multilabel` when using these metrics.
## ObjectKeypointSimilarity

For more information, see [object-keypoint-similarity](https://learnopencv.com/object-keypoint-similarity/).
Expand Down
13 changes: 7 additions & 6 deletions luxonis_train/attached_modules/visualizers/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,13 @@ Visualizer for bounding boxes.

**Parameters:**

| Key | Type | Default value | Description |
| -------------- | ---------------------- | ------------- | ------------------------------------------------------------------------- |
| `include_plot` | `bool` | `True` | Whether to include a plot of the class probabilities in the visualization |
| `color` | `tuple[int, int, int]` | `(255, 0, 0)` | Color of the text |
| `font_scale` | `float` | `1.0` | Scale of the font |
| `thickness` | `int` | `1` | Line thickness of the font |
| Key | Type | Default value | Description |
| -------------- | ---------------------- | ------------- | -------------------------------------------------------------------------------- |
| `include_plot` | `bool` | `True` | Whether to include a plot of the class probabilities in the visualization |
| `color` | `tuple[int, int, int]` | `(255, 0, 0)` | Color of the text |
| `font_scale` | `float` | `1.0` | Scale of the font |
| `thickness` | `int` | `1` | Line thickness of the font |
| `multi_label` | `bool` | `False` | Set to `True` for multi-label classification, otherwise `False` for single-label |

**Example:**

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def __init__(
font_scale: float = 1.0,
color: tuple[int, int, int] = (255, 0, 0),
thickness: int = 1,
multilabel: bool = False,
**kwargs,
):
"""Visualizer for classification tasks.
Expand All @@ -33,17 +34,28 @@ def __init__(
self.font_scale = font_scale
self.color = color
self.thickness = thickness
self.multilabel = multilabel

def _get_class_name(self, pred: Tensor) -> str:
idx = int((pred.argmax()).item())
if self.class_names is None:
return str(idx)
return self.class_names[idx]
"""Handles both single-label and multi-label classification."""
if self.multilabel:
idxs = (pred > 0.5).nonzero(as_tuple=True)[0].tolist()
if self.class_names is None:
return ", ".join([str(idx) for idx in idxs])
return ", ".join([self.class_names[idx] for idx in idxs])
else:
idx = int((pred.argmax()).item())
if self.class_names is None:
return str(idx)
return self.class_names[idx]

def _generate_plot(
self, prediction: Tensor, width: int, height: int
) -> Tensor:
pred = prediction.softmax(-1).detach().cpu().numpy()
if self.multilabel:
pred = prediction.sigmoid().detach().cpu().numpy()
else:
pred = prediction.softmax(-1).detach().cpu().numpy()
fig, ax = plt.subplots(figsize=(width / 100, height / 100))
ax.bar(np.arange(len(pred)), pred)
ax.set_xticks(np.arange(len(pred)))
Expand Down

0 comments on commit af41af5

Please sign in to comment.