Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/confusion matrix metric #138

Merged
merged 21 commits into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions luxonis_train/attached_modules/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .base_metric import BaseMetric
from .confusion_matrix import ConfusionMatrix
from .mean_average_precision import MeanAveragePrecision
from .mean_average_precision_keypoints import MeanAveragePrecisionKeypoints
from .object_keypoint_similarity import ObjectKeypointSimilarity
Expand All @@ -14,4 +15,5 @@
"ObjectKeypointSimilarity",
"Precision",
"Recall",
"ConfusionMatrix",
]
275 changes: 275 additions & 0 deletions luxonis_train/attached_modules/metrics/confusion_matrix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,275 @@
from typing import Literal

import torch
from torch import Tensor
from torchmetrics.classification import MulticlassConfusionMatrix
from torchvision.ops import box_convert, box_iou

from luxonis_train.enums import TaskType
from luxonis_train.utils import Labels, Packet

from .base_metric import BaseMetric


class ConfusionMatrix(BaseMetric[Tensor, Tensor]):
klemen1999 marked this conversation as resolved.
Show resolved Hide resolved
def __init__(
self,
box_format: Literal["xyxy", "xywh", "cxcywh"] = "xyxy",
iou_threshold: float = 0.45,
confidence_threshold: float = 0.25,
**kwargs,
):
"""Compute the confusion matrix for classification,
segmentation, and object detection tasks.

@type box_format: Literal["xyxy", "xywh", "cxcywh"]
@param box_format: The format of the bounding boxes. Can be one
of "xyxy", "xywh", or "cxcywh".
@type iou_threshold: float
@param iou_threshold: The IoU threshold for matching predictions
to ground truth.
@type confidence_threshold: float
@param confidence_threshold: The confidence threshold for
filtering predictions.
"""
super().__init__(**kwargs)
allowed_box_formats = ("xyxy", "xywh", "cxcywh")
if box_format not in allowed_box_formats:
raise ValueError(

Check warning on line 38 in luxonis_train/attached_modules/metrics/confusion_matrix.py

View check run for this annotation

Codecov / codecov/patch

luxonis_train/attached_modules/metrics/confusion_matrix.py#L38

Added line #L38 was not covered by tests
f"Expected argument `box_format` to be one of {allowed_box_formats} but got {box_format}"
)
self.box_format = box_format
self.iou_threshold = iou_threshold
self.confidence_threshold = confidence_threshold

self.is_classification = (
self.node.tasks is not None
and TaskType.CLASSIFICATION in self.node.tasks
)
self.is_detection = (
self.node.tasks is not None
and TaskType.BOUNDINGBOX in self.node.tasks
)
self.is_segmentation = (
self.node.tasks is not None
and TaskType.SEGMENTATION in self.node.tasks
)

if (
sum(
[
self.is_classification,
self.is_detection,
self.is_segmentation,
]
)
> 1
):
raise ValueError(

Check warning on line 68 in luxonis_train/attached_modules/metrics/confusion_matrix.py

View check run for this annotation

Codecov / codecov/patch

luxonis_train/attached_modules/metrics/confusion_matrix.py#L68

Added line #L68 was not covered by tests
"Multiple tasks detected in self.node.tasks. Only one task is allowed."
)

self.metric_cm = None
if self.is_classification or self.is_segmentation:
self.metric_cm = MulticlassConfusionMatrix(

Check warning on line 74 in luxonis_train/attached_modules/metrics/confusion_matrix.py

View check run for this annotation

Codecov / codecov/patch

luxonis_train/attached_modules/metrics/confusion_matrix.py#L74

Added line #L74 was not covered by tests
num_classes=self.n_classes
)

if self.is_detection:
self.add_state(
"detection_cm",
default=torch.zeros(
self.n_classes + 1, self.n_classes + 1, dtype=torch.int64
), # +1 for background
JSabadin marked this conversation as resolved.
Show resolved Hide resolved
dist_reduce_fx="sum",
)

def prepare(
self, inputs: Packet[Tensor], labels: Labels
) -> tuple[dict[str, Tensor], dict[str, Tensor]]:
"""Prepare data for classification, segmentation, and detection
tasks.

@type inputs: Packet[Tensor]
@param inputs: The inputs to the model.
@type labels: Labels
@param labels: The ground-truth labels.
@return: A tuple of two dictionaries: one for predictions and
one for targets.
"""
predictions = {}
targets = {}

Check warning on line 101 in luxonis_train/attached_modules/metrics/confusion_matrix.py

View check run for this annotation

