-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsensitivity.py
230 lines (177 loc) · 7.69 KB
/
sensitivity.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
from typing import Callable
import torch
from nnunetv2.utilities.ddp_allgather import AllGatherGrad
from torch import nn
class SensitivityLoss(nn.Module):
def __init__(self, apply_nonlin: Callable = None, batch_dice: bool = False, do_bg: bool = True, smooth: float = 1.,
ddp: bool = True, clip_tp: float = None):
super(SensitivityLoss, self).__init__()
self.do_bg = do_bg
self.batch_dice = batch_dice
self.apply_nonlin = apply_nonlin
self.smooth = smooth
self.clip_tp = clip_tp
self.ddp = ddp
def forward(self, x, y, loss_mask=None):
shp_x = x.shape
if self.batch_dice:
axes = [0] + list(range(2, len(shp_x)))
else:
axes = list(range(2, len(shp_x)))
if self.apply_nonlin is not None:
x = self.apply_nonlin(x)
tp, _, fn, _ = get_tp_fp_fn_tn(x, y, axes, loss_mask, False)
if self.ddp and self.batch_dice:
tp = AllGatherGrad.apply(tp).sum(0)
fn = AllGatherGrad.apply(fn).sum(0)
if self.clip_tp is not None:
tp = torch.clip(tp, min=self.clip_tp , max=None)
nominator = tp
denominator = tp + fn
# TPR (true positivity rate) = sensitivity = recall
TPR = (nominator + self.smooth) / (torch.clip(denominator + self.smooth, 1e-8))
if not self.do_bg:
if self.batch_dice:
TPR = TPR[1:]
else:
TPR = TPR[:, 1:]
TPR = TPR.mean()
return -TPR
class DC_and_TPR_loss(nn.Module):
def __init__(self, apply_nonlin: Callable = None, batch_dice: bool = False, do_bg: bool = True, smooth: float = 1.,
ddp: bool = True, clip_tp: float = None):
super(DC_and_TPR_loss, self).__init__()
self.do_bg = do_bg
self.batch_dice = batch_dice
self.apply_nonlin = apply_nonlin
self.smooth = smooth
self.clip_tp = clip_tp
self.ddp = ddp
def forward(self, x, y, loss_mask=None):
shp_x = x.shape
if self.batch_dice:
axes = [0] + list(range(2, len(shp_x)))
else:
axes = list(range(2, len(shp_x)))
if self.apply_nonlin is not None:
x = self.apply_nonlin(x)
tp, fp, fn, _ = get_tp_fp_fn_tn(x, y, axes, loss_mask, False)
if self.ddp and self.batch_dice:
tp = AllGatherGrad.apply(tp).sum(0)
fp = AllGatherGrad.apply(fp).sum(0)
fn = AllGatherGrad.apply(fn).sum(0)
if self.clip_tp is not None:
tp = torch.clip(tp, min=self.clip_tp , max=None)
# TPR (true positivity rate) = sensitivity = recall
tpr = (tp + self.smooth) / (torch.clip(tp + fn + self.smooth, 1e-8))
dc = (2 * tp + self.smooth) / (torch.clip(2 * tp + fp + fn + self.smooth, 1e-8))
if not self.do_bg:
if self.batch_dice:
tpr = tpr[1:]
dc = dc[1:]
else:
tpr = tpr[:, 1:]
dc = dc[:, 1:]
tpr = tpr.mean()
dc = dc.mean()
return 1-tpr-dc
class MemoryEfficientSoftDiceLossAndTPR(nn.Module):
def __init__(self, apply_nonlin: Callable = None, batch_dice: bool = False, do_bg: bool = True, smooth: float = 1.,
ddp: bool = True):
super(MemoryEfficientSoftDiceLossAndTPR, self).__init__()
self.do_bg = do_bg
self.batch_dice = batch_dice
self.apply_nonlin = apply_nonlin
self.smooth = smooth
self.ddp = ddp
def forward(self, x, y, loss_mask=None):
shp_x, shp_y = x.shape, y.shape
if self.apply_nonlin is not None:
x = self.apply_nonlin(x)
if not self.do_bg:
x = x[:, 1:]
# make everything shape (b, c)
axes = list(range(2, len(shp_x)))
with torch.no_grad():
if len(shp_x) != len(shp_y):
y = y.view((shp_y[0], 1, *shp_y[1:]))
if all([i == j for i, j in zip(shp_x, shp_y)]):
# if this is the case then gt is probably already a one hot encoding
y_onehot = y
else:
gt = y.long()
y_onehot = torch.zeros(shp_x, device=x.device, dtype=torch.bool)
y_onehot.scatter_(1, gt, 1)
if not self.do_bg:
y_onehot = y_onehot[:, 1:]
sum_gt = y_onehot.sum(axes) if loss_mask is None else (y_onehot * loss_mask).sum(axes)
intersect = (x * y_onehot).sum(axes) if loss_mask is None else (x * y_onehot * loss_mask).sum(axes)
sum_pred = (x * x).sum(axes) if loss_mask is None else (x * x * loss_mask).sum(axes)
if self.ddp and self.batch_dice:
intersect = AllGatherGrad.apply(intersect).sum(0)
sum_pred = AllGatherGrad.apply(sum_pred).sum(0)
sum_gt = AllGatherGrad.apply(sum_gt).sum(0)
if self.batch_dice:
intersect = intersect.sum(0)
sum_pred = sum_pred.sum(0)
sum_gt = sum_gt.sum(0)
dc = (2 * intersect + self.smooth) / (torch.clip(sum_gt + sum_pred + self.smooth, 1e-8))
tpr = (intersect + self.smooth) / (torch.clip(sum_gt + self.smooth, 1e-8))
dc = dc.mean()
tpr = tpr.mean()
return -dc - 0.05*tpr
def get_tp_fp_fn_tn(net_output, gt, axes=None, mask=None, square=False):
"""
net_output must be (b, c, x, y(, z)))
gt must be a label map (shape (b, 1, x, y(, z)) OR shape (b, x, y(, z))) or one hot encoding (b, c, x, y(, z))
if mask is provided it must have shape (b, 1, x, y(, z)))
:param net_output:
:param gt:
:param axes: can be (, ) = no summation
:param mask: mask must be 1 for valid pixels and 0 for invalid pixels
:param square: if True then fp, tp and fn will be squared before summation
:return:
"""
if axes is None:
axes = tuple(range(2, len(net_output.size())))
shp_x = net_output.shape
shp_y = gt.shape
with torch.no_grad():
if len(shp_x) != len(shp_y):
gt = gt.view((shp_y[0], 1, *shp_y[1:]))
if all([i == j for i, j in zip(net_output.shape, gt.shape)]):
# if this is the case then gt is probably already a one hot encoding
y_onehot = gt
else:
gt = gt.long()
y_onehot = torch.zeros(shp_x, device=net_output.device)
y_onehot.scatter_(1, gt, 1)
tp = net_output * y_onehot
fp = net_output * (1 - y_onehot)
fn = (1 - net_output) * y_onehot
tn = (1 - net_output) * (1 - y_onehot)
if mask is not None:
with torch.no_grad():
mask_here = torch.tile(mask, (1, tp.shape[1], *[1 for i in range(2, len(tp.shape))]))
tp *= mask_here
fp *= mask_here
fn *= mask_here
tn *= mask_here
# benchmark whether tiling the mask would be faster (torch.tile). It probably is for large batch sizes
# OK it barely makes a difference but the implementation above is a tiny bit faster + uses less vram
# (using nnUNetv2_train 998 3d_fullres 0)
# tp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(tp, dim=1)), dim=1)
# fp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fp, dim=1)), dim=1)
# fn = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fn, dim=1)), dim=1)
# tn = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(tn, dim=1)), dim=1)
if square:
tp = tp ** 2
fp = fp ** 2
fn = fn ** 2
tn = tn ** 2
if len(axes) > 0:
tp = tp.sum(dim=axes, keepdim=False)
fp = fp.sum(dim=axes, keepdim=False)
fn = fn.sum(dim=axes, keepdim=False)
tn = tn.sum(dim=axes, keepdim=False)
return tp, fp, fn, tn