-
Notifications
You must be signed in to change notification settings - Fork 31
/
Copy pathPairConLoss.py
30 lines (22 loc) · 1.14 KB
/
PairConLoss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch
from torch import nn
class PairConLoss(nn.Module):
def __init__(self, temperature=0.05):
super(PairConLoss, self).__init__()
self.temperature = temperature
self.eps = 1e-08
def forward(self, features_1, features_2, device):
batch_size = features_1.shape[0]
features = torch.cat([features_1, features_2], dim=0)
mask = torch.eye(batch_size, dtype=torch.bool).to(device)
mask = mask.repeat(2, 2)
mask = ~mask
pos = torch.exp(torch.sum(features_1*features_2, dim=-1) / self.temperature)
pos = torch.cat([pos, pos], dim=0)
neg = torch.exp(torch.mm(features, features.t().contiguous()) / self.temperature)
neg = neg.masked_select(mask).view(2*batch_size, -1)
neg_mean = torch.mean(neg)
pos_n = torch.mean(pos)
Ng = neg.sum(dim=-1)
loss_pos = (- torch.log(pos / (Ng+pos))).mean()
return {"loss":loss_pos, "pos_mean":pos_n.detach().cpu().numpy(), "neg_mean":neg_mean.detach().cpu().numpy(), "pos":pos.detach().cpu().numpy(), "neg":neg.detach().cpu().numpy()}