Codecov / codecov/patch

luxonis_train/attached_modules/metrics/confusion_matrix.py#L100-L101

Added lines #L100 - L101 were not covered by tests

if self.is_detection:
out_bbox = self.get_input_tensors(inputs, TaskType.BOUNDINGBOX)
bbox = self.get_label(labels, TaskType.BOUNDINGBOX)
bbox = bbox.to(out_bbox[0].device).clone()
bbox[..., 2:6] = box_convert(bbox[..., 2:6], "xywh", "xyxy")
scale_factors = torch.tensor(

Check warning on line 108 in luxonis_train/attached_modules/metrics/confusion_matrix.py

View check run for this annotation

Codecov / codecov/patch

luxonis_train/attached_modules/metrics/confusion_matrix.py#L103-L108

Added lines #L103 - L108 were not covered by tests
[
self.original_in_shape[2],
self.original_in_shape[1],
self.original_in_shape[2],
self.original_in_shape[1],
],
device=bbox.device,
)
bbox[..., 2:6] *= scale_factors
predictions["detection"] = out_bbox
targets["detection"] = bbox

Check warning on line 119 in luxonis_train/attached_modules/metrics/confusion_matrix.py

View check run for this annotation

Codecov / codecov/patch

luxonis_train/attached_modules/metrics/confusion_matrix.py#L117-L119

Added lines #L117 - L119 were not covered by tests

if self.is_classification:
prediction = self.get_input_tensors(

Check warning on line 122 in luxonis_train/attached_modules/metrics/confusion_matrix.py

View check run for this annotation

Codecov / codecov/patch

luxonis_train/attached_modules/metrics/confusion_matrix.py#L121-L122

Added lines #L121 - L122 were not covered by tests
inputs, TaskType.CLASSIFICATION
)
target = self.get_label(labels, TaskType.CLASSIFICATION).to(

Check warning on line 125 in luxonis_train/attached_modules/metrics/confusion_matrix.py

View check run for this annotation

Codecov / codecov/patch

luxonis_train/attached_modules/metrics/confusion_matrix.py#L125

Added line #L125 was not covered by tests
prediction[0].device
)
predictions["classification"] = prediction
targets["classification"] = target

Check warning on line 129 in luxonis_train/attached_modules/metrics/confusion_matrix.py

View check run for this annotation

Codecov / codecov/patch

luxonis_train/attached_modules/metrics/confusion_matrix.py#L128-L129

Added lines #L128 - L129 were not covered by tests

if self.is_segmentation:
prediction = self.get_input_tensors(inputs, TaskType.SEGMENTATION)
target = self.get_label(labels, TaskType.SEGMENTATION).to(

Check warning on line 133 in luxonis_train/attached_modules/metrics/confusion_matrix.py

View check run for this annotation

Codecov / codecov/patch

luxonis_train/attached_modules/metrics/confusion_matrix.py#L131-L133

Added lines #L131 - L133 were not covered by tests
prediction[0].device
)
predictions["segmentation"] = prediction
targets["segmentation"] = target

Check warning on line 137 in luxonis_train/attached_modules/metrics/confusion_matrix.py

View check run for this annotation

Codecov / codecov/patch

luxonis_train/attached_modules/metrics/confusion_matrix.py#L136-L137

Added lines #L136 - L137 were not covered by tests

return predictions, targets

Check warning on line 139 in luxonis_train/attached_modules/metrics/confusion_matrix.py

View check run for this annotation

Codecov / codecov/patch

luxonis_train/attached_modules/metrics/confusion_matrix.py#L139

Added line #L139 was not covered by tests

def update(
self, predictions: dict[str, Tensor], targets: dict[str, Tensor]
) -> None:
"""Update the confusion matrices for all tasks using prepared
data.

@type predictions: dict[str, Tensor]
@param predictions: A dictionary containing predictions for all
tasks.
@type targets: dict[str, Tensor]
@param targets: A dictionary containing targets for all tasks.
"""
if "classification" in predictions and "classification" in targets:
preds = predictions["classification"]
target = targets["classification"]
pred_classes = preds[0].argmax(dim=1) # [B]
target_classes = target.argmax(dim=1) # [B]
if self.metric_cm is not None:
self.metric_cm.update(pred_classes, target_classes)

Check warning on line 159 in luxonis_train/attached_modules/metrics/confusion_matrix.py

View check run for this annotation

Codecov / codecov/patch

luxonis_train/attached_modules/metrics/confusion_matrix.py#L153-L159

Added lines #L153 - L159 were not covered by tests

