Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Re-Identification Model #141

Closed
wants to merge 62 commits into from
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
241c74a
Added necessary requirements. Added integration with pytorch-metric-l…
CaptainTrojan Dec 6, 2024
be4c2d2
Add detailed docstring for GhostFaceNetsV2 backbone class
CaptainTrojan Dec 6, 2024
c5c4f16
fix: update docstring for pairwise_distances function in pml_metrics.py
CaptainTrojan Dec 6, 2024
c360f15
Fixed type errors
CaptainTrojan Dec 6, 2024
6eda12a
Implemented improvements and suggestions. Separated GFN into class, b…
CaptainTrojan Dec 16, 2024
0689935
refactor: update type hint for GhostFaceNetsV2 class to use Tensor fr…
CaptainTrojan Dec 16, 2024
9463997
refactor: remove unused unwrap and wrap methods from GhostFaceNetsV2 …
CaptainTrojan Dec 16, 2024
9d5c418
Merge branch 'main' into feat/reid-support
klemen1999 Dec 16, 2024
555fe2a
fix: correct formatting in __all__ list in metrics module
CaptainTrojan Dec 16, 2024
9fe0b79
Improved coverage, explicitly set mdformat github version
CaptainTrojan Dec 16, 2024
b47e79e
Reduced mdformat-gfm version to 0.3.6 to support Python 3.8
CaptainTrojan Dec 17, 2024
8e376a0
Coverage fixes
CaptainTrojan Jan 1, 2025
708691b
Merge branch 'feat/reid-support' of https://github.com/luxonis/luxoni…
CaptainTrojan Jan 1, 2025
23892f9
Merge branch 'main' into feat/reid-support
CaptainTrojan Jan 1, 2025
23e7500
fix: return a model copy for the specified GhostFaceNets variant
CaptainTrojan Jan 1, 2025
a356093
Merge branch 'feat/reid-support' of https://github.com/luxonis/luxoni…
CaptainTrojan Jan 1, 2025
df89eef
initial labels refactor support
kozlov721 Jan 11, 2025
d01816b
updated docs
kozlov721 Jan 14, 2025
e34e893
updated predefined models
kozlov721 Jan 14, 2025
82abeae
updated attached modules
kozlov721 Jan 14, 2025
f48622b
small changes
kozlov721 Jan 14, 2025
7c244af
updated tests
kozlov721 Jan 14, 2025
1de6f74
fixed predefined classification
kozlov721 Jan 14, 2025
8c32014
docs
kozlov721 Jan 14, 2025
8d7685b
fix inspect
kozlov721 Jan 14, 2025
44198cd
Merge branch 'main' into feature/nested-labels
kozlov721 Jan 14, 2025
fbbbc26
fixed tests
kozlov721 Jan 15, 2025
bb5e882
fix debug config
kozlov721 Jan 16, 2025
785f2f8
updated perlin
kozlov721 Jan 16, 2025
c093363
missing doc
kozlov721 Jan 16, 2025
f2cdfa3
reverted bacj to train_rgb
kozlov721 Jan 16, 2025
e32f6ea
fix type issues
kozlov721 Jan 16, 2025
eef219a
replaced deprecated `register_module`
kozlov721 Jan 16, 2025
0379b2a
removed init arguments
kozlov721 Jan 16, 2025
d6344ef
added missing types
kozlov721 Jan 17, 2025
44adfcb
fixed anomaly detection
kozlov721 Jan 17, 2025
c76135c
converting to float
kozlov721 Jan 17, 2025
058f449
helper function
kozlov721 Jan 17, 2025
732ad1f
changes for latest luxonis-ml
kozlov721 Jan 17, 2025
09b1e58
fixed tests
kozlov721 Jan 17, 2025
bcdb303
Merge branch 'feature/nested-labels' into feat/reid-support
kozlov721 Jan 21, 2025
5a10d61
reid fixes
kozlov721 Jan 23, 2025
3dfb8b2
renamed
kozlov721 Jan 23, 2025
f16aad4
simplified
kozlov721 Jan 23, 2025
2fe723e
separated head
kozlov721 Jan 23, 2025
45ade94
simplified
kozlov721 Jan 23, 2025
368188b
updated config
kozlov721 Jan 23, 2025
0041af2
Merge branch 'main' into feat/reid-support
kozlov721 Jan 23, 2025
7b96ab8
small changes
kozlov721 Jan 23, 2025
b14a76c
fix for rectangular images
kozlov721 Jan 23, 2025
1bce803
renamed
kozlov721 Jan 23, 2025
3c0423e
type simplification
kozlov721 Jan 23, 2025
cb970fa
added cross batch memory
kozlov721 Jan 23, 2025
c368723
attached modules improvememnt
kozlov721 Jan 24, 2025
2383283
metadata task override
kozlov721 Jan 25, 2025
dec365b
fix automatic inputs
kozlov721 Jan 25, 2025
1f05da0
cleaned
kozlov721 Jan 25, 2025
6537153
metadata overriding
kozlov721 Jan 25, 2025
fb15dff
type checking
kozlov721 Jan 25, 2025
d680406
embedding tests
kozlov721 Jan 25, 2025
f844d19
fix
kozlov721 Jan 25, 2025
1670566
parametrized tests
kozlov721 Jan 25, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions luxonis_train/attached_modules/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .ohem_bce_with_logits import OHEMBCEWithLogitsLoss
from .ohem_cross_entropy import OHEMCrossEntropyLoss
from .ohem_loss import OHEMLoss
from .pml_loss import MetricLearningLoss
klemen1999 marked this conversation as resolved.
Show resolved Hide resolved
from .reconstruction_segmentation_loss import ReconstructionSegmentationLoss
from .sigmoid_focal_loss import SigmoidFocalLoss
from .smooth_bce_with_logits import SmoothBCEWithLogitsLoss
Expand All @@ -26,4 +27,5 @@
"OHEMCrossEntropyLoss",
"OHEMBCEWithLogitsLoss",
"FOMOLocalizationLoss",
"MetricLearningLoss",
]
122 changes: 122 additions & 0 deletions luxonis_train/attached_modules/losses/pml_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import warnings

from pytorch_metric_learning.losses import (
AngularLoss,
ArcFaceLoss,
CircleLoss,
ContrastiveLoss,
CosFaceLoss,
CrossBatchMemory,
DynamicSoftMarginLoss,
FastAPLoss,
GeneralizedLiftedStructureLoss,
HistogramLoss,
InstanceLoss,
IntraPairVarianceLoss,
LargeMarginSoftmaxLoss,
LiftedStructureLoss,
ManifoldLoss,
MarginLoss,
MultiSimilarityLoss,
NCALoss,
NormalizedSoftmaxLoss,
NPairsLoss,
NTXentLoss,
P2SGradLoss,
PNPLoss,
ProxyAnchorLoss,
ProxyNCALoss,
RankedListLoss,
SignalToNoiseRatioContrastiveLoss,
SoftTripleLoss,
SphereFaceLoss,
SubCenterArcFaceLoss,
SupConLoss,
TripletMarginLoss,
TupletMarginLoss,
)
from torch import Tensor

from .base_loss import BaseLoss

# Dictionary mapping string keys to loss classes
loss_dict = {
"AngularLoss": AngularLoss,
"ArcFaceLoss": ArcFaceLoss,
"CircleLoss": CircleLoss,
"ContrastiveLoss": ContrastiveLoss,
"CosFaceLoss": CosFaceLoss,
"DynamicSoftMarginLoss": DynamicSoftMarginLoss,
"FastAPLoss": FastAPLoss,
"GeneralizedLiftedStructureLoss": GeneralizedLiftedStructureLoss,
"InstanceLoss": InstanceLoss,
"HistogramLoss": HistogramLoss,
"IntraPairVarianceLoss": IntraPairVarianceLoss,
"LargeMarginSoftmaxLoss": LargeMarginSoftmaxLoss,
"LiftedStructureLoss": LiftedStructureLoss,
"ManifoldLoss": ManifoldLoss,
"MarginLoss": MarginLoss,
"MultiSimilarityLoss": MultiSimilarityLoss,
"NCALoss": NCALoss,
"NormalizedSoftmaxLoss": NormalizedSoftmaxLoss,
"NPairsLoss": NPairsLoss,
"NTXentLoss": NTXentLoss,
"P2SGradLoss": P2SGradLoss,
"PNPLoss": PNPLoss,
"ProxyAnchorLoss": ProxyAnchorLoss,
"ProxyNCALoss": ProxyNCALoss,
"RankedListLoss": RankedListLoss,
"SignalToNoiseRatioContrastiveLoss": SignalToNoiseRatioContrastiveLoss,
"SoftTripleLoss": SoftTripleLoss,
"SphereFaceLoss": SphereFaceLoss,
"SubCenterArcFaceLoss": SubCenterArcFaceLoss,
"SupConLoss": SupConLoss,
"TripletMarginLoss": TripletMarginLoss,
"TupletMarginLoss": TupletMarginLoss,
}


class MetricLearningLoss(BaseLoss):
CaptainTrojan marked this conversation as resolved.
Show resolved Hide resolved
def __init__(
self,
loss_name: str,
embedding_size: int = 512,
cross_batch_memory_size=0,
loss_kwargs: dict | None = None,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
if loss_kwargs is None:
loss_kwargs = {}
self.loss_func = loss_dict[loss_name](
CaptainTrojan marked this conversation as resolved.
Show resolved Hide resolved
CaptainTrojan marked this conversation as resolved.
Show resolved Hide resolved
**loss_kwargs
) # Instantiate the loss object
if cross_batch_memory_size > 0:
if loss_name in CrossBatchMemory.supported_losses():
self.loss_func = CrossBatchMemory(
self.loss_func, embedding_size=embedding_size
)
else:
# Warn that cross_batch_memory_size is ignored
CaptainTrojan marked this conversation as resolved.
Show resolved Hide resolved
warnings.warn(

Check warning on line 102 in luxonis_train/attached_modules/losses/pml_loss.py

View check run for this annotation

Codecov / codecov/patch

luxonis_train/attached_modules/losses/pml_loss.py#L102

Added line #L102 was not covered by tests
CaptainTrojan marked this conversation as resolved.
Show resolved Hide resolved
f"Cross batch memory is not supported for {loss_name}. Ignoring cross_batch_memory_size"
)

# self.miner_func = miner_func
CaptainTrojan marked this conversation as resolved.
Show resolved Hide resolved

def prepare(self, inputs, labels):
CaptainTrojan marked this conversation as resolved.
Show resolved Hide resolved
embeddings = inputs["features"][0]
CaptainTrojan marked this conversation as resolved.
Show resolved Hide resolved

assert (
CaptainTrojan marked this conversation as resolved.
Show resolved Hide resolved
labels is not None and "id" in labels
), "ID labels are required for metric learning losses"
IDs = labels["id"][0][:, 0]
CaptainTrojan marked this conversation as resolved.
Show resolved Hide resolved
return embeddings, IDs

def forward(self, inputs: Tensor, target: Tensor):
CaptainTrojan marked this conversation as resolved.
Show resolved Hide resolved
# miner_output = self.miner_func(inputs, target)

loss = self.loss_func(inputs, target)

return loss
3 changes: 3 additions & 0 deletions luxonis_train/attached_modules/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .mean_average_precision import MeanAveragePrecision
from .mean_average_precision_keypoints import MeanAveragePrecisionKeypoints
from .object_keypoint_similarity import ObjectKeypointSimilarity
from .pml_metrics import ClosestIsPositiveAccuracy, MedianDistances
from .torchmetrics import Accuracy, F1Score, JaccardIndex, Precision, Recall

__all__ = [
Expand All @@ -14,4 +15,6 @@
"ObjectKeypointSimilarity",
"Precision",
"Recall",
"ClosestIsPositiveAccuracy",
"MedianDistances",
]
Loading
Loading