-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathPolyLoss.py
126 lines (104 loc) · 4.88 KB
/
PolyLoss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
#################################################
### THIS FILE WAS AUTOGENERATED! DO NOT EDIT! ###
#################################################
import warnings
from typing import Optional
import torch
import torch.nn as nn
from torch.nn.modules.loss import _Loss
def to_one_hot(labels: torch.Tensor, num_classes: int, dtype: torch.dtype = torch.float, dim: int = 1) -> torch.Tensor:
# if `dim` is bigger, add singleton dim at the end
if labels.ndim < dim + 1:
shape = list(labels.shape) + [1] * (dim + 1 - len(labels.shape))
labels = torch.reshape(labels, shape)
sh = list(labels.shape)
if sh[dim] != 1:
raise AssertionError("labels should have a channel with length equal to one.")
sh[dim] = num_classes
o = torch.zeros(size=sh, dtype=dtype, device=labels.device)
labels = o.scatter_(dim=dim, index=labels.long(), value=1)
return labels
class PolyLoss(_Loss):
def __init__(self,
softmax: bool = True,
ce_weight: Optional[torch.Tensor] = None,
reduction: str = 'mean',
epsilon: float = 1.0,
) -> None:
super().__init__()
self.softmax = softmax
self.reduction = reduction
self.epsilon = epsilon
self.cross_entropy = nn.CrossEntropyLoss(weight=ce_weight, reduction='none')
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Args:
input: the shape should be BNH[WD], where N is the number of classes.
You can pass logits or probabilities as input, if pass logit, must set softmax=True
target: if target is in one-hot format, its shape should be BNH[WD],
if it is not one-hot encoded, it should has shape B1H[WD] or BH[WD], where N is the number of classes,
It should contain binary values
Raises:
ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"].
"""
if len(input.shape) - len(target.shape) == 1:
target = target.unsqueeze(1).long()
n_pred_ch, n_target_ch = input.shape[1], target.shape[1]
# target not in one-hot encode format, has shape B1H[WD]
if n_pred_ch != n_target_ch:
# squeeze out the channel dimension of size 1 to calculate ce loss
self.ce_loss = self.cross_entropy(input, torch.squeeze(target, dim=1).long())
# convert into one-hot format to calculate ce loss
target = to_one_hot(target, num_classes=n_pred_ch)
else:
# # target is in the one-hot format, convert to BH[WD] format to calculate ce loss
self.ce_loss = self.cross_entropy(input, torch.argmax(target, dim=1))
if self.softmax:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `softmax=True` ignored.")
else:
input = torch.softmax(input, 1)
pt = (input * target).sum(dim=1) # BH[WD]
poly_loss = self.ce_loss + self.epsilon * (1 - pt)
if self.reduction == 'mean':
polyl = torch.mean(poly_loss) # the batch and channel average
elif self.reduction == 'sum':
polyl = torch.sum(poly_loss) # sum over the batch and channel dims
elif self.reduction == 'none':
# BH[WD]
polyl = poly_loss
else:
raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].')
return (polyl)
class PolyBCELoss(_Loss):
def __init__(self,
reduction: str = 'mean',
epsilon: float = 1.0,
) -> None:
super().__init__()
self.reduction = reduction
self.epsilon = epsilon
self.bce = nn.BCEWithLogitsLoss(reduction='none')
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Args:
input: (∗), where * means any number of dimensions.
target: same shape as the input
Raises:
ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"].
"""
# # target is in the one-hot format, convert to BH[WD] format to calculate ce loss
self.bce_loss = self.bce(input, target)
pt = torch.sigmoid(input)
pt = torch.where(target ==1,pt,1-pt)
poly_loss = self.bce_loss + self.epsilon * (1 - pt)
if self.reduction == 'mean':
polyl = torch.mean(poly_loss) # the batch and channel average
elif self.reduction == 'sum':
polyl = torch.sum(poly_loss) # sum over the batch and channel dims
elif self.reduction == 'none':
# BH[WD]
polyl = poly_loss
else:
raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].')
return (polyl)