diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 97189ad9..c9355abb 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -20,4 +20,4 @@ repos: hooks: - id: mdformat additional_dependencies: - - mdformat-gfm==0.3.6 \ No newline at end of file + - mdformat-gfm==0.3.6 diff --git a/luxonis_train/attached_modules/losses/__init__.py b/luxonis_train/attached_modules/losses/__init__.py index ff0bafc8..2d0c77e1 100644 --- a/luxonis_train/attached_modules/losses/__init__.py +++ b/luxonis_train/attached_modules/losses/__init__.py @@ -7,6 +7,7 @@ from .ohem_bce_with_logits import OHEMBCEWithLogitsLoss from .ohem_cross_entropy import OHEMCrossEntropyLoss from .ohem_loss import OHEMLoss +from .pml_loss import EmbeddingLossWrapper from .reconstruction_segmentation_loss import ReconstructionSegmentationLoss from .sigmoid_focal_loss import SigmoidFocalLoss from .smooth_bce_with_logits import SmoothBCEWithLogitsLoss @@ -26,4 +27,5 @@ "OHEMCrossEntropyLoss", "OHEMBCEWithLogitsLoss", "FOMOLocalizationLoss", + "EmbeddingLossWrapper", ] diff --git a/luxonis_train/attached_modules/losses/pml_loss.py b/luxonis_train/attached_modules/losses/pml_loss.py new file mode 100644 index 00000000..959a5f68 --- /dev/null +++ b/luxonis_train/attached_modules/losses/pml_loss.py @@ -0,0 +1,119 @@ +import logging + +import pytorch_metric_learning.losses as pml_losses +from pytorch_metric_learning.losses import CrossBatchMemory +from torch import Tensor + +from .base_loss import BaseLoss + +logger = logging.getLogger(__name__) + +ALL_EMBEDDING_LOSSES = [ + "AngularLoss", + "ArcFaceLoss", + "CircleLoss", + "ContrastiveLoss", + "CosFaceLoss", + "DynamicSoftMarginLoss", + "FastAPLoss", + "HistogramLoss", + "InstanceLoss", + "IntraPairVarianceLoss", + "LargeMarginSoftmaxLoss", + "GeneralizedLiftedStructureLoss", + "LiftedStructureLoss", + "MarginLoss", + "MultiSimilarityLoss", + "NPairsLoss", + "NCALoss", + "NormalizedSoftmaxLoss", + "NTXentLoss", + "PNPLoss", + "ProxyAnchorLoss", + "ProxyNCALoss", + "RankedListLoss", + "SignalToNoiseRatioContrastiveLoss", + "SoftTripleLoss", + "SphereFaceLoss", + "SubCenterArcFaceLoss", + "SupConLoss", + "ThresholdConsistentMarginLoss", + "TripletMarginLoss", + "TupletMarginLoss", +] + +CLASS_EMBEDDING_LOSSES = [ + "ArcFaceLoss", + "CosFaceLoss", + "LargeMarginSoftmaxLoss", + "NormalizedSoftmaxLoss", + "ProxyAnchorLoss", + "ProxyNCALoss", + "SoftTripleLoss", + "SphereFaceLoss", + "SubCenterArcFaceLoss", +] + + +class EmbeddingLossWrapper(BaseLoss): + def __init__( + self, + loss_name: str, + embedding_size: int = 512, + cross_batch_memory_size=0, + num_classes: int = 0, + loss_kwargs: dict | None = None, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + if loss_kwargs is None: + loss_kwargs = {} + + try: + loss_cls = getattr(pml_losses, loss_name) + except AttributeError as e: + raise ValueError( + f"Loss {loss_name} not found in pytorch_metric_learning" + ) from e + + if loss_name in CLASS_EMBEDDING_LOSSES: + if num_classes < 0: + raise ValueError( + f"Loss {loss_name} requires num_classes to be set to a positive value" + ) + loss_kwargs["num_classes"] = num_classes + loss_kwargs["embedding_size"] = embedding_size + + # If we wanted to support these losses, we would need to add a separate optimizer for them. + # They may be useful in some scenarios, so leaving this here for future reference. + raise ValueError( + f"Loss {loss_name} requires its own optimizer, and that is not currently supported." + ) + + self.loss_func = loss_cls(**loss_kwargs) + + if cross_batch_memory_size > 0: + if loss_name in CrossBatchMemory.supported_losses(): + self.loss_func = CrossBatchMemory( + self.loss_func, embedding_size=embedding_size + ) + else: + logger.warning( + f"Cross batch memory is not supported for {loss_name}. Ignoring cross_batch_memory_size." + ) + + def prepare( + self, inputs: dict[str, list[Tensor]], labels: dict[str, list[Tensor]] + ) -> tuple[Tensor, Tensor]: + embeddings = self.get_input_tensors(inputs, "features")[0] + + if labels is None or "id" not in labels: + raise ValueError("Labels must contain 'id' key") + + ids = labels["id"][0][:, 0] + return embeddings, ids + + def forward(self, inputs: Tensor, target: Tensor) -> Tensor: + loss = self.loss_func(inputs, target) + return loss diff --git a/luxonis_train/attached_modules/metrics/__init__.py b/luxonis_train/attached_modules/metrics/__init__.py index cdd0b3ac..59e9cc57 100644 --- a/luxonis_train/attached_modules/metrics/__init__.py +++ b/luxonis_train/attached_modules/metrics/__init__.py @@ -3,6 +3,7 @@ from .mean_average_precision import MeanAveragePrecision from .mean_average_precision_keypoints import MeanAveragePrecisionKeypoints from .object_keypoint_similarity import ObjectKeypointSimilarity +from .pml_metrics import ClosestIsPositiveAccuracy, MedianDistances from .torchmetrics import Accuracy, F1Score, JaccardIndex, Precision, Recall __all__ = [ @@ -15,5 +16,7 @@ "ObjectKeypointSimilarity", "Precision", "Recall", + "ClosestIsPositiveAccuracy", "ConfusionMatrix", + "MedianDistances", ] diff --git a/luxonis_train/attached_modules/metrics/pml_metrics.py b/luxonis_train/attached_modules/metrics/pml_metrics.py new file mode 100644 index 00000000..ad8b0d88 --- /dev/null +++ b/luxonis_train/attached_modules/metrics/pml_metrics.py @@ -0,0 +1,263 @@ +import torch +from torch import Tensor + +from .base_metric import BaseMetric + +# Converted from https://omoindrot.github.io/triplet-loss#offline-and-online-triplet-mining +# to PyTorch from TensorFlow + + +class ClosestIsPositiveAccuracy(BaseMetric): + def __init__(self, cross_batch_memory_size=0, **kwargs): + super().__init__(**kwargs) + self.cross_batch_memory_size = cross_batch_memory_size + self.add_state("cross_batch_memory", default=[], dist_reduce_fx="cat") + self.add_state( + "correct_predictions", + default=torch.tensor(0), + dist_reduce_fx="sum", + ) + self.add_state( + "total_predictions", default=torch.tensor(0), dist_reduce_fx="sum" + ) + + def prepare(self, inputs, labels): + embeddings = inputs["features"][0] + + assert ( + labels is not None and "id" in labels + ), "ID labels are required for metric learning losses" + ids = labels["id"][0][:, 0] + return embeddings, ids + + def update(self, inputs: Tensor, target: Tensor): + embeddings, labels = inputs, target + + if self.cross_batch_memory_size > 0: + # Append embedding and labels to the memory + self.cross_batch_memory.extend(list(zip(embeddings, labels))) + + # If the memory is full, remove the oldest elements + if len(self.cross_batch_memory) > self.cross_batch_memory_size: + self.cross_batch_memory = self.cross_batch_memory[ + -self.cross_batch_memory_size : + ] + + # If the memory is not full, return + if len(self.cross_batch_memory) < self.cross_batch_memory_size: + return + + # Get the embeddings and labels from the memory + embeddings, labels = zip(*self.cross_batch_memory) + embeddings = torch.stack(embeddings) + labels = torch.stack(labels) + + # print(f"Calculating accuracy for {len(embeddings)} embeddings") + + # Get the pairwise distances between all embeddings + pairwise_distances = _pairwise_distances(embeddings) + + # Set diagonal to infinity so that the closest embedding is not the same embedding + pairwise_distances.fill_diagonal_(float("inf")) + + # Find the closest embedding for each query embedding + closest_indices = torch.argmin(pairwise_distances, dim=1) + + # Get the labels of the closest embeddings + closest_labels = labels[closest_indices] + + # Filter out embeddings that don't have both positive and negative examples + positive_mask = _get_anchor_positive_triplet_mask(labels) + num_positives = positive_mask.sum(dim=1) + has_at_least_one_positive_and_negative = (num_positives > 0) & ( + num_positives < len(labels) + ) + + # Filter embeddings, labels, and closest indices based on valid indices + filtered_labels = labels[has_at_least_one_positive_and_negative] + filtered_closest_labels = closest_labels[ + has_at_least_one_positive_and_negative + ] + + # Calculate the number of correct predictions where the closest is positive + correct_predictions = ( + filtered_labels == filtered_closest_labels + ).sum() + + # Update the metric state + self.correct_predictions += correct_predictions + self.total_predictions += len(filtered_labels) + + def compute(self): + return self.correct_predictions / self.total_predictions + + +class MedianDistances(BaseMetric): + def __init__(self, cross_batch_memory_size=0, **kwargs): + super().__init__(**kwargs) + self.cross_batch_memory_size = cross_batch_memory_size + self.add_state("cross_batch_memory", default=[], dist_reduce_fx="cat") + self.add_state("all_distances", default=[], dist_reduce_fx="cat") + self.add_state("closest_distances", default=[], dist_reduce_fx="cat") + self.add_state("positive_distances", default=[], dist_reduce_fx="cat") + self.add_state( + "closest_vs_positive_distances", default=[], dist_reduce_fx="cat" + ) + + def prepare(self, inputs, labels): + embeddings = inputs["features"][0] + + assert ( + labels is not None and "id" in labels + ), "ID labels are required for metric learning losses" + ids = labels["id"][0][:, 0] + return embeddings, ids + + def update(self, inputs: Tensor, target: Tensor): + embeddings, labels = inputs, target + + if self.cross_batch_memory_size > 0: + # Append embedding and labels to the memory + self.cross_batch_memory.extend(list(zip(embeddings, labels))) + + # If the memory is full, remove the oldest elements + if len(self.cross_batch_memory) > self.cross_batch_memory_size: + self.cross_batch_memory = self.cross_batch_memory[ + -self.cross_batch_memory_size : + ] + + # If the memory is not full, return + if len(self.cross_batch_memory) < self.cross_batch_memory_size: + return + + # Get the embeddings and labels from the memory + embeddings, labels = zip(*self.cross_batch_memory) + embeddings = torch.stack(embeddings) + labels = torch.stack(labels) + + # Get the pairwise distances between all embeddings + pairwise_distances = _pairwise_distances(embeddings) + # Append only upper triangular part of the matrix + self.all_distances.append( + pairwise_distances[ + torch.triu(torch.ones_like(pairwise_distances), diagonal=1) + == 1 + ].flatten() + ) + + # Set diagonal to infinity so that the closest embedding is not the same embedding + pairwise_distances.fill_diagonal_(float("inf")) + + # Get the closest distance for each query embedding + closest_distances, _ = torch.min(pairwise_distances, dim=1) + self.closest_distances.append(closest_distances) + + # Get the positive mask and convert it to boolean + positive_mask = _get_anchor_positive_triplet_mask(labels).bool() + + # Filter out distances to negative elements w.r.t. each query embedding + only_positive_distances = pairwise_distances.clone() + only_positive_distances[~positive_mask] = float("inf") + + # From the positive distances, get the closest positive distance for each query embedding + closest_positive_distances, _ = torch.min( + only_positive_distances, dim=1 + ) + + # Calculate the difference between the closest distance (any) and closest positive distances + # - this tells us how much closer should the closest positive be in order for the embedding + # to be considered correct + non_inf_mask = closest_positive_distances != float("inf") + difference = closest_positive_distances - closest_distances + difference = difference[non_inf_mask] + + # Update the metric state + self.closest_vs_positive_distances.append(difference) + self.positive_distances.append( + closest_positive_distances[non_inf_mask] + ) + + def compute(self): + if len(self.all_distances) == 0: + # Return NaN tensor if no distances were calculated + return { + "MedianDistance": torch.tensor(float("nan")), + "MedianClosestDistance": torch.tensor(float("nan")), + "MedianClosestPositiveDistance": torch.tensor(float("nan")), + "MedianClosestVsClosestPositiveDistance": torch.tensor( + float("nan") + ), + } + + all_distances = torch.cat(self.all_distances) + closest_distances = torch.cat(self.closest_distances) + positive_distances = torch.cat(self.positive_distances) + closest_vs_positive_distances = torch.cat( + self.closest_vs_positive_distances + ) + + # Return medians + return { + "MedianDistance": torch.median(all_distances), + "MedianClosestDistance": torch.median(closest_distances), + "MedianClosestPositiveDistance": torch.median(positive_distances), + "MedianClosestVsClosestPositiveDistance": torch.median( + closest_vs_positive_distances + ), + } + + +def _pairwise_distances(embeddings, squared=False): + """Compute the 2D matrix of distances between all the embeddings. + + @param embeddings: tensor of shape (batch_size, embed_dim) + @type embeddings: torch.Tensor + @param squared: If true, output is the pairwise squared euclidean + distance matrix. If false, output is the pairwise euclidean + distance matrix. + @type squared: bool + @return: pairwise_distances: tensor of shape (batch_size, + batch_size) + @rtype: torch.Tensor + """ + # Get the dot product between all embeddings + # shape (batch_size, batch_size) + dot_product = torch.matmul(embeddings, embeddings.t()) + + # Get squared L2 norm for each embedding. We can just take the diagonal of `dot_product`. + # This also provides more numerical stability (the diagonal of the result will be exactly 0). + # shape (batch_size,) + square_norm = torch.diag(dot_product) + + # Compute the pairwise distance matrix as we have: + # ||a - b||^2 = ||a||^2 - 2 + ||b||^2 + # shape (batch_size, batch_size) + distances = ( + square_norm.unsqueeze(0) - 2.0 * dot_product + square_norm.unsqueeze(1) + ) + + # Because of computation errors, some distances might be negative so we put everything >= 0.0 + distances = torch.max(distances, torch.tensor(0.0)) + + if not squared: + # Because the gradient of sqrt is infinite when distances == 0.0 (ex: on the diagonal) + # we need to add a small epsilon where distances == 0.0 + mask = (distances == 0.0).float() + distances = distances + mask * 1e-16 + + distances = torch.sqrt(distances) + + # Correct the epsilon added: set the distances on the mask to be exactly 0.0 + distances = distances * (1.0 - mask) + + return distances + + +def _get_anchor_positive_triplet_mask(labels): + indices_equal = torch.eye( + labels.shape[0], dtype=torch.uint8, device=labels.device + ) + indices_not_equal = ~indices_equal + labels_equal = labels.unsqueeze(0) == labels.unsqueeze(1) + mask = indices_not_equal & labels_equal + return mask diff --git a/luxonis_train/attached_modules/visualizers/__init__.py b/luxonis_train/attached_modules/visualizers/__init__.py index 50b90471..69ecc3c4 100644 --- a/luxonis_train/attached_modules/visualizers/__init__.py +++ b/luxonis_train/attached_modules/visualizers/__init__.py @@ -1,6 +1,7 @@ from .base_visualizer import BaseVisualizer from .bbox_visualizer import BBoxVisualizer from .classification_visualizer import ClassificationVisualizer +from .embeddings_visualizer import EmbeddingsVisualizer from .keypoint_visualizer import KeypointVisualizer from .multi_visualizer import MultiVisualizer from .segmentation_visualizer import SegmentationVisualizer @@ -23,6 +24,7 @@ "KeypointVisualizer", "MultiVisualizer", "SegmentationVisualizer", + "EmbeddingsVisualizer", "combine_visualizations", "draw_bounding_box_labels", "draw_keypoint_labels", diff --git a/luxonis_train/attached_modules/visualizers/embeddings_visualizer.py b/luxonis_train/attached_modules/visualizers/embeddings_visualizer.py new file mode 100644 index 00000000..f3591c83 --- /dev/null +++ b/luxonis_train/attached_modules/visualizers/embeddings_visualizer.py @@ -0,0 +1,100 @@ +import logging + +from matplotlib import pyplot as plt +from sklearn.manifold import TSNE +from torch import Tensor + +from luxonis_train.utils import Labels, Packet + +from .base_visualizer import BaseVisualizer +from .utils import ( + figure_to_torch, +) + +logger = logging.getLogger(__name__) +log_disable = False + + +class EmbeddingsVisualizer(BaseVisualizer[Tensor, Tensor]): + # supported_tasks: list[TaskType] = [TaskType.LABEL] + + def __init__( + self, + **kwargs, + ): + """Visualizer for embedding tasks like reID.""" + super().__init__(**kwargs) + + def prepare( + self, inputs: Packet[Tensor], labels: Labels | None + ) -> tuple[Tensor, Tensor]: + embeddings = inputs["features"][0] + + assert ( + labels is not None and "id" in labels + ), "ID labels are required for metric learning losses" + ids = labels["id"][0] + return embeddings, ids + + def forward( + self, + label_canvas: Tensor, + prediction_canvas: Tensor, + embeddings: Tensor, + ids: Tensor, + **kwargs, + ) -> Tensor: + """Creates a visualization of the embeddings. + + @type label_canvas: Tensor + @param label_canvas: The canvas to draw the labels on. + @type prediction_canvas: Tensor + @param prediction_canvas: The canvas to draw the predictions on. + @type embeddings: Tensor + @param embeddings: The embeddings to visualize. + @type ids: Tensor + @param ids: The ids to visualize. + @rtype: Tensor + @return: An embedding space projection. + """ + + # Embeddings: [B, D], D = e.g. 512 + # ids: [B, 1], corresponding to the embeddings + + # Convert embeddings to numpy array + embeddings_np = embeddings.detach().cpu().numpy() + + # Perplexity must be less than the number of samples + perplexity = min(30, embeddings_np.shape[0] - 1) + + # Reduce dimensionality to 2D using t-SNE + tsne = TSNE(n_components=2, random_state=42, perplexity=perplexity) + embeddings_2d = tsne.fit_transform(embeddings_np) + + # Plot the embeddings + fig, ax = plt.subplots(figsize=(10, 10)) + scatter = ax.scatter( + embeddings_2d[:, 0], + embeddings_2d[:, 1], + c=ids.detach().cpu().numpy(), + cmap="viridis", + s=5, + ) + + fig.colorbar(scatter, ax=ax) + ax.set_title("Embeddings Visualization") + ax.set_xlabel("Dimension 1") + ax.set_ylabel("Dimension 2") + + # Convert figure to tensor + image_tensor = figure_to_torch( + fig, width=label_canvas.shape[3], height=label_canvas.shape[2] + ) + + # Close the figure to free memory + plt.close(fig) + + # Add fake batch dimension + image_tensor = image_tensor.unsqueeze(0) + + return image_tensor diff --git a/luxonis_train/loaders/utils.py b/luxonis_train/loaders/utils.py index b030e218..2782500e 100644 --- a/luxonis_train/loaders/utils.py +++ b/luxonis_train/loaders/utils.py @@ -38,6 +38,7 @@ def collate_fn( TaskType.CLASSIFICATION, TaskType.SEGMENTATION, TaskType.ARRAY, + TaskType.LABEL, ]: out_labels[task] = torch.stack(annos, 0), task_type diff --git a/luxonis_train/nodes/backbones/__init__.py b/luxonis_train/nodes/backbones/__init__.py index cc621625..da063a5e 100644 --- a/luxonis_train/nodes/backbones/__init__.py +++ b/luxonis_train/nodes/backbones/__init__.py @@ -2,6 +2,7 @@ from .ddrnet import DDRNet from .efficientnet import EfficientNet from .efficientrep import EfficientRep +from .ghostfacenet.ghostfacenet import GhostFaceNetsV2 from .micronet import MicroNet from .mobilenetv2 import MobileNetV2 from .mobileone import MobileOne @@ -22,4 +23,5 @@ "ResNet", "DDRNet", "RecSubNet", + "GhostFaceNetsV2", ] diff --git a/luxonis_train/nodes/backbones/ghostfacenet/__init__.py b/luxonis_train/nodes/backbones/ghostfacenet/__init__.py new file mode 100644 index 00000000..85ed4447 --- /dev/null +++ b/luxonis_train/nodes/backbones/ghostfacenet/__init__.py @@ -0,0 +1,3 @@ +from .ghostfacenet import GhostFaceNetsV2 + +__all__ = ["GhostFaceNetsV2"] diff --git a/luxonis_train/nodes/backbones/ghostfacenet/blocks.py b/luxonis_train/nodes/backbones/ghostfacenet/blocks.py new file mode 100644 index 00000000..46a9ba27 --- /dev/null +++ b/luxonis_train/nodes/backbones/ghostfacenet/blocks.py @@ -0,0 +1,256 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from luxonis_train.nodes.backbones.micronet.blocks import _make_divisible +from luxonis_train.nodes.blocks import SqueezeExciteBlock + + +class ModifiedGDC(nn.Module): + def __init__(self, image_size, in_chs, num_classes, dropout, emb=512): + super().__init__() + + if image_size % 32 == 0: + self.conv_dw = nn.Conv2d( + in_chs, + in_chs, + kernel_size=(image_size // 32), + groups=in_chs, + bias=False, + ) + else: + self.conv_dw = nn.Conv2d( + in_chs, + in_chs, + kernel_size=(image_size // 32 + 1), + groups=in_chs, + bias=False, + ) + self.bn1 = nn.BatchNorm2d(in_chs) + self.dropout = nn.Dropout(dropout) + + self.conv = nn.Conv2d(in_chs, emb, kernel_size=1, bias=False) + self.bn2 = nn.BatchNorm1d(emb) + self.linear = ( + nn.Linear(emb, num_classes) if num_classes else nn.Identity() + ) + + def forward(self, inps): + x = inps + x = self.conv_dw(x) + x = self.bn1(x) + x = self.dropout(x) + x = self.conv(x) + x = x.view(x.size(0), -1) + x = self.bn2(x) + x = self.linear(x) + return x + + +class GhostModuleV2(nn.Module): + def __init__( + self, + inp, + oup, + kernel_size=1, + ratio=2, + dw_size=3, + stride=1, + prelu=True, + mode=None, + args=None, + ): + super(GhostModuleV2, self).__init__() + self.mode = mode + self.gate_fn = nn.Sigmoid() + + if self.mode in ["original"]: + self.oup = oup + init_channels = math.ceil(oup / ratio) + new_channels = init_channels * (ratio - 1) + self.primary_conv = nn.Sequential( + nn.Conv2d( + inp, + init_channels, + kernel_size, + stride, + kernel_size // 2, + bias=False, + ), + nn.BatchNorm2d(init_channels), + nn.PReLU() if prelu else nn.Sequential(), + ) + self.cheap_operation = nn.Sequential( + nn.Conv2d( + init_channels, + new_channels, + dw_size, + 1, + dw_size // 2, + groups=init_channels, + bias=False, + ), + nn.BatchNorm2d(new_channels), + nn.PReLU() if prelu else nn.Sequential(), + ) + elif self.mode in ["attn"]: # DFC + self.oup = oup + init_channels = math.ceil(oup / ratio) + new_channels = init_channels * (ratio - 1) + self.primary_conv = nn.Sequential( + nn.Conv2d( + inp, + init_channels, + kernel_size, + stride, + kernel_size // 2, + bias=False, + ), + nn.BatchNorm2d(init_channels), + nn.PReLU() if prelu else nn.Sequential(), + ) + self.cheap_operation = nn.Sequential( + nn.Conv2d( + init_channels, + new_channels, + dw_size, + 1, + dw_size // 2, + groups=init_channels, + bias=False, + ), + nn.BatchNorm2d(new_channels), + nn.PReLU() if prelu else nn.Sequential(), + ) + self.short_conv = nn.Sequential( + nn.Conv2d( + inp, oup, kernel_size, stride, kernel_size // 2, bias=False + ), + nn.BatchNorm2d(oup), + nn.Conv2d( + oup, + oup, + kernel_size=(1, 5), + stride=1, + padding=(0, 2), + groups=oup, + bias=False, + ), + nn.BatchNorm2d(oup), + nn.Conv2d( + oup, + oup, + kernel_size=(5, 1), + stride=1, + padding=(2, 0), + groups=oup, + bias=False, + ), + nn.BatchNorm2d(oup), + ) + + def forward(self, x): + if self.mode in ["original"]: + x1 = self.primary_conv(x) + x2 = self.cheap_operation(x1) + out = torch.cat([x1, x2], dim=1) + return out[:, : self.oup, :, :] + elif self.mode in ["attn"]: + res = self.short_conv(F.avg_pool2d(x, kernel_size=2, stride=2)) + x1 = self.primary_conv(x) + x2 = self.cheap_operation(x1) + out = torch.cat([x1, x2], dim=1) + return out[:, : self.oup, :, :] * F.interpolate( + self.gate_fn(res), + size=(out.shape[-2], out.shape[-1]), + mode="nearest", + ) + + +class GhostBottleneckV2(nn.Module): + def __init__( + self, + in_chs, + mid_chs, + out_chs, + dw_kernel_size=3, + stride=1, + act_layer=nn.PReLU, + se_ratio=0.0, + layer_id=None, + args=None, + ): + super(GhostBottleneckV2, self).__init__() + has_se = se_ratio is not None and se_ratio > 0.0 + self.stride = stride + + assert layer_id is not None, "Layer ID must be explicitly provided" + + # Point-wise expansion + if layer_id <= 1: + self.ghost1 = GhostModuleV2( + in_chs, mid_chs, prelu=True, mode="original", args=args + ) + else: + self.ghost1 = GhostModuleV2( + in_chs, mid_chs, prelu=True, mode="attn", args=args + ) + + # Depth-wise convolution + if self.stride > 1: + self.conv_dw = nn.Conv2d( + mid_chs, + mid_chs, + dw_kernel_size, + stride=stride, + padding=(dw_kernel_size - 1) // 2, + groups=mid_chs, + bias=False, + ) + self.bn_dw = nn.BatchNorm2d(mid_chs) + + # Squeeze-and-excitation + if has_se: + reduced_chs = _make_divisible(mid_chs * se_ratio, 4) + self.se = SqueezeExciteBlock( + mid_chs, reduced_chs, True, activation=nn.PReLU() + ) + else: + self.se = None + + self.ghost2 = GhostModuleV2( + mid_chs, out_chs, prelu=False, mode="original", args=args + ) + + # shortcut + if in_chs == out_chs and self.stride == 1: + self.shortcut = nn.Sequential() + else: + self.shortcut = nn.Sequential( + nn.Conv2d( + in_chs, + in_chs, + dw_kernel_size, + stride=stride, + padding=(dw_kernel_size - 1) // 2, + groups=in_chs, + bias=False, + ), + nn.BatchNorm2d(in_chs), + nn.Conv2d(in_chs, out_chs, 1, stride=1, padding=0, bias=False), + nn.BatchNorm2d(out_chs), + ) + + def forward(self, x): + residual = x + x = self.ghost1(x) + if self.stride > 1: + x = self.conv_dw(x) + x = self.bn_dw(x) + if self.se is not None: + x = self.se(x) + x = self.ghost2(x) + x += self.shortcut(residual) + return x diff --git a/luxonis_train/nodes/backbones/ghostfacenet/ghostfacenet.py b/luxonis_train/nodes/backbones/ghostfacenet/ghostfacenet.py new file mode 100644 index 00000000..5a99ae28 --- /dev/null +++ b/luxonis_train/nodes/backbones/ghostfacenet/ghostfacenet.py @@ -0,0 +1,144 @@ +# Original source: https://github.com/Hazqeel09/ellzaf_ml/blob/main/ellzaf_ml/models/ghostfacenetsv2.py +import math +from typing import Literal + +import torch.nn as nn +from torch import Tensor + +from luxonis_train.nodes.backbones.ghostfacenet.blocks import ( + GhostBottleneckV2, + ModifiedGDC, +) +from luxonis_train.nodes.backbones.ghostfacenet.variants import get_variant +from luxonis_train.nodes.backbones.micronet.blocks import _make_divisible +from luxonis_train.nodes.base_node import BaseNode +from luxonis_train.nodes.blocks import ConvModule + + +class GhostFaceNetsV2(BaseNode[Tensor, list[Tensor]]): + in_channels: list[int] + in_width: list[int] + + def __init__( + self, + embedding_size=512, + num_classes=-1, + variant: Literal["V2"] = "V2", + *args, + **kwargs, + ): + """GhostFaceNetsV2 backbone. + + GhostFaceNetsV2 is a convolutional neural network architecture focused on face recognition, but it is + adaptable to generic embedding tasks. It is based on the GhostNet architecture and uses Ghost BottleneckV2 blocks. + + Source: U{https://github.com/Hazqeel09/ellzaf_ml/blob/main/ellzaf_ml/models/ghostfacenetsv2.py} + + @license: U{MIT License + } + + @see: U{GhostFaceNets: Lightweight Face Recognition Model From Cheap Operations + } + + @type embedding_size: int + @param embedding_size: Size of the embedding. Defaults to 512. + @type num_classes: int + @param num_classes: Number of classes. Defaults to -1, which leaves the default variant value in. Otherwise it can be used to + have the network return raw embeddings (=0) or add another linear layer to the network, which is useful for training using + ArcFace or similar classification-based losses that require the user to drop the last layer of the network. + @type variant: Literal["V2"] + @param variant: Variant of the GhostFaceNets embedding model. Defaults to "V2" (which is the only variant available). + """ + super().__init__(*args, **kwargs) + + image_size = self.in_width[0] + channels = self.in_channels[0] + var = get_variant(variant) + if num_classes >= 0: + var.num_classes = num_classes + self.cfgs = var.cfgs + + # Building first layer + output_channel = _make_divisible(int(16 * var.width), 4) + self.conv_stem = nn.Conv2d( + channels, output_channel, 3, 2, 1, bias=False + ) + self.bn1 = nn.BatchNorm2d(output_channel) + self.act1 = nn.PReLU() + input_channel = output_channel + + # Building Ghost BottleneckV2 blocks + stages = [] + layer_id = 0 + for cfg in self.cfgs: + layers = [] + for b_cfg in cfg: + output_channel = _make_divisible( + b_cfg.output_channels * var.width, 4 + ) + hidden_channel = _make_divisible( + b_cfg.expand_size * var.width, 4 + ) + if var.block == GhostBottleneckV2: + layers.append( + var.block( + input_channel, + hidden_channel, + output_channel, + b_cfg.kernel_size, + b_cfg.stride, + se_ratio=b_cfg.se_ratio, + layer_id=layer_id, + args=var.block_args, + ) + ) + input_channel = output_channel + layer_id += 1 + stages.append(nn.Sequential(*layers)) + + output_channel = _make_divisible(b_cfg.expand_size * var.width, 4) + stages.append( + nn.Sequential( + ConvModule( + input_channel, + output_channel, + kernel_size=1, + activation=nn.PReLU(), + ) + ) + ) + + self.blocks = nn.Sequential(*stages) + + # Building pointwise convolution + pointwise_conv = [nn.Sequential()] + self.pointwise_conv = nn.Sequential(*pointwise_conv) + self.classifier = ModifiedGDC( + image_size, + output_channel, + var.num_classes, + var.dropout, + embedding_size, + ) + + # Initializing weights + for m in self.modules(): + if var.init_kaiming: + if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(m.weight) + negative_slope = 0.25 + m.weight.data.normal_( + 0, math.sqrt(2.0 / (fan_in * (1 + negative_slope**2))) + ) + if isinstance(m, nn.BatchNorm2d): + m.momentum, m.eps = var.bn_momentum, var.bn_epsilon + + def forward(self, inps): + x = inps[0] + x = self.conv_stem(x) + x = self.bn1(x) + x = self.act1(x) + x = self.blocks(x) + x = self.pointwise_conv(x) + x = self.classifier(x) + return x diff --git a/luxonis_train/nodes/backbones/ghostfacenet/variants.py b/luxonis_train/nodes/backbones/ghostfacenet/variants.py new file mode 100644 index 00000000..aa78daf8 --- /dev/null +++ b/luxonis_train/nodes/backbones/ghostfacenet/variants.py @@ -0,0 +1,214 @@ +from typing import List, Literal + +from pydantic import BaseModel +from torch import nn + +from luxonis_train.nodes.backbones.ghostfacenet.blocks import GhostBottleneckV2 + + +class BlockConfig(BaseModel): + kernel_size: int + expand_size: int + output_channels: int + se_ratio: float + stride: int + + +class GhostFaceNetsVariant(BaseModel): + """Variant of the GhostFaceNets embedding model. + + @type cfgs: List[List[BlockConfig]] + @param cfgs: List of Ghost BottleneckV2 configurations. + @type num_classes: int + @param num_classes: Number of classes. Defaults to 0, which makes + the network output the raw embeddings. Otherwise it can be used + to add another linear layer to the network, which is useful for + training using ArcFace or similar classification-based losses + that require the user to drop the last layer of the network. + @type width: int + @param width: Width multiplier. Increases complexity and number of + parameters. Defaults to 1.0. + @type dropout: float + @param dropout: Dropout rate. Defaults to 0.2. + @type block: nn.Module + @param block: Ghost BottleneckV2 block. Defaults to + GhostBottleneckV2. + @type add_pointwise_conv: bool + @param add_pointwise_conv: If True, adds a pointwise convolution + layer at the end of the network. Defaults to False. + @type bn_momentum: float + @param bn_momentum: Batch normalization momentum. Defaults to 0.9. + @type bn_epsilon: float + @param bn_epsilon: Batch normalization epsilon. Defaults to 1e-5. + @type init_kaiming: bool + @param init_kaiming: If True, initializes the weights using the + Kaiming initialization. Defaults to True. + @type block_args: dict + @param block_args: Arguments to pass to the block. Defaults to None. + """ + + num_classes: int + width: int + dropout: float + block: type[nn.Module] + add_pointwise_conv: bool + bn_momentum: float + bn_epsilon: float + init_kaiming: bool + block_args: dict | None + cfgs: List[List[BlockConfig]] + + +V2 = GhostFaceNetsVariant( + num_classes=0, + width=1, + dropout=0.2, + block=GhostBottleneckV2, + add_pointwise_conv=False, + bn_momentum=0.9, + bn_epsilon=1e-5, + init_kaiming=True, + block_args=None, + cfgs=[ + [ + BlockConfig( + kernel_size=3, + expand_size=16, + output_channels=16, + se_ratio=0.0, + stride=1, + ) + ], + [ + BlockConfig( + kernel_size=3, + expand_size=48, + output_channels=24, + se_ratio=0.0, + stride=2, + ) + ], + [ + BlockConfig( + kernel_size=3, + expand_size=72, + output_channels=24, + se_ratio=0.0, + stride=1, + ) + ], + [ + BlockConfig( + kernel_size=5, + expand_size=72, + output_channels=40, + se_ratio=0.25, + stride=2, + ) + ], + [ + BlockConfig( + kernel_size=5, + expand_size=120, + output_channels=40, + se_ratio=0.25, + stride=1, + ) + ], + [ + BlockConfig( + kernel_size=3, + expand_size=240, + output_channels=80, + se_ratio=0.0, + stride=2, + ) + ], + [ + BlockConfig( + kernel_size=3, + expand_size=200, + output_channels=80, + se_ratio=0.0, + stride=1, + ), + BlockConfig( + kernel_size=3, + expand_size=184, + output_channels=80, + se_ratio=0.0, + stride=1, + ), + BlockConfig( + kernel_size=3, + expand_size=184, + output_channels=80, + se_ratio=0.0, + stride=1, + ), + BlockConfig( + kernel_size=3, + expand_size=480, + output_channels=112, + se_ratio=0.25, + stride=1, + ), + BlockConfig( + kernel_size=3, + expand_size=672, + output_channels=112, + se_ratio=0.25, + stride=1, + ), + ], + [ + BlockConfig( + kernel_size=5, + expand_size=672, + output_channels=160, + se_ratio=0.25, + stride=2, + ) + ], + [ + BlockConfig( + kernel_size=5, + expand_size=960, + output_channels=160, + se_ratio=0.0, + stride=1, + ), + BlockConfig( + kernel_size=5, + expand_size=960, + output_channels=160, + se_ratio=0.25, + stride=1, + ), + BlockConfig( + kernel_size=5, + expand_size=960, + output_channels=160, + se_ratio=0.0, + stride=1, + ), + BlockConfig( + kernel_size=5, + expand_size=960, + output_channels=160, + se_ratio=0.25, + stride=1, + ), + ], + ], +) + + +def get_variant(variant: Literal["V2"]) -> GhostFaceNetsVariant: + variants = {"V2": V2} + if variant not in variants: # pragma: no cover + raise ValueError( + "GhostFaceNets model variant should be in " + f"{list(variants.keys())}, got {variant}." + ) + return variants[variant].model_copy() diff --git a/luxonis_train/nodes/backbones/micronet/blocks.py b/luxonis_train/nodes/backbones/micronet/blocks.py index 3da5e15e..b29082cf 100644 --- a/luxonis_train/nodes/backbones/micronet/blocks.py +++ b/luxonis_train/nodes/backbones/micronet/blocks.py @@ -357,7 +357,7 @@ def __init__( self.avg_pool = nn.AdaptiveAvgPool2d(1) - squeeze_channels = self._make_divisible(in_channels // reduction, 4) + squeeze_channels = _make_divisible(in_channels // reduction, 4) self.fc = nn.Sequential( nn.Linear(in_channels, squeeze_channels), @@ -413,16 +413,17 @@ def forward(self, x: Tensor) -> Tensor: return out - def _make_divisible( - self, value: int, divisor: int, min_value: int | None = None - ) -> int: - if min_value is None: - min_value = divisor - new_v = max(min_value, int(value + divisor / 2) // divisor * divisor) - # Make sure that round down does not go down by more than 10%. - if new_v < 0.9 * value: - new_v += divisor - return new_v + +def _make_divisible( + value: int, divisor: int, min_value: int | None = None +) -> int: + if min_value is None: + min_value = divisor + new_v = max(min_value, int(value + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < 0.9 * value: + new_v += divisor + return new_v class SpatialSepConvSF(nn.Module): diff --git a/requirements.txt b/requirements.txt index 5ef87b3a..b49d9b3e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -18,3 +18,5 @@ mlflow>=2.10.0 psutil>=5.0.0 tabulate>=0.9.0 grad-cam>=1.5.4 +pytorch_metric_learning>=2.7.0 +scikit-learn>=1.5.0 \ No newline at end of file diff --git a/tests/configs/reid.yaml b/tests/configs/reid.yaml new file mode 100644 index 00000000..c79e4f8e --- /dev/null +++ b/tests/configs/reid.yaml @@ -0,0 +1,60 @@ +loader: + name: CustomReIDLoader + +model: + name: reid_test + nodes: + - name: GhostFaceNetsV2 + input_sources: + - image + params: + embedding_size: &embedding_size 512 + + losses: + - name: EmbeddingLossWrapper + params: + loss_name: SupConLoss + embedding_size: *embedding_size + cross_batch_memory_size: &memory_size 4 + attached_to: GhostFaceNetsV2 + + metrics: + - name: ClosestIsPositiveAccuracy + params: + cross_batch_memory_size: *memory_size + attached_to: GhostFaceNetsV2 + is_main_metric: True + - name: MedianDistances + params: + cross_batch_memory_size: *memory_size + attached_to: GhostFaceNetsV2 + is_main_metric: False + + visualizers: + - name: EmbeddingsVisualizer + attached_to: GhostFaceNetsV2 + +trainer: + preprocessing: + train_image_size: [256, 256] + + batch_size: 16 + epochs: 10 + n_workers: 0 + validation_interval: 10 + + callbacks: + - name: ExportOnTrainEnd + + optimizer: + name: Adam + params: + lr: 0.01 + +tracker: + project_name: reid_example + is_tensorboard: True + +exporter: + onnx: + opset_version: 11 \ No newline at end of file diff --git a/tests/integration/test_detection.py b/tests/integration/test_detection.py index 45e83f0a..060e84e2 100644 --- a/tests/integration/test_detection.py +++ b/tests/integration/test_detection.py @@ -103,7 +103,9 @@ def train_and_test( assert value > 0.8, f"{name} = {value} (expected > 0.8)" -@pytest.mark.parametrize("backbone", BACKBONES) +@pytest.mark.parametrize( + "backbone", [b for b in BACKBONES if b != "GhostFaceNetsV2"] +) def test_backbones( backbone: str, config: dict[str, Any], diff --git a/tests/integration/test_reid.py b/tests/integration/test_reid.py new file mode 100644 index 00000000..53355025 --- /dev/null +++ b/tests/integration/test_reid.py @@ -0,0 +1,184 @@ +import shutil +from pathlib import Path +from typing import Any + +import pytest +import torch + +from luxonis_train.attached_modules.losses.pml_loss import ( + ALL_EMBEDDING_LOSSES, + CLASS_EMBEDDING_LOSSES, +) +from luxonis_train.core import LuxonisModel +from luxonis_train.enums import TaskType +from luxonis_train.loaders import BaseLoaderTorch + +from .multi_input_modules import * + +INFER_PATH = Path("tests/integration/infer-save-directory") +ONNX_PATH = Path("tests/integration/_model.onnx") +STUDY_PATH = Path("study_local.db") + +NUM_INDIVIDUALS = 100 + + +class CustomReIDLoader(BaseLoaderTorch): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + @property + def input_shapes(self): + return { + "image": torch.Size([3, 256, 256]), + "id": torch.Size([1]), + } + + def __getitem__(self, _): # pragma: no cover + # Fake data + image = torch.rand(self.input_shapes["image"], dtype=torch.float32) + inputs = { + "image": image, + } + + # Fake labels + id = torch.randint(0, NUM_INDIVIDUALS, (1,), dtype=torch.int64) + labels = { + "id": (id, TaskType.LABEL), + } + + return inputs, labels + + def __len__(self): + return 10 + + def get_classes(self) -> dict[TaskType, list[str]]: + return {TaskType.LABEL: ["id"]} + + +class CustomReIDLoaderNoID(CustomReIDLoader): + def __getitem__(self, _): + inputs, labels = super().__getitem__(_) + labels["something_else"] = labels["id"] + del labels["id"] + + return inputs, labels + + +class CustomReIDLoaderImageSize2(CustomReIDLoader): + @property + def input_shapes(self): + return { + "image": torch.Size([3, 200, 200]), + "id": torch.Size([1]), + } + + +@pytest.fixture +def infer_path() -> Path: + if INFER_PATH.exists(): + shutil.rmtree(INFER_PATH) + INFER_PATH.mkdir() + return INFER_PATH + + +@pytest.fixture +def opts(test_output_dir: Path) -> dict[str, Any]: + return { + "trainer.epochs": 1, + "trainer.batch_size": 2, + "trainer.validation_interval": 1, + "trainer.callbacks": "[]", + "tracker.save_directory": str(test_output_dir), + "tuner.n_trials": 4, + } + + +@pytest.fixture(scope="function", autouse=True) +def clear_files(): + yield + STUDY_PATH.unlink(missing_ok=True) + ONNX_PATH.unlink(missing_ok=True) + + +not_class_based_losses = ALL_EMBEDDING_LOSSES.copy() +for loss in CLASS_EMBEDDING_LOSSES: + not_class_based_losses.remove(loss) + + +@pytest.mark.parametrize("loss_name", not_class_based_losses) +def test_available_losses( + opts: dict[str, Any], infer_path: Path, loss_name: str +): + config_file = "tests/configs/reid.yaml" + opts["model.losses.0.params.loss_name"] = loss_name + + # if loss_name in CLASS_EMBEDDING_LOSSES: + # opts["model.losses.0.params.num_classes"] = NUM_INDIVIDUALS + # opts["model.nodes.0.params.num_classes"] = NUM_INDIVIDUALS + # else: + # opts["model.losses.0.params.num_classes"] = 0 + # opts["model.nodes.0.params.num_classes"] = 0 + + if loss_name == "RankedListLoss": + opts["model.losses.0.params.loss_kwargs"] = {"margin": 1.0, "Tn": 0.5} + + model = LuxonisModel(config_file, opts) + model.train() + model.test(view="val") + + assert not ONNX_PATH.exists() + model.export(str(ONNX_PATH)) + assert ONNX_PATH.exists() + + assert len(list(infer_path.iterdir())) == 0 + model.infer(view="val", save_dir=infer_path) + assert infer_path.exists() + + +@pytest.mark.parametrize("loss_name", CLASS_EMBEDDING_LOSSES) +@pytest.mark.parametrize("num_classes", [-2, NUM_INDIVIDUALS]) +def test_unsupported_class_based_losses( + opts: dict[str, Any], loss_name: str, num_classes: int +): + config_file = "tests/configs/reid.yaml" + opts["model.losses.0.params.loss_name"] = loss_name + opts["model.losses.0.params.num_classes"] = num_classes + opts["model.nodes.0.params.num_classes"] = num_classes + + with pytest.raises(ValueError): + LuxonisModel(config_file, opts) + + +@pytest.mark.parametrize("loss_name", ["NonExistentLoss"]) +def test_nonexistent_losses(opts: dict[str, Any], loss_name: str): + config_file = "tests/configs/reid.yaml" + opts["model.losses.0.params.loss_name"] = loss_name + + with pytest.raises(ValueError): + LuxonisModel(config_file, opts) + + +def test_bad_loader(opts: dict[str, Any]): + config_file = "tests/configs/reid.yaml" + opts["loader.name"] = "CustomReIDLoaderNoID" + + with pytest.raises(ValueError): + model = LuxonisModel(config_file, opts) + model.train() + + +def test_not_enough_samples_for_metrics(opts: dict[str, Any]): + config_file = "tests/configs/reid.yaml" + opts["model.metrics.1.params.cross_batch_memory_size"] = 100 + + model = LuxonisModel(config_file, opts) + model.train() + + +def test_image_size_not_divisible_by_32(opts: dict[str, Any]): + config_file = "tests/configs/reid.yaml" + opts["loader.name"] = "CustomReIDLoaderImageSize2" + + # with pytest.raises(ValueError): + model = LuxonisModel(config_file, opts) + model.train() diff --git a/tests/integration/test_segmentation.py b/tests/integration/test_segmentation.py index a8b4df91..4ab4478a 100644 --- a/tests/integration/test_segmentation.py +++ b/tests/integration/test_segmentation.py @@ -123,7 +123,9 @@ def train_and_test( assert value > 0.8, f"{name} = {value} (expected > 0.8)" -@pytest.mark.parametrize("backbone", BACKBONES) +@pytest.mark.parametrize( + "backbone", [b for b in BACKBONES if b != "GhostFaceNetsV2"] +) def test_backbones( backbone: str, config: dict[str, Any],