Skip to content

Commit

Permalink
feat: remove MatryoshkaCELoss (incorrect) for an optimal matryoshka…
Browse files Browse the repository at this point in the history
… loss computation
  • Loading branch information
tonywu71 committed Sep 23, 2024
1 parent 23d61c5 commit cc3b31d
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 88 deletions.
57 changes: 46 additions & 11 deletions colpali_engine/loss/colpali_2_losses.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from dataclasses import dataclass
from typing import Optional, cast
from typing import List, Optional

import torch
import torch.nn.functional as F # noqa: N812
from torch.nn import CrossEntropyLoss, KLDivLoss

from colpali_engine.loss.base_late_interaction_loss import BaseColbertLoss
from colpali_engine.loss.matryoshka_loss import MatryoshkaCELoss
from colpali_engine.models.paligemma.colpali_2.modeling_colpali_2 import ColPali2LossOutputs, ColPali2ModelOutput


Expand All @@ -30,17 +29,22 @@ def __init__(
self,
alpha: float = 0.5,
use_matryoshka_loss: bool = True,
matryoshka_dims: Optional[List[int]] = None,
matryoshka_weights: Optional[List[float]] = None,
use_distillation_loss: bool = True,
beta: float = 0.5,
temperature: float = 2.0,
):
super().__init__()
self.alpha = alpha

self.use_matryoshka_loss = use_matryoshka_loss
self.matryoshka_dims = matryoshka_dims
self.matryoshka_weights = matryoshka_weights

self.use_distillation_loss = use_distillation_loss
self.beta = beta
self.temperature = temperature
self.single_vector_loss_fn = MatryoshkaCELoss() if self.use_matryoshka_loss else CrossEntropyLoss()

def single_vector_loss(
self,
Expand All @@ -62,15 +66,46 @@ def single_vector_loss(
if query_embeddings.shape[0] != doc_embeddings.shape[0]:
raise ValueError("Batch size mismatch between query and document embeddings.")

scores = torch.einsum("bd,cd->bc", query_embeddings, doc_embeddings) # (batch_size, batch_size)
if query_embeddings.shape[1] != doc_embeddings.shape[1]:
raise ValueError("Dimensionality mismatch between query and document embeddings.")

batch_size = query_embeddings.shape[0]
device = query_embeddings.device

loss = cast(
torch.Tensor,
self.single_vector_loss_fn(
ce_loss_fn = CrossEntropyLoss()

if not self.use_matryoshka_loss:
scores = torch.einsum("bd,cd->bc", query_embeddings, doc_embeddings) # (batch_size, batch_size)
loss = ce_loss_fn.forward(
input=scores,
target=torch.arange(scores.shape[0], device=scores.device, dtype=torch.long),
),
) # (1,)
target=torch.arange(scores.shape[0], device=scores.device),
) # ()

else:
if not self.matryoshka_dims:
raise ValueError("Matryoshka dimensions must be provided when using Matryoshka loss.")

# The target is independent of the Matryoshka dimensionality
target = torch.arange(
query_embeddings.shape[0],
dtype=torch.long,
device=query_embeddings.device,
)

# Initialize the scores matrix and the loss
prev_scores = torch.zeros(batch_size, batch_size, device=device)
scores = torch.zeros(batch_size, batch_size, device=device)
loss = torch.tensor(0.0, device=device) # ()

# To efficiently compute the scores, we need the Matryoshka dimensions to be sorted.
matryoshka_dims = [0] + sorted(self.matryoshka_dims)

for prev_dim, dim in zip(matryoshka_dims, matryoshka_dims[1:]):
scores = prev_scores + torch.einsum(
"bd,cd->bc", query_embeddings[:, prev_dim:dim], doc_embeddings[:, prev_dim:dim]
) # (batch_size, batch_size)
loss += ce_loss_fn.forward(input=scores, target=target) # ()
prev_scores = scores

return ColPali2IntermediateLossOutputs(
loss=loss,
Expand Down Expand Up @@ -109,7 +144,7 @@ def multi_vector_loss(
neg_scores = neg_scores.max(dim=1)[0] # (batch_size,)

# Compute the margin loss
loss = F.softplus(neg_scores - pos_scores).mean() # (1,)
loss = F.softplus(neg_scores - pos_scores).mean().squeeze() # ()

return ColPali2IntermediateLossOutputs(
loss=loss,
Expand Down
75 changes: 0 additions & 75 deletions colpali_engine/loss/matryoshka_loss.py

This file was deleted.

5 changes: 3 additions & 2 deletions tests/loss/test_colpali_2_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ def model_output(

@pytest.fixture(
params=[
# (False, False),
# (False, True),
(False, False),
(False, True),
(True, False),
(True, True),
]
Expand All @@ -42,6 +42,7 @@ def colpali_2_loss(request) -> ColPali2Loss:
return ColPali2Loss(
use_matryoshka_loss=use_matryoshka_loss,
use_distillation_loss=use_distillation_loss,
matryoshka_dims=[EMBEDDING_DIM, EMBEDDING_DIM // 2],
)


Expand Down

0 comments on commit cc3b31d

Please sign in to comment.