if "segmentation" in predictions and "segmentation" in targets:
preds = predictions["segmentation"]
target = targets["segmentation"]
pred_masks = preds[0].argmax(dim=1) # [B, H, W]
target_masks = target.argmax(dim=1) # [B, H, W]
if self.metric_cm is not None:
self.metric_cm.update(

Check warning on line 167 in luxonis_train/attached_modules/metrics/confusion_matrix.py

View check run for this annotation

Codecov / codecov/patch

luxonis_train/attached_modules/metrics/confusion_matrix.py#L161-L167

Added lines #L161 - L167 were not covered by tests
pred_masks.view(-1), target_masks.view(-1)
)

if "detection" in predictions and "detection" in targets:
preds = predictions["detection"] # type: ignore
target = targets["detection"]
self.detection_cm += self._compute_detection_confusion_matrix( # type: ignore

Check warning on line 174 in luxonis_train/attached_modules/metrics/confusion_matrix.py

View check run for this annotation

Codecov / codecov/patch

luxonis_train/attached_modules/metrics/confusion_matrix.py#L171-L174

Added lines #L171 - L174 were not covered by tests
preds, # type: ignore
target, # type: ignore
)

def compute(self) -> dict[str, Tensor]:
"""Compute confusion matrices for classification, segmentation,
and detection tasks."""
results = {}
if self.metric_cm:
task_type = (

Check warning on line 184 in luxonis_train/attached_modules/metrics/confusion_matrix.py

View check run for this annotation

Codecov / codecov/patch

luxonis_train/attached_modules/metrics/confusion_matrix.py#L182-L184

Added lines #L182 - L184 were not covered by tests
"classification" if self.is_classification else "segmentation"
)
results[f"{task_type}_confusion_matrix"] = self.metric_cm.compute()
self.metric_cm.reset()
if self.is_detection:
results["detection_confusion_matrix"] = self.detection_cm

Check warning on line 190 in luxonis_train/attached_modules/metrics/confusion_matrix.py

View check run for this annotation

Codecov / codecov/patch

luxonis_train/attached_modules/metrics/confusion_matrix.py#L187-L190

Added lines #L187 - L190 were not covered by tests

klemen1999 marked this conversation as resolved.
Show resolved Hide resolved
return results

Check warning on line 192 in luxonis_train/attached_modules/metrics/confusion_matrix.py

View check run for this annotation

Codecov / codecov/patch

luxonis_train/attached_modules/metrics/confusion_matrix.py#L192

Added line #L192 was not covered by tests

def _compute_detection_confusion_matrix(
self, preds: list[Tensor], targets: Tensor
) -> Tensor:
"""Compute a confusion matrix for object detection tasks.

@type preds: list[Tensor]
@param preds: List of predictions for each image. Each tensor
has shape [N, 6] where 6 is for [x1, y1, x2, y2, score,
class]
@type targets: Tensor
@param targets: Ground truth boxes and classes. Shape [M, 6]
where first column is image index.
"""
cm = torch.zeros(
self.n_classes + 1,
self.n_classes + 1,
dtype=torch.int64,
device=preds[0].device,
)

for img_idx, pred in enumerate(preds):
img_targets = targets[targets[:, 0] == img_idx]

if img_targets.shape[0] == 0:
for pred_class in pred[:, 5].int():
cm[pred_class, self.n_classes] += 1
continue

Check warning on line 220 in luxonis_train/attached_modules/metrics/confusion_matrix.py

View check run for this annotation

Codecov / codecov/patch

luxonis_train/attached_modules/metrics/confusion_matrix.py#L218-L220

Added lines #L218 - L220 were not covered by tests

if pred.shape[0] == 0:
for gt_class in img_targets[:, 1].int():
cm[self.n_classes, gt_class] += 1
continue

pred = pred[pred[:, 4] > self.confidence_threshold]
pred_boxes = pred[:, :4]
pred_classes = pred[:, 5].int()

gt_boxes = img_targets[:, 2:]
gt_classes = img_targets[:, 1].int()

iou = box_iou(gt_boxes, pred_boxes)
iou_thresholded = iou > self.iou_threshold

if iou_thresholded.any():
iou_max, pred_max_idx = torch.max(iou, dim=1)
iou_gt_mask = iou_max > self.iou_threshold
gt_match_idx = torch.arange(
len(gt_boxes), device=gt_boxes.device
)[iou_gt_mask]
pred_match_idx = pred_max_idx[iou_gt_mask]

for gt_idx, pred_idx in zip(gt_match_idx, pred_match_idx):
gt_class = gt_classes[gt_idx]
pred_class = pred_classes[pred_idx]
cm[pred_class, gt_class] += 1

unmatched_gt_mask = ~torch.isin(
torch.arange(len(gt_boxes), device=gt_boxes.device),
gt_match_idx,
)
for gt_idx in torch.arange(
len(gt_boxes), device=gt_boxes.device
)[unmatched_gt_mask]:
gt_class = gt_classes[gt_idx]
cm[self.n_classes, gt_class] += 1

Check warning on line 258 in luxonis_train/attached_modules/metrics/confusion_matrix.py

View check run for this annotation

Codecov / codecov/patch

luxonis_train/attached_modules/metrics/confusion_matrix.py#L257-L258

Added lines #L257 - L258 were not covered by tests

unmatched_pred_mask = ~torch.isin(
torch.arange(len(pred_boxes), device=gt_boxes.device),
pred_match_idx,
)
for pred_idx in torch.arange(
len(pred_boxes), device=gt_boxes.device
)[unmatched_pred_mask]:
pred_class = pred_classes[pred_idx]
cm[pred_class, self.n_classes] += 1
else:
for gt_class in gt_classes:
cm[self.n_classes, gt_class] += 1
for pred_class in pred_classes:
cm[pred_class, self.n_classes] += 1

Check warning on line 273 in luxonis_train/attached_modules/metrics/confusion_matrix.py

View check run for this annotation

Codecov / codecov/patch

luxonis_train/attached_modules/metrics/confusion_matrix.py#L270-L273

Added lines #L270 - L273 were not covered by tests

return cm
17 changes: 13 additions & 4 deletions luxonis_train/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,15 +151,24 @@
def check_main_metric(self) -> Self:
for metric in self.metrics:
if metric.is_main_metric:
if "matrix" in metric.name.lower():
raise ValueError(

Check warning on line 155 in luxonis_train/config/config.py

View check run for this annotation

Codecov / codecov/patch

luxonis_train/config/config.py#L155

Added line #L155 was not covered by tests
f"Main metric cannot contain 'matrix' in its name: `{metric.name}`"
)
logger.info(f"Main metric: `{metric.name}`")
return self

logger.warning("No main metric specified.")
if self.metrics:
metric = self.metrics[0]
metric.is_main_metric = True
name = metric.alias or metric.name
logger.info(f"Setting '{name}' as main metric.")
for metric in self.metrics:
if "matrix" not in metric.name.lower():
metric.is_main_metric = True
name = metric.alias or metric.name
logger.info(f"Setting '{name}' as main metric.")
return self
raise ValueError(

Check warning on line 169 in luxonis_train/config/config.py

View check run for this annotation

Codecov / codecov/patch

luxonis_train/config/config.py#L169

Added line #L169 was not covered by tests
"[Configuration Error] No valid main metric can be set as all metrics contain 'matrix' in their names."
)
else:
logger.warning(
"[Ignore if using predefined model] "
Expand Down
23 changes: 15 additions & 8 deletions luxonis_train/models/luxonis_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,14 +810,21 @@ def _evaluation_epoch_end(self, mode: Literal["test", "val"]) -> None:
logger.info("Metrics computed.")
for node_name, metrics in computed_metrics.items():
for metric_name, metric_value in metrics.items():
metric_results[node_name][metric_name] = (
metric_value.cpu().item()
)
self.log(
f"{mode}/metric/{node_name}/{metric_name}",
metric_value,
sync_dist=True,
)
if "matrix" in metric_name.lower():
self.logger.log_matrix(
matrix=metric_value.cpu().numpy(),
name=f"{mode}/metrics/{self.current_epoch}/{metric_name}",
step=self.current_epoch,
)
else:
metric_results[node_name][metric_name] = (
metric_value.cpu().item()
)
self.log(
f"{mode}/metric/{node_name}/{metric_name}",
metric_value,
sync_dist=True,
)

if self.cfg.trainer.verbose:
self._print_results(
Expand Down
2 changes: 1 addition & 1 deletion requirements-config.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
luxonis-ml[data,utils]@git+https://github.com/luxonis/luxonis-ml.git@dev
luxonis-ml[data,utils]@git+https://github.com/luxonis/luxonis-ml.git@main
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
blobconverter>=1.4.2
lightning>=2.4.0
luxonis-ml[data,tracker]>=0.5.0
luxonis-ml[data,tracker]>=0.5.1
onnx>=1.12.0
onnxruntime>=1.13.1
onnxsim>=0.4.10
Expand Down
Loading
Loading