diff --git a/mmseg/models/losses/__init__.py b/mmseg/models/losses/__init__.py index d623887..beca720 100644 --- a/mmseg/models/losses/__init__.py +++ b/mmseg/models/losses/__init__.py @@ -1,11 +1,12 @@ from .accuracy import Accuracy, accuracy from .cross_entropy_loss import (CrossEntropyLoss, binary_cross_entropy, cross_entropy, mask_cross_entropy) +from .dice_loss import DiceLoss from .lovasz_loss import LovaszLoss from .utils import reduce_loss, weight_reduce_loss, weighted_loss __all__ = [ 'accuracy', 'Accuracy', 'cross_entropy', 'binary_cross_entropy', 'mask_cross_entropy', 'CrossEntropyLoss', 'reduce_loss', - 'weight_reduce_loss', 'weighted_loss', 'LovaszLoss' + 'weight_reduce_loss', 'weighted_loss', 'LovaszLoss', 'DiceLoss' ] diff --git a/mmseg/models/losses/dice_loss.py b/mmseg/models/losses/dice_loss.py new file mode 100644 index 0000000..27da861 --- /dev/null +++ b/mmseg/models/losses/dice_loss.py @@ -0,0 +1,116 @@ +"""Modified from https://github.com/LikeLy-Journey/SegmenTron/blob/master/ +segmentron/solver/loss.py (Apache-2.0 License)""" +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ..builder import LOSSES +from .utils import weighted_loss + + +@weighted_loss +def dice_loss(pred, + target, + valid_mask, + smooth=1, + exponent=2, + class_weight=None, + ignore_index=-1): + assert pred.shape[0] == target.shape[0] + total_loss = 0 + num_classes = pred.shape[1] + for i in range(num_classes): + if i != ignore_index: + dice_loss = binary_dice_loss( + pred[:, i], + target[..., i], + valid_mask=valid_mask, + smooth=smooth, + exponent=exponent) + if class_weight is not None: + dice_loss *= class_weight[i] + total_loss += dice_loss + return total_loss / num_classes + + +@weighted_loss +def binary_dice_loss(pred, target, valid_mask, smooth=1, exponent=2, **kwards): + assert pred.shape[0] == target.shape[0] + pred = pred.contiguous().view(pred.shape[0], -1) + target = target.contiguous().view(target.shape[0], -1) + valid_mask = valid_mask.contiguous().view(valid_mask.shape[0], -1) + + num = torch.sum(torch.mul(pred, target) * valid_mask, dim=1) * 2 + smooth + den = torch.sum(pred.pow(exponent) + target.pow(exponent), dim=1) + smooth + + return 1 - num / den + + +@LOSSES.register_module() +class DiceLoss(nn.Module): + """DiceLoss. + + This loss is proposed in `V-Net: Fully Convolutional Neural Networks for + Volumetric Medical Image Segmentation `_. + + Args: + loss_type (str, optional): Binary or multi-class loss. + Default: 'multi_class'. Options are "binary" and "multi_class". + smooth (float): A float number to smooth loss, and avoid NaN error. + Default: 1 + exponent (float): An float number to calculate denominator + value: \\sum{x^exponent} + \\sum{y^exponent}. Default: 2. + reduction (str, optional): The method used to reduce the loss. Options + are "none", "mean" and "sum". This parameter only works when + per_image is True. Default: 'mean'. + class_weight (list[float], optional): The weight for each class. + Default: None. + loss_weight (float, optional): Weight of the loss. Default to 1.0. + ignore_index (int | None): The label index to be ignored. Default: 255. + """ + + def __init__(self, + loss_type='multi_class', + smooth=1, + exponent=2, + reduction='mean', + class_weight=None, + loss_weight=1.0, + ignore_index=255): + super(DiceLoss, self).__init__() + assert loss_type in ['multi_class', 'binary'] + if loss_type == 'multi_class': + self.cls_criterion = dice_loss + else: + self.cls_criterion = binary_dice_loss + self.smooth = smooth + self.exponent = exponent + self.reduction = reduction + self.class_weight = class_weight + self.loss_weight = loss_weight + self.ignore_index = ignore_index + + def forward(self, pred, target, avg_factor=None, reduction_override=None): + assert reduction_override in (None, 'none', 'mean', 'sum') + reduction = ( + reduction_override if reduction_override else self.reduction) + if self.class_weight is not None: + class_weight = pred.new_tensor(self.class_weight) + else: + class_weight = None + + pred = F.softmax(pred, dim=1) + one_hot_target = F.one_hot(torch.clamp_min(target.long(), 0)) + valid_mask = (target != self.ignore_index).long() + + loss = self.loss_weight * self.cls_criterion( + pred, + one_hot_target, + valid_mask=valid_mask, + reduction=reduction, + avg_factor=avg_factor, + smooth=self.smooth, + exponent=self.exponent, + class_weight=class_weight, + ignore_index=self.ignore_index) + return loss diff --git a/tests/test_models/test_losses.py b/tests/test_models/test_losses.py index 005d939..481a8e9 100644 --- a/tests/test_models/test_losses.py +++ b/tests/test_models/test_losses.py @@ -202,3 +202,43 @@ def test_lovasz_loss(): logits = torch.rand(2, 4, 4) labels = (torch.rand(2, 4, 4)).long() lovasz_loss(logits, labels, ignore_index=None) + + +def test_dice_lose(): + from mmseg.models import build_loss + + # loss_type should be 'binary' or 'multi_class' + with pytest.raises(AssertionError): + loss_cfg = dict( + type='DiceLoss', + loss_type='Binary', + reduction='none', + loss_weight=1.0) + build_loss(loss_cfg) + + # test dice loss with loss_type = 'multi_class' + loss_cfg = dict( + type='DiceLoss', + loss_type='multi_class', + reduction='none', + class_weight=[1.0, 2.0, 3.0], + loss_weight=1.0, + ignore_index=1) + dice_loss = build_loss(loss_cfg) + logits = torch.rand(8, 3, 4, 4) + labels = (torch.rand(8, 4, 4) * 3).long() + dice_loss(logits, labels) + + # test dice loss with loss_type = 'binary' + loss_cfg = dict( + type='DiceLoss', + loss_type='binary', + smooth=2, + exponent=3, + reduction='sum', + loss_weight=1.0, + ignore_index=0) + dice_loss = build_loss(loss_cfg) + logits = torch.rand(16, 4, 4) + labels = (torch.rand(16, 4, 4)).long() + dice_loss(logits, labels)