-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
81d562e
commit abf0a9f
Showing
1 changed file
with
389 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,389 @@ | ||
""" | ||
Created on April 8, 2022. | ||
generalizeddice.py | ||
@ref: https://github.com/wolny/pytorch-3dunet | ||
different types of Dice loss (multi label and multi class) | ||
""" | ||
import pdb | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
from torch import nn as nn | ||
from torch.autograd import Variable | ||
from torch.nn import MSELoss, SmoothL1Loss, L1Loss | ||
|
||
|
||
def compute_per_channel_dice(input, target, epsilon=1e-6, weight=None): | ||
""" | ||
Computes DiceCoefficient as defined in https://arxiv.org/abs/1606.04797 given a multi channel input and target. | ||
Assumes the input is a normalized probability, e.g. a result of Sigmoid or Softmax function. | ||
Args: | ||
input (torch.Tensor): NxCxSpatial input tensor | ||
target (torch.Tensor): NxCxSpatial target tensor | ||
epsilon (float): prevents division by zero | ||
weight (torch.Tensor): Cx1 tensor of weight per channel/class | ||
""" | ||
|
||
# input and target shapes must match | ||
assert input.size() == target.size(), "'input' and 'target' must have the same shape" | ||
|
||
input = flatten(input) | ||
target = flatten(target) | ||
target = target.float() | ||
|
||
# compute per channel Dice Coefficient | ||
intersect = (input * target).sum(-1) | ||
if weight is not None: | ||
intersect = weight * intersect | ||
|
||
# here we can use standard dice (input + target).sum(-1) or extension (see V-Net) (input^2 + target^2).sum(-1) | ||
denominator = (input * input).sum(-1) + (target * target).sum(-1) | ||
return 2 * (intersect / denominator.clamp(min=epsilon)) | ||
|
||
|
||
class _MaskingLossWrapper(nn.Module): | ||
""" | ||
Loss wrapper which prevents the gradient of the loss to be computed where target is equal to `ignore_index`. | ||
""" | ||
|
||
def __init__(self, loss, ignore_index): | ||
super(_MaskingLossWrapper, self).__init__() | ||
assert ignore_index is not None, 'ignore_index cannot be None' | ||
self.loss = loss | ||
self.ignore_index = ignore_index | ||
|
||
def forward(self, input, target): | ||
mask = target.clone().ne_(self.ignore_index) | ||
mask.requires_grad = False | ||
|
||
# mask out input/target so that the gradient is zero where on the mask | ||
input = input * mask | ||
target = target * mask | ||
|
||
# forward masked input and target to the loss | ||
return self.loss(input, target) | ||
|
||
|
||
class SkipLastTargetChannelWrapper(nn.Module): | ||
""" | ||
Loss wrapper which removes additional target channel | ||
""" | ||
|
||
def __init__(self, loss, squeeze_channel=False): | ||
super(SkipLastTargetChannelWrapper, self).__init__() | ||
self.loss = loss | ||
self.squeeze_channel = squeeze_channel | ||
|
||
def forward(self, input, target): | ||
assert target.size(1) > 1, 'Target tensor has a singleton channel dimension, cannot remove channel' | ||
|
||
# skips last target channel if needed | ||
target = target[:, :-1, ...] | ||
|
||
if self.squeeze_channel: | ||
# squeeze channel dimension if singleton | ||
target = torch.squeeze(target, dim=1) | ||
return self.loss(input, target) | ||
|
||
|
||
class _AbstractDiceLoss(nn.Module): | ||
""" | ||
Base class for different implementations of Dice loss. | ||
""" | ||
|
||
def __init__(self, weight=None, normalization='sigmoid'): | ||
super(_AbstractDiceLoss, self).__init__() | ||
self.register_buffer('weight', weight) | ||
# The output from the network during training is assumed to be un-normalized probabilities and we would | ||
# like to normalize the logits. Since Dice (or soft Dice in this case) is usually used for binary data, | ||
# normalizing the channels with Sigmoid is the default choice even for multi-class segmentation problems. | ||
# However if one would like to apply Softmax in order to get the proper probability distribution from the | ||
# output, just specify `normalization=Softmax` | ||
assert normalization in ['sigmoid', 'softmax', 'none'] | ||
if normalization == 'sigmoid': | ||
self.normalization = nn.Sigmoid() | ||
elif normalization == 'softmax': | ||
self.normalization = nn.Softmax(dim=1) | ||
else: | ||
self.normalization = lambda x: x | ||
|
||
def dice(self, input, target, weight): | ||
# actual Dice score computation; to be implemented by the subclass | ||
raise NotImplementedError | ||
|
||
def forward(self, input, target): | ||
# get probabilities from logits | ||
input = self.normalization(input) | ||
|
||
# compute per channel Dice coefficient | ||
per_channel_dice = self.dice(input, target, weight=self.weight) | ||
|
||
# average Dice score across all channels/classes | ||
return 1. - torch.mean(per_channel_dice) | ||
|
||
|
||
class DiceLoss(_AbstractDiceLoss): | ||
"""Computes Dice Loss according to https://arxiv.org/abs/1606.04797. | ||
For multi-class segmentation `weight` parameter can be used to assign different weights per class. | ||
The input to the loss function is assumed to be a logit and will be normalized by the Sigmoid function. | ||
""" | ||
|
||
def __init__(self, weight=None, normalization='sigmoid'): | ||
super().__init__(weight, normalization) | ||
|
||
def dice(self, input, target, weight): | ||
return compute_per_channel_dice(input, target, weight=self.weight) | ||
|
||
|
||
class GeneralizedDiceLoss(_AbstractDiceLoss): | ||
"""Computes Generalized Dice Loss (GDL) as described in https://arxiv.org/pdf/1707.03237.pdf. | ||
""" | ||
|
||
def __init__(self, normalization='sigmoid', epsilon=1e-6): | ||
super().__init__(weight=None, normalization=normalization) | ||
self.epsilon = epsilon | ||
|
||
def dice(self, input, target, weight): | ||
assert input.size() == target.size(), "'input' and 'target' must have the same shape" | ||
|
||
input = flatten(input) | ||
target = flatten(target) | ||
target = target.float() | ||
|
||
if input.size(0) == 1: | ||
# for GDL to make sense we need at least 2 channels (see https://arxiv.org/pdf/1707.03237.pdf) | ||
# put foreground and background voxels in separate channels | ||
input = torch.cat((input, 1 - input), dim=0) | ||
target = torch.cat((target, 1 - target), dim=0) | ||
|
||
# GDL weighting: the contribution of each label is corrected by the inverse of its volume | ||
w_l = target.sum(-1) | ||
w_l = 1 / (w_l * w_l).clamp(min=self.epsilon) | ||
w_l.requires_grad = False | ||
|
||
intersect = (input * target).sum(-1) | ||
intersect = intersect * w_l | ||
|
||
denominator = (input + target).sum(-1) | ||
denominator = (denominator * w_l).clamp(min=self.epsilon) | ||
|
||
return 2 * (intersect.sum() / denominator.sum()) | ||
|
||
|
||
class BCEDiceLoss(nn.Module): | ||
"""Linear combination of BCE and Dice losses""" | ||
|
||
def __init__(self, alpha, beta): | ||
super(BCEDiceLoss, self).__init__() | ||
self.alpha = alpha | ||
self.bce = nn.BCEWithLogitsLoss() | ||
self.beta = beta | ||
self.dice = DiceLoss() | ||
|
||
def forward(self, input, target): | ||
return self.alpha * self.bce(input, target) + self.beta * self.dice(input, target) | ||
|
||
|
||
class WeightedCrossEntropyLoss(nn.Module): | ||
"""WeightedCrossEntropyLoss (WCE) as described in https://arxiv.org/pdf/1707.03237.pdf | ||
""" | ||
|
||
def __init__(self, ignore_index=-1): | ||
super(WeightedCrossEntropyLoss, self).__init__() | ||
self.ignore_index = ignore_index | ||
|
||
def forward(self, input, target): | ||
weight = self._class_weights(input) | ||
return F.cross_entropy(input, target, weight=weight, ignore_index=self.ignore_index) | ||
|
||
@staticmethod | ||
def _class_weights(input): | ||
# normalize the input first | ||
input = F.softmax(input, dim=1) | ||
flattened = flatten(input) | ||
nominator = (1. - flattened).sum(-1) | ||
denominator = flattened.sum(-1) | ||
class_weights = Variable(nominator / denominator, requires_grad=False) | ||
return class_weights | ||
|
||
|
||
class PixelWiseCrossEntropyLoss(nn.Module): | ||
def __init__(self, class_weights=None, ignore_index=None): | ||
super(PixelWiseCrossEntropyLoss, self).__init__() | ||
self.register_buffer('class_weights', class_weights) | ||
self.ignore_index = ignore_index | ||
self.log_softmax = nn.LogSoftmax(dim=1) | ||
|
||
def forward(self, input, target, weights): | ||
assert target.size() == weights.size() | ||
# normalize the input | ||
log_probabilities = self.log_softmax(input) | ||
# standard CrossEntropyLoss requires the target to be (NxDxHxW), so we need to expand it to (NxCxDxHxW) | ||
target = expand_as_one_hot(target, C=input.size()[1], ignore_index=self.ignore_index) | ||
# expand weights | ||
weights = weights.unsqueeze(0) | ||
weights = weights.expand_as(input) | ||
|
||
# create default class_weights if None | ||
if self.class_weights is None: | ||
class_weights = torch.ones(input.size()[1]).float().to(input.device) | ||
else: | ||
class_weights = self.class_weights | ||
|
||
# resize class_weights to be broadcastable into the weights | ||
class_weights = class_weights.view(1, -1, 1, 1, 1) | ||
|
||
# multiply weights tensor by class weights | ||
weights = class_weights * weights | ||
|
||
# compute the losses | ||
result = -weights * target * log_probabilities | ||
# average the losses | ||
return result.mean() | ||
|
||
|
||
class WeightedSmoothL1Loss(nn.SmoothL1Loss): | ||
def __init__(self, threshold, initial_weight, apply_below_threshold=True): | ||
super().__init__(reduction="none") | ||
self.threshold = threshold | ||
self.apply_below_threshold = apply_below_threshold | ||
self.weight = initial_weight | ||
|
||
def forward(self, input, target): | ||
l1 = super().forward(input, target) | ||
|
||
if self.apply_below_threshold: | ||
mask = target < self.threshold | ||
else: | ||
mask = target >= self.threshold | ||
|
||
l1[mask] = l1[mask] * self.weight | ||
|
||
return l1.mean() | ||
|
||
|
||
def flatten(tensor): | ||
"""Flattens a given tensor such that the channel axis is first. | ||
The shapes are transformed as follows: | ||
(N, C, D, H, W) -> (C, N * D * H * W) | ||
""" | ||
# number of channels | ||
C = tensor.shape[1] | ||
# new axis order | ||
axis_order = (1, 0) + tuple(range(2, tensor.dim())) | ||
# Transpose: (N, C, D, H, W) -> (C, N, D, H, W) | ||
transposed = tensor.permute(axis_order) | ||
# Flatten: (C, N, D, H, W) -> (C, N * D * H * W) | ||
return transposed.contiguous().view(C, -1) | ||
|
||
|
||
def get_loss_criterion(config): | ||
""" | ||
Returns the loss function based on provided configuration | ||
:param config: (dict) a top level configuration object containing the 'loss' key | ||
:return: an instance of the loss function | ||
""" | ||
assert 'loss' in config, 'Could not find loss function configuration' | ||
loss_config = config['loss'] | ||
name = loss_config.pop('name') | ||
|
||
ignore_index = loss_config.pop('ignore_index', None) | ||
skip_last_target = loss_config.pop('skip_last_target', False) | ||
weight = loss_config.pop('weight', None) | ||
|
||
if weight is not None: | ||
# convert to cuda tensor if necessary | ||
weight = torch.tensor(weight).to(config['device']) | ||
|
||
pos_weight = loss_config.pop('pos_weight', None) | ||
if pos_weight is not None: | ||
# convert to cuda tensor if necessary | ||
pos_weight = torch.tensor(pos_weight).to(config['device']) | ||
|
||
loss = _create_loss(name, loss_config, weight, ignore_index, pos_weight) | ||
|
||
if not (ignore_index is None or name in ['CrossEntropyLoss', 'WeightedCrossEntropyLoss']): | ||
# use MaskingLossWrapper only for non-cross-entropy losses, since CE losses allow specifying 'ignore_index' directly | ||
loss = _MaskingLossWrapper(loss, ignore_index) | ||
|
||
if skip_last_target: | ||
loss = SkipLastTargetChannelWrapper(loss, loss_config.get('squeeze_channel', False)) | ||
|
||
return loss | ||
|
||
|
||
|
||
def expand_as_one_hot(input, C, ignore_index=None): | ||
""" | ||
Converts NxSPATIAL label image to NxCxSPATIAL, where each label gets converted to its corresponding one-hot vector. | ||
It is assumed that the batch dimension is present. | ||
Args: | ||
input (torch.Tensor): 3D/4D input image | ||
C (int): number of channels/labels | ||
ignore_index (int): ignore index to be kept during the expansion | ||
Returns: | ||
4D/5D output torch.Tensor (NxCxSPATIAL) | ||
""" | ||
assert input.dim() == 4 | ||
|
||
# expand the input tensor to Nx1xSPATIAL before scattering | ||
input = input.unsqueeze(1) | ||
# create output tensor shape (NxCxSPATIAL) | ||
shape = list(input.size()) | ||
shape[1] = C | ||
|
||
if ignore_index is not None: | ||
# create ignore_index mask for the result | ||
mask = input.expand(shape) == ignore_index | ||
# clone the src tensor and zero out ignore_index in the input | ||
input = input.clone() | ||
input[input == ignore_index] = 0 | ||
# scatter to get the one-hot tensor | ||
result = torch.zeros(shape).to(input.device).scatter_(1, input, 1) | ||
# bring back the ignore_index in the result | ||
result[mask] = ignore_index | ||
return result | ||
else: | ||
# scatter to get the one-hot tensor | ||
return torch.zeros(shape).to(input.device).scatter_(1, input, 1) | ||
|
||
####################################################################################################################### | ||
|
||
def _create_loss(name, loss_config, weight, ignore_index, pos_weight): | ||
if name == 'BCEWithLogitsLoss': | ||
return nn.BCEWithLogitsLoss(pos_weight=pos_weight) | ||
elif name == 'BCEDiceLoss': | ||
alpha = loss_config.get('alphs', 1.) | ||
beta = loss_config.get('beta', 1.) | ||
return BCEDiceLoss(alpha, beta) | ||
elif name == 'CrossEntropyLoss': | ||
if ignore_index is None: | ||
ignore_index = -100 # use the default 'ignore_index' as defined in the CrossEntropyLoss | ||
return nn.CrossEntropyLoss(weight=weight, ignore_index=ignore_index) | ||
elif name == 'WeightedCrossEntropyLoss': | ||
if ignore_index is None: | ||
ignore_index = -100 # use the default 'ignore_index' as defined in the CrossEntropyLoss | ||
return WeightedCrossEntropyLoss(ignore_index=ignore_index) | ||
elif name == 'PixelWiseCrossEntropyLoss': | ||
return PixelWiseCrossEntropyLoss(class_weights=weight, ignore_index=ignore_index) | ||
elif name == 'GeneralizedDiceLoss': | ||
normalization = loss_config.get('normalization', 'sigmoid') | ||
return GeneralizedDiceLoss(normalization=normalization) | ||
elif name == 'DiceLoss': | ||
normalization = loss_config.get('normalization', 'sigmoid') | ||
return DiceLoss(weight=weight, normalization=normalization) | ||
elif name == 'MSELoss': | ||
return MSELoss() | ||
elif name == 'SmoothL1Loss': | ||
return SmoothL1Loss() | ||
elif name == 'L1Loss': | ||
return L1Loss() | ||
elif name == 'WeightedSmoothL1Loss': | ||
return WeightedSmoothL1Loss(threshold=loss_config['threshold'], | ||
initial_weight=loss_config['initial_weight'], | ||
apply_below_threshold=loss_config.get('apply_below_threshold', True)) | ||
else: | ||
raise RuntimeError(f"Unsupported loss function: '{name}'") |