-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy patharcface.py
30 lines (24 loc) · 965 Bytes
/
arcface.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch import linalg
class ArcFace(nn.Module):
def __init__(self, cin, cout, s=8, m=0.5):
super().__init__()
self.s = s
self.sin_m = torch.sin(torch.tensor(m))
self.cos_m = torch.cos(torch.tensor(m))
self.cout = cout
self.fc = nn.Linear(cin, cout, bias=False)
def forward(self, x, label=None):
w_L2 = linalg.norm(self.fc.weight.detach(), dim=1, keepdim=True).T
x_L2 = linalg.norm(x, dim=1, keepdim=True)
cos = self.fc(x) / (x_L2 * w_L2)
if label is not None:
sin_m, cos_m = self.sin_m, self.cos_m
one_hot = F.one_hot(label, num_classes=self.cout)
sin = (1 - cos ** 2) ** 0.5
angle_sum = cos * cos_m - sin * sin_m
cos = angle_sum * one_hot + cos * (1 - one_hot)
cos = cos * self.s
return cos