Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/confusion matrix metric #138

Merged
merged 21 commits into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions luxonis_train/attached_modules/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .base_metric import BaseMetric
from .confusion_matrix import ConfusionMatrix
from .mean_average_precision import MeanAveragePrecision
from .mean_average_precision_keypoints import MeanAveragePrecisionKeypoints
from .object_keypoint_similarity import ObjectKeypointSimilarity
Expand All @@ -14,4 +15,5 @@
"ObjectKeypointSimilarity",
"Precision",
"Recall",
"ConfusionMatrix",
]
275 changes: 275 additions & 0 deletions luxonis_train/attached_modules/metrics/confusion_matrix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,275 @@
from typing import Literal

import torch
from torch import Tensor
from torchmetrics.classification import MulticlassConfusionMatrix
from torchvision.ops import box_convert, box_iou

from luxonis_train.enums import TaskType
from luxonis_train.utils import Labels, Packet

from .base_metric import BaseMetric


class ConfusionMatrix(BaseMetric[Tensor, Tensor]):
klemen1999 marked this conversation as resolved.
Show resolved Hide resolved
def __init__(
self,
box_format: Literal["xyxy", "xywh", "cxcywh"] = "xyxy",
iou_threshold: float = 0.45,
confidence_threshold: float = 0.25,
**kwargs,
):
"""Compute the confusion matrix for classification,
segmentation, and object detection tasks.

@type box_format: Literal["xyxy", "xywh", "cxcywh"]
@param box_format: The format of the bounding boxes. Can be one
of "xyxy", "xywh", or "cxcywh".
@type iou_threshold: float
@param iou_threshold: The IoU threshold for matching predictions
to ground truth.
@type confidence_threshold: float
@param confidence_threshold: The confidence threshold for
filtering predictions.
"""
super().__init__(**kwargs)
allowed_box_formats = ("xyxy", "xywh", "cxcywh")
if box_format not in allowed_box_formats:
raise ValueError(
f"Expected argument `box_format` to be one of {allowed_box_formats} but got {box_format}"
)
self.box_format = box_format
self.iou_threshold = iou_threshold
self.confidence_threshold = confidence_threshold

self.is_classification = (
self.node.tasks is not None
and TaskType.CLASSIFICATION in self.node.tasks
)
self.is_detection = (
self.node.tasks is not None
and TaskType.BOUNDINGBOX in self.node.tasks
)
self.is_segmentation = (
self.node.tasks is not None
and TaskType.SEGMENTATION in self.node.tasks
)

if (
sum(
[
self.is_classification,
self.is_detection,
self.is_segmentation,
]
)
> 1
):
raise ValueError(
"Multiple tasks detected in self.node.tasks. Only one task is allowed."
)

self.metric_cm = None
if self.is_classification or self.is_segmentation:
self.metric_cm = MulticlassConfusionMatrix(
num_classes=self.n_classes
)

if self.is_detection:
self.add_state(
"detection_cm",
default=torch.zeros(
self.n_classes + 1, self.n_classes + 1, dtype=torch.int64
), # +1 for background
JSabadin marked this conversation as resolved.
Show resolved Hide resolved
dist_reduce_fx="sum",
)

def prepare(
self, inputs: Packet[Tensor], labels: Labels
) -> tuple[dict[str, Tensor], dict[str, Tensor]]:
"""Prepare data for classification, segmentation, and detection
tasks.

@type inputs: Packet[Tensor]
@param inputs: The inputs to the model.
@type labels: Labels
@param labels: The ground-truth labels.
@return: A tuple of two dictionaries: one for predictions and
one for targets.
"""
predictions = {}
targets = {}

if self.is_detection:
out_bbox = self.get_input_tensors(inputs, TaskType.BOUNDINGBOX)
bbox = self.get_label(labels, TaskType.BOUNDINGBOX)
bbox = bbox.to(out_bbox[0].device).clone()
bbox[..., 2:6] = box_convert(bbox[..., 2:6], "xywh", "xyxy")
scale_factors = torch.tensor(
[
self.original_in_shape[2],
self.original_in_shape[1],
self.original_in_shape[2],
self.original_in_shape[1],
],
device=bbox.device,
)
bbox[..., 2:6] *= scale_factors
predictions["detection"] = out_bbox
targets["detection"] = bbox

if self.is_classification:
prediction = self.get_input_tensors(
inputs, TaskType.CLASSIFICATION
)
target = self.get_label(labels, TaskType.CLASSIFICATION).to(
prediction[0].device
)
predictions["classification"] = prediction
targets["classification"] = target

if self.is_segmentation:
prediction = self.get_input_tensors(inputs, TaskType.SEGMENTATION)
target = self.get_label(labels, TaskType.SEGMENTATION).to(
prediction[0].device
)
predictions["segmentation"] = prediction
targets["segmentation"] = target

return predictions, targets

def update(
self, predictions: dict[str, Tensor], targets: dict[str, Tensor]
) -> None:
"""Update the confusion matrices for all tasks using prepared
data.

@type predictions: dict[str, Tensor]
@param predictions: A dictionary containing predictions for all
tasks.
@type targets: dict[str, Tensor]
@param targets: A dictionary containing targets for all tasks.
"""
if "classification" in predictions and "classification" in targets:
preds = predictions["classification"]
target = targets["classification"]
pred_classes = preds[0].argmax(dim=1) # [B]
target_classes = target.argmax(dim=1) # [B]
if self.metric_cm is not None:
self.metric_cm.update(pred_classes, target_classes)

if "segmentation" in predictions and "segmentation" in targets:
preds = predictions["segmentation"]
target = targets["segmentation"]
pred_masks = preds[0].argmax(dim=1) # [B, H, W]
target_masks = target.argmax(dim=1) # [B, H, W]
if self.metric_cm is not None:
self.metric_cm.update(
pred_masks.view(-1), target_masks.view(-1)
)

if "detection" in predictions and "detection" in targets:
preds = predictions["detection"] # type: ignore
target = targets["detection"]
self.detection_cm += self._compute_detection_confusion_matrix( # type: ignore
preds, # type: ignore
target, # type: ignore
)

def compute(self) -> dict[str, Tensor]:
"""Compute confusion matrices for classification, segmentation,
and detection tasks."""
results = {}
if self.metric_cm:
task_type = (
"classification" if self.is_classification else "segmentation"
)
results[f"{task_type}_confusion_matrix"] = self.metric_cm.compute()
self.metric_cm.reset()
if self.is_detection:
results["detection_confusion_matrix"] = self.detection_cm

klemen1999 marked this conversation as resolved.
Show resolved Hide resolved
return results

def _compute_detection_confusion_matrix(
self, preds: list[Tensor], targets: Tensor
) -> Tensor:
"""Compute a confusion matrix for object detection tasks.

@type preds: list[Tensor]
@param preds: List of predictions for each image. Each tensor
has shape [N, 6] where 6 is for [x1, y1, x2, y2, score,
class]
@type targets: Tensor
@param targets: Ground truth boxes and classes. Shape [M, 6]
where first column is image index.
"""
cm = torch.zeros(
self.n_classes + 1,
self.n_classes + 1,
dtype=torch.int64,
device=preds[0].device,
)

for img_idx, pred in enumerate(preds):
img_targets = targets[targets[:, 0] == img_idx]

if img_targets.shape[0] == 0:
for pred_class in pred[:, 5].int():
cm[pred_class, self.n_classes] += 1
continue

if pred.shape[0] == 0:
for gt_class in img_targets[:, 1].int():
cm[self.n_classes, gt_class] += 1
continue

pred = pred[pred[:, 4] > self.confidence_threshold]
pred_boxes = pred[:, :4]
pred_classes = pred[:, 5].int()

gt_boxes = img_targets[:, 2:]
gt_classes = img_targets[:, 1].int()

iou = box_iou(gt_boxes, pred_boxes)
iou_thresholded = iou > self.iou_threshold

if iou_thresholded.any():
iou_max, pred_max_idx = torch.max(iou, dim=1)
iou_gt_mask = iou_max > self.iou_threshold
gt_match_idx = torch.arange(
len(gt_boxes), device=gt_boxes.device
)[iou_gt_mask]
pred_match_idx = pred_max_idx[iou_gt_mask]

for gt_idx, pred_idx in zip(gt_match_idx, pred_match_idx):
gt_class = gt_classes[gt_idx]
pred_class = pred_classes[pred_idx]
cm[pred_class, gt_class] += 1

unmatched_gt_mask = ~torch.isin(
torch.arange(len(gt_boxes), device=gt_boxes.device),
gt_match_idx,
)
for gt_idx in torch.arange(
len(gt_boxes), device=gt_boxes.device
)[unmatched_gt_mask]:
gt_class = gt_classes[gt_idx]
cm[self.n_classes, gt_class] += 1

unmatched_pred_mask = ~torch.isin(
torch.arange(len(pred_boxes), device=gt_boxes.device),
pred_match_idx,
)
for pred_idx in torch.arange(
len(pred_boxes), device=gt_boxes.device
)[unmatched_pred_mask]:
pred_class = pred_classes[pred_idx]
cm[pred_class, self.n_classes] += 1
else:
for gt_class in gt_classes:
cm[self.n_classes, gt_class] += 1
for pred_class in pred_classes:
cm[pred_class, self.n_classes] += 1

return cm
17 changes: 13 additions & 4 deletions luxonis_train/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,15 +151,24 @@ def check_predefined_model(self) -> Self:
def check_main_metric(self) -> Self:
for metric in self.metrics:
if metric.is_main_metric:
if "matrix" in metric.name.lower():
raise ValueError(
f"Main metric cannot contain 'matrix' in its name: `{metric.name}`"
)
logger.info(f"Main metric: `{metric.name}`")
return self

logger.warning("No main metric specified.")
if self.metrics:
metric = self.metrics[0]
metric.is_main_metric = True
name = metric.alias or metric.name
logger.info(f"Setting '{name}' as main metric.")
for metric in self.metrics:
if "matrix" not in metric.name.lower():
metric.is_main_metric = True
name = metric.alias or metric.name
logger.info(f"Setting '{name}' as main metric.")
return self
raise ValueError(
"[Configuration Error] No valid main metric can be set as all metrics contain 'matrix' in their names."
)
else:
logger.warning(
"[Ignore if using predefined model] "
Expand Down
23 changes: 15 additions & 8 deletions luxonis_train/models/luxonis_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,14 +810,21 @@
logger.info("Metrics computed.")
for node_name, metrics in computed_metrics.items():
for metric_name, metric_value in metrics.items():
metric_results[node_name][metric_name] = (
metric_value.cpu().item()
)
self.log(
f"{mode}/metric/{node_name}/{metric_name}",
metric_value,
sync_dist=True,
)
if "matrix" in metric_name.lower():
self.logger.log_matrix(

Check failure on line 814 in luxonis_train/models/luxonis_lightning.py

View workflow job for this annotation

GitHub Actions / type-check

Cannot access attribute "log_matrix" for class "LuxonisTrackerPL"   Attribute "log_matrix" is unknown (reportAttributeAccessIssue)
matrix=metric_value.cpu().numpy(),
name=f"{mode}/metrics/{self.current_epoch}/{metric_name}",
step=self.current_epoch,
)
else:
metric_results[node_name][metric_name] = (
metric_value.cpu().item()
)
self.log(
f"{mode}/metric/{node_name}/{metric_name}",
metric_value,
sync_dist=True,
)

if self.cfg.trainer.verbose:
self._print_results(
Expand Down
Loading
Loading