forked from bui-thanh-lam/cps-segment
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathloss.py
199 lines (183 loc) · 9.33 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchgeometry.losses import dice_loss
from utils import DEVICE, IGNORE_INDEX
class DiceCELoss(nn.Module):
""" Mix DiceLoss and Cross Entropy Loss"""
def __init__(self, dice_weight=1, reduction="mean", ignore_index=IGNORE_INDEX):
super().__init__()
self.dice_weight = dice_weight
self.reduction = reduction
self.ignore_index = ignore_index
def forward(self, preds, targets):
# ignored element has value of -1
ce = F.cross_entropy(preds, targets, ignore_index=self.ignore_index, reduction=self.reduction)
# change all -1 to 0 because of compatiblity
targets = torch.div((targets + 1), 2, rounding_mode='floor')
dice = dice_loss(preds, targets)
return (ce + dice*self.dice_weight) / (1+self.dice_weight)
class CombinedCPSLoss(nn.Module):
def __init__(
self,
n_models=3,
trade_off_factor=1.5,
pseudo_label_confidence_threshold=0.7,
use_cutmix=False,
use_multiple_teachers=False,
use_momentum=False
):
super().__init__()
self.pseudo_label_confidence_threshold = pseudo_label_confidence_threshold
self.n_models = n_models
self.trade_off_factor = trade_off_factor
# self.loss = DiceCELoss()
self.loss = nn.CrossEntropyLoss(ignore_index=IGNORE_INDEX)
self.use_cutmix = use_cutmix
self.use_multiple_teachers = use_multiple_teachers
self.use_momentum = use_momentum
def _multiple_teacher_correction(self, pseudo_labels):
# shape: bs * n_classes * h * w * n_models
_sum = torch.sum(pseudo_labels, dim=-1)
_pseudo_labels = torch.empty_like(pseudo_labels)
for i in range(self.n_models):
_pseudo_labels[:, :, :, :, i] = _sum - pseudo_labels[:, :, :, :, i]
# shape: bs * n_classes * h * w * n_models, but differnet notations
return _pseudo_labels
def _prune_pseudo_label_by_threshold(self, pseudo_labels):
# pseudo_labels: shape bs * n_classes * h * w * n_models
_pseudo_labels = F.softmax(pseudo_labels, dim=1)
_pseudo_labels = torch.max(_pseudo_labels, dim=1)[0]
_pseudo_labels = (_pseudo_labels >= self.pseudo_label_confidence_threshold).long()
_pseudo_labels = _pseudo_labels * (torch.argmax(pseudo_labels, dim=1) + 1) - 1
# _pseudo_labels: shape bs * h * w * n_models
return _pseudo_labels
def forward(
self,
targets,
preds_L,
preds_U=None,
preds_U_1=None,
preds_U_2=None,
preds_m=None,
M=None,
t_preds_U=None,
t_preds_U_1=None,
t_preds_U_2=None,
):
# preds: bs * class * w * h * n_models
# targets: bs * 1 * w * h
ce_loss = torch.zeros(1).to(DEVICE)
cps_loss = torch.zeros(1).to(DEVICE)
if self.use_cutmix:
if preds_U_1 is None or preds_U_2 is None or preds_m is None:
raise ValueError("preds_U_1, preds_U_2, preds_m and M must be provided when use_cutmix=True")
if self.use_multiple_teachers:
if self.use_momentum:
Y_1 = self._multiple_teacher_correction(t_preds_U_1)
Y_2 = self._multiple_teacher_correction(t_preds_U_2)
else:
Y_1 = self._multiple_teacher_correction(preds_U_1)
Y_2 = self._multiple_teacher_correction(preds_U_2)
for j in range(self.n_models):
P_m_j = preds_m[:, :, :, :, j]
Y_1_j = Y_1[:, :, :, :, j]
Y_2_j = Y_2[:, :, :, :, j]
# disable gradient passing
with torch.no_grad():
ones = torch.ones_like(M)
# if threshold <= 0.5, don't use threshold clipping
if self.pseudo_label_confidence_threshold <= 0.5:
tmp = Y_1_j * (ones - M) + Y_2_j * M
Y_j = torch.argmax(tmp, dim=1)
# otherwise, only concern with pseudo labels which have high confidence
else:
tmp = Y_1_j * (ones - M) + Y_2_j * M
Y_j = self._prune_pseudo_label_by_threshold(tmp)
cps_loss += self.loss(P_m_j, Y_j)
else:
for r in range(self.n_models):
for l in range(r):
if self.use_momentum:
P_U_l_1 = t_preds_U_1[:, :, :, :, l]
P_U_r_1 = t_preds_U_1[:, :, :, :, r]
P_U_l_2 = t_preds_U_2[:, :, :, :, l]
P_U_r_2 = t_preds_U_2[:, :, :, :, r]
else:
P_U_l_1 = preds_U_1[:, :, :, :, l]
P_U_r_1 = preds_U_1[:, :, :, :, r]
P_U_l_2 = preds_U_2[:, :, :, :, l]
P_U_r_2 = preds_U_2[:, :, :, :, r]
P_m_l = preds_m[:, :, :, :, l]
P_m_r = preds_m[:, :, :, :, r]
# compute cps loss, disable gradient passing
with torch.no_grad():
ones = torch.ones_like(M)
# if threshold <= 0.5, don't use threshold clipping
if self.pseudo_label_confidence_threshold <= 0.5:
tmp = P_U_l_1 * (ones - M) + P_U_l_2 * M
Y_l = torch.argmax(tmp, dim=1)
tmp = P_U_r_1 * (ones - M) + P_U_r_2 * M
Y_r = torch.argmax(P_U_r, dim=1)
# otherwise, only concern with pseudo labels which have high confidence
else:
tmp = P_U_l_1 * (ones - M) + P_U_l_2 * M
Y_l = self._prune_pseudo_label_by_threshold(tmp)
tmp = P_U_r_1 * (ones - M) + P_U_r_2 * M
Y_r = self._prune_pseudo_label_by_threshold(tmp)
cps_loss += self.loss(P_m_l, Y_r) + self.loss(P_m_r, Y_l)
else:
if preds_L is None or preds_U is None:
raise ValueError("preds_U and preds_L must be provided when use_cutmix=False")
if self.use_multiple_teachers:
if self.use_momentum:
Y_U = self._multiple_teacher_correction(t_preds_U)
else:
Y_U = self._multiple_teacher_correction(preds_U)
for j in range(self.n_models):
P_U_j = preds_U[:, :, :, :, j]
# disable gradient passing
with torch.no_grad():
Y_U_j = Y_U[:, :, :, :, j]
if self.pseudo_label_confidence_threshold <= 0.5:
Y_U_j = torch.argmax(Y_U_j, dim=1)
else:
Y_U_j = self._prune_pseudo_label_by_threshold(Y_U_j)
cps_loss += self.loss(P_U_j, Y_U_j)
else:
for r in range(self.n_models):
for l in range(r):
P_U_l = preds_U[:, :, :, :, l]
P_U_r = preds_U[:, :, :, :, r]
if self.use_momentum:
Y_U_l = t_preds_U[:, :, :, :, l]
Y_U_r = t_preds_U[:, :, :, :, r]
# compute cps loss, disable gradient passing
with torch.no_grad():
# if threshold <= 0.5, don't use threshold clipping
if self.pseudo_label_confidence_threshold <= 0.5:
if self.use_momentum:
Y_U_l = torch.argmax(Y_U_l, dim=1)
Y_U_r = torch.argmax(Y_U_r, dim=1)
else:
Y_U_l = torch.argmax(P_U_l, dim=1)
Y_U_r = torch.argmax(P_U_r, dim=1)
# otherwise, only concern with pseudo labels which have high confidence
else:
if self.use_momentum:
Y_U_l = self._prune_pseudo_label_by_threshold(Y_U_l)
Y_U_r = self._prune_pseudo_label_by_threshold(Y_U_r)
else:
Y_U_l = self._prune_pseudo_label_by_threshold(P_U_l)
Y_U_r = self._prune_pseudo_label_by_threshold(P_U_r)
cps_loss += self.loss(P_U_l, Y_U_r) + self.loss(P_U_r, Y_U_l)
# if all pseudo labels are ignored
if torch.isnan(cps_loss):
cps_loss.zero_()
# compute supervision loss
for j in range(self.n_models):
P_ce_j = preds_L[:, :, :, :, j]
ce_loss += self.loss(P_ce_j, targets)
# combine two losses
combined_loss = (ce_loss + float(self.trade_off_factor / (self.n_models - 1)) * cps_loss) / self.n_models
return combined_loss