diff --git a/models/generalizeddice.py b/models/generalizeddice.py new file mode 100644 index 0000000..617a991 --- /dev/null +++ b/models/generalizeddice.py @@ -0,0 +1,389 @@ +""" +Created on April 8, 2022. +generalizeddice.py + +@ref: https://github.com/wolny/pytorch-3dunet + +different types of Dice loss (multi label and multi class) +""" +import pdb + +import torch +import torch.nn.functional as F +from torch import nn as nn +from torch.autograd import Variable +from torch.nn import MSELoss, SmoothL1Loss, L1Loss + + +def compute_per_channel_dice(input, target, epsilon=1e-6, weight=None): + """ + Computes DiceCoefficient as defined in https://arxiv.org/abs/1606.04797 given a multi channel input and target. + Assumes the input is a normalized probability, e.g. a result of Sigmoid or Softmax function. + + Args: + input (torch.Tensor): NxCxSpatial input tensor + target (torch.Tensor): NxCxSpatial target tensor + epsilon (float): prevents division by zero + weight (torch.Tensor): Cx1 tensor of weight per channel/class + """ + + # input and target shapes must match + assert input.size() == target.size(), "'input' and 'target' must have the same shape" + + input = flatten(input) + target = flatten(target) + target = target.float() + + # compute per channel Dice Coefficient + intersect = (input * target).sum(-1) + if weight is not None: + intersect = weight * intersect + + # here we can use standard dice (input + target).sum(-1) or extension (see V-Net) (input^2 + target^2).sum(-1) + denominator = (input * input).sum(-1) + (target * target).sum(-1) + return 2 * (intersect / denominator.clamp(min=epsilon)) + + +class _MaskingLossWrapper(nn.Module): + """ + Loss wrapper which prevents the gradient of the loss to be computed where target is equal to `ignore_index`. + """ + + def __init__(self, loss, ignore_index): + super(_MaskingLossWrapper, self).__init__() + assert ignore_index is not None, 'ignore_index cannot be None' + self.loss = loss + self.ignore_index = ignore_index + + def forward(self, input, target): + mask = target.clone().ne_(self.ignore_index) + mask.requires_grad = False + + # mask out input/target so that the gradient is zero where on the mask + input = input * mask + target = target * mask + + # forward masked input and target to the loss + return self.loss(input, target) + + +class SkipLastTargetChannelWrapper(nn.Module): + """ + Loss wrapper which removes additional target channel + """ + + def __init__(self, loss, squeeze_channel=False): + super(SkipLastTargetChannelWrapper, self).__init__() + self.loss = loss + self.squeeze_channel = squeeze_channel + + def forward(self, input, target): + assert target.size(1) > 1, 'Target tensor has a singleton channel dimension, cannot remove channel' + + # skips last target channel if needed + target = target[:, :-1, ...] + + if self.squeeze_channel: + # squeeze channel dimension if singleton + target = torch.squeeze(target, dim=1) + return self.loss(input, target) + + +class _AbstractDiceLoss(nn.Module): + """ + Base class for different implementations of Dice loss. + """ + + def __init__(self, weight=None, normalization='sigmoid'): + super(_AbstractDiceLoss, self).__init__() + self.register_buffer('weight', weight) + # The output from the network during training is assumed to be un-normalized probabilities and we would + # like to normalize the logits. Since Dice (or soft Dice in this case) is usually used for binary data, + # normalizing the channels with Sigmoid is the default choice even for multi-class segmentation problems. + # However if one would like to apply Softmax in order to get the proper probability distribution from the + # output, just specify `normalization=Softmax` + assert normalization in ['sigmoid', 'softmax', 'none'] + if normalization == 'sigmoid': + self.normalization = nn.Sigmoid() + elif normalization == 'softmax': + self.normalization = nn.Softmax(dim=1) + else: + self.normalization = lambda x: x + + def dice(self, input, target, weight): + # actual Dice score computation; to be implemented by the subclass + raise NotImplementedError + + def forward(self, input, target): + # get probabilities from logits + input = self.normalization(input) + + # compute per channel Dice coefficient + per_channel_dice = self.dice(input, target, weight=self.weight) + + # average Dice score across all channels/classes + return 1. - torch.mean(per_channel_dice) + + +class DiceLoss(_AbstractDiceLoss): + """Computes Dice Loss according to https://arxiv.org/abs/1606.04797. + For multi-class segmentation `weight` parameter can be used to assign different weights per class. + The input to the loss function is assumed to be a logit and will be normalized by the Sigmoid function. + """ + + def __init__(self, weight=None, normalization='sigmoid'): + super().__init__(weight, normalization) + + def dice(self, input, target, weight): + return compute_per_channel_dice(input, target, weight=self.weight) + + +class GeneralizedDiceLoss(_AbstractDiceLoss): + """Computes Generalized Dice Loss (GDL) as described in https://arxiv.org/pdf/1707.03237.pdf. + """ + + def __init__(self, normalization='sigmoid', epsilon=1e-6): + super().__init__(weight=None, normalization=normalization) + self.epsilon = epsilon + + def dice(self, input, target, weight): + assert input.size() == target.size(), "'input' and 'target' must have the same shape" + + input = flatten(input) + target = flatten(target) + target = target.float() + + if input.size(0) == 1: + # for GDL to make sense we need at least 2 channels (see https://arxiv.org/pdf/1707.03237.pdf) + # put foreground and background voxels in separate channels + input = torch.cat((input, 1 - input), dim=0) + target = torch.cat((target, 1 - target), dim=0) + + # GDL weighting: the contribution of each label is corrected by the inverse of its volume + w_l = target.sum(-1) + w_l = 1 / (w_l * w_l).clamp(min=self.epsilon) + w_l.requires_grad = False + + intersect = (input * target).sum(-1) + intersect = intersect * w_l + + denominator = (input + target).sum(-1) + denominator = (denominator * w_l).clamp(min=self.epsilon) + + return 2 * (intersect.sum() / denominator.sum()) + + +class BCEDiceLoss(nn.Module): + """Linear combination of BCE and Dice losses""" + + def __init__(self, alpha, beta): + super(BCEDiceLoss, self).__init__() + self.alpha = alpha + self.bce = nn.BCEWithLogitsLoss() + self.beta = beta + self.dice = DiceLoss() + + def forward(self, input, target): + return self.alpha * self.bce(input, target) + self.beta * self.dice(input, target) + + +class WeightedCrossEntropyLoss(nn.Module): + """WeightedCrossEntropyLoss (WCE) as described in https://arxiv.org/pdf/1707.03237.pdf + """ + + def __init__(self, ignore_index=-1): + super(WeightedCrossEntropyLoss, self).__init__() + self.ignore_index = ignore_index + + def forward(self, input, target): + weight = self._class_weights(input) + return F.cross_entropy(input, target, weight=weight, ignore_index=self.ignore_index) + + @staticmethod + def _class_weights(input): + # normalize the input first + input = F.softmax(input, dim=1) + flattened = flatten(input) + nominator = (1. - flattened).sum(-1) + denominator = flattened.sum(-1) + class_weights = Variable(nominator / denominator, requires_grad=False) + return class_weights + + +class PixelWiseCrossEntropyLoss(nn.Module): + def __init__(self, class_weights=None, ignore_index=None): + super(PixelWiseCrossEntropyLoss, self).__init__() + self.register_buffer('class_weights', class_weights) + self.ignore_index = ignore_index + self.log_softmax = nn.LogSoftmax(dim=1) + + def forward(self, input, target, weights): + assert target.size() == weights.size() + # normalize the input + log_probabilities = self.log_softmax(input) + # standard CrossEntropyLoss requires the target to be (NxDxHxW), so we need to expand it to (NxCxDxHxW) + target = expand_as_one_hot(target, C=input.size()[1], ignore_index=self.ignore_index) + # expand weights + weights = weights.unsqueeze(0) + weights = weights.expand_as(input) + + # create default class_weights if None + if self.class_weights is None: + class_weights = torch.ones(input.size()[1]).float().to(input.device) + else: + class_weights = self.class_weights + + # resize class_weights to be broadcastable into the weights + class_weights = class_weights.view(1, -1, 1, 1, 1) + + # multiply weights tensor by class weights + weights = class_weights * weights + + # compute the losses + result = -weights * target * log_probabilities + # average the losses + return result.mean() + + +class WeightedSmoothL1Loss(nn.SmoothL1Loss): + def __init__(self, threshold, initial_weight, apply_below_threshold=True): + super().__init__(reduction="none") + self.threshold = threshold + self.apply_below_threshold = apply_below_threshold + self.weight = initial_weight + + def forward(self, input, target): + l1 = super().forward(input, target) + + if self.apply_below_threshold: + mask = target < self.threshold + else: + mask = target >= self.threshold + + l1[mask] = l1[mask] * self.weight + + return l1.mean() + + +def flatten(tensor): + """Flattens a given tensor such that the channel axis is first. + The shapes are transformed as follows: + (N, C, D, H, W) -> (C, N * D * H * W) + """ + # number of channels + C = tensor.shape[1] + # new axis order + axis_order = (1, 0) + tuple(range(2, tensor.dim())) + # Transpose: (N, C, D, H, W) -> (C, N, D, H, W) + transposed = tensor.permute(axis_order) + # Flatten: (C, N, D, H, W) -> (C, N * D * H * W) + return transposed.contiguous().view(C, -1) + + +def get_loss_criterion(config): + """ + Returns the loss function based on provided configuration + :param config: (dict) a top level configuration object containing the 'loss' key + :return: an instance of the loss function + """ + assert 'loss' in config, 'Could not find loss function configuration' + loss_config = config['loss'] + name = loss_config.pop('name') + + ignore_index = loss_config.pop('ignore_index', None) + skip_last_target = loss_config.pop('skip_last_target', False) + weight = loss_config.pop('weight', None) + + if weight is not None: + # convert to cuda tensor if necessary + weight = torch.tensor(weight).to(config['device']) + + pos_weight = loss_config.pop('pos_weight', None) + if pos_weight is not None: + # convert to cuda tensor if necessary + pos_weight = torch.tensor(pos_weight).to(config['device']) + + loss = _create_loss(name, loss_config, weight, ignore_index, pos_weight) + + if not (ignore_index is None or name in ['CrossEntropyLoss', 'WeightedCrossEntropyLoss']): + # use MaskingLossWrapper only for non-cross-entropy losses, since CE losses allow specifying 'ignore_index' directly + loss = _MaskingLossWrapper(loss, ignore_index) + + if skip_last_target: + loss = SkipLastTargetChannelWrapper(loss, loss_config.get('squeeze_channel', False)) + + return loss + + + +def expand_as_one_hot(input, C, ignore_index=None): + """ + Converts NxSPATIAL label image to NxCxSPATIAL, where each label gets converted to its corresponding one-hot vector. + It is assumed that the batch dimension is present. + Args: + input (torch.Tensor): 3D/4D input image + C (int): number of channels/labels + ignore_index (int): ignore index to be kept during the expansion + Returns: + 4D/5D output torch.Tensor (NxCxSPATIAL) + """ + assert input.dim() == 4 + + # expand the input tensor to Nx1xSPATIAL before scattering + input = input.unsqueeze(1) + # create output tensor shape (NxCxSPATIAL) + shape = list(input.size()) + shape[1] = C + + if ignore_index is not None: + # create ignore_index mask for the result + mask = input.expand(shape) == ignore_index + # clone the src tensor and zero out ignore_index in the input + input = input.clone() + input[input == ignore_index] = 0 + # scatter to get the one-hot tensor + result = torch.zeros(shape).to(input.device).scatter_(1, input, 1) + # bring back the ignore_index in the result + result[mask] = ignore_index + return result + else: + # scatter to get the one-hot tensor + return torch.zeros(shape).to(input.device).scatter_(1, input, 1) + +####################################################################################################################### + +def _create_loss(name, loss_config, weight, ignore_index, pos_weight): + if name == 'BCEWithLogitsLoss': + return nn.BCEWithLogitsLoss(pos_weight=pos_weight) + elif name == 'BCEDiceLoss': + alpha = loss_config.get('alphs', 1.) + beta = loss_config.get('beta', 1.) + return BCEDiceLoss(alpha, beta) + elif name == 'CrossEntropyLoss': + if ignore_index is None: + ignore_index = -100 # use the default 'ignore_index' as defined in the CrossEntropyLoss + return nn.CrossEntropyLoss(weight=weight, ignore_index=ignore_index) + elif name == 'WeightedCrossEntropyLoss': + if ignore_index is None: + ignore_index = -100 # use the default 'ignore_index' as defined in the CrossEntropyLoss + return WeightedCrossEntropyLoss(ignore_index=ignore_index) + elif name == 'PixelWiseCrossEntropyLoss': + return PixelWiseCrossEntropyLoss(class_weights=weight, ignore_index=ignore_index) + elif name == 'GeneralizedDiceLoss': + normalization = loss_config.get('normalization', 'sigmoid') + return GeneralizedDiceLoss(normalization=normalization) + elif name == 'DiceLoss': + normalization = loss_config.get('normalization', 'sigmoid') + return DiceLoss(weight=weight, normalization=normalization) + elif name == 'MSELoss': + return MSELoss() + elif name == 'SmoothL1Loss': + return SmoothL1Loss() + elif name == 'L1Loss': + return L1Loss() + elif name == 'WeightedSmoothL1Loss': + return WeightedSmoothL1Loss(threshold=loss_config['threshold'], + initial_weight=loss_config['initial_weight'], + apply_below_threshold=loss_config.get('apply_below_threshold', True)) + else: + raise RuntimeError(f"Unsupported loss function: '{name}'") \ No newline at end of file