From 1b419899342a83c6dce653b50ecd29874c4c42c9 Mon Sep 17 00:00:00 2001 From: Rockey <41846794+RockeyCoss@users.noreply.github.com> Date: Fri, 3 Dec 2021 16:18:40 +0800 Subject: [PATCH] [Feature] Add focal loss (#1024) * [Feature] add focal loss * fix the bug of 'non' reduction type * refine the implementation * add class_weight and ignore_index; support different alpha values for different classes * fixed some bugs * fix bugs * add comments * modify test * Update mmseg/models/losses/focal_loss.py Co-authored-by: Junjun2016 * update test_focal_loss.py * modified the implementation * Update mmseg/models/losses/focal_loss.py Co-authored-by: Jerry Jiarui XU * update focal_loss.py Co-authored-by: Junjun2016 Co-authored-by: Jerry Jiarui XU --- mmseg/models/losses/__init__.py | 4 +- mmseg/models/losses/focal_loss.py | 327 ++++++++++++++++++ .../test_losses/test_focal_loss.py | 216 ++++++++++++ 3 files changed, 546 insertions(+), 1 deletion(-) create mode 100644 mmseg/models/losses/focal_loss.py create mode 100644 tests/test_models/test_losses/test_focal_loss.py diff --git a/mmseg/models/losses/__init__.py b/mmseg/models/losses/__init__.py index e85d8e0..fbc5b2d 100644 --- a/mmseg/models/losses/__init__.py +++ b/mmseg/models/losses/__init__.py @@ -3,11 +3,13 @@ from .accuracy import Accuracy, accuracy from .cross_entropy_loss import (CrossEntropyLoss, binary_cross_entropy, cross_entropy, mask_cross_entropy) from .dice_loss import DiceLoss +from .focal_loss import FocalLoss from .lovasz_loss import LovaszLoss from .utils import reduce_loss, weight_reduce_loss, weighted_loss __all__ = [ 'accuracy', 'Accuracy', 'cross_entropy', 'binary_cross_entropy', 'mask_cross_entropy', 'CrossEntropyLoss', 'reduce_loss', - 'weight_reduce_loss', 'weighted_loss', 'LovaszLoss', 'DiceLoss' + 'weight_reduce_loss', 'weighted_loss', 'LovaszLoss', 'DiceLoss', + 'FocalLoss' ] diff --git a/mmseg/models/losses/focal_loss.py b/mmseg/models/losses/focal_loss.py new file mode 100644 index 0000000..af1c711 --- /dev/null +++ b/mmseg/models/losses/focal_loss.py @@ -0,0 +1,327 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Modified from https://github.com/open-mmlab/mmdetection +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.ops import sigmoid_focal_loss as _sigmoid_focal_loss + +from ..builder import LOSSES +from .utils import weight_reduce_loss + + +# This method is used when cuda is not available +def py_sigmoid_focal_loss(pred, + target, + one_hot_target=None, + weight=None, + gamma=2.0, + alpha=0.5, + class_weight=None, + valid_mask=None, + reduction='mean', + avg_factor=None): + """PyTorch version of `Focal Loss `_. + + Args: + pred (torch.Tensor): The prediction with shape (N, C), C is the + number of classes + target (torch.Tensor): The learning label of the prediction with + shape (N, C) + one_hot_target (None): Placeholder. It should be None. + weight (torch.Tensor, optional): Sample-wise loss weight. + gamma (float, optional): The gamma for calculating the modulating + factor. Defaults to 2.0. + alpha (float | list[float], optional): A balanced form for Focal Loss. + Defaults to 0.5. + class_weight (list[float], optional): Weight of each class. + Defaults to None. + valid_mask (torch.Tensor, optional): A mask uses 1 to mark the valid + samples and uses 0 to mark the ignored samples. Default: None. + reduction (str, optional): The method used to reduce the loss into + a scalar. Defaults to 'mean'. + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + """ + if isinstance(alpha, list): + alpha = pred.new_tensor(alpha) + pred_sigmoid = pred.sigmoid() + target = target.type_as(pred) + one_minus_pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target) + focal_weight = (alpha * target + (1 - alpha) * + (1 - target)) * one_minus_pt.pow(gamma) + + loss = F.binary_cross_entropy_with_logits( + pred, target, reduction='none') * focal_weight + final_weight = torch.ones(1, pred.size(1)).type_as(loss) + if weight is not None: + if weight.shape != loss.shape and weight.size(0) == loss.size(0): + # For most cases, weight is of shape (N, ), + # which means it does not have the second axis num_class + weight = weight.view(-1, 1) + assert weight.dim() == loss.dim() + final_weight = final_weight * weight + if class_weight is not None: + final_weight = final_weight * pred.new_tensor(class_weight) + if valid_mask is not None: + final_weight = final_weight * valid_mask + loss = weight_reduce_loss(loss, final_weight, reduction, avg_factor) + return loss + + +def sigmoid_focal_loss(pred, + target, + one_hot_target, + weight=None, + gamma=2.0, + alpha=0.5, + class_weight=None, + valid_mask=None, + reduction='mean', + avg_factor=None): + r"""A warpper of cuda version `Focal Loss + `_. + Args: + pred (torch.Tensor): The prediction with shape (N, C), C is the number + of classes. + target (torch.Tensor): The learning label of the prediction. It's shape + should be (N, ) + one_hot_target (torch.Tensor): The learning label with shape (N, C) + weight (torch.Tensor, optional): Sample-wise loss weight. + gamma (float, optional): The gamma for calculating the modulating + factor. Defaults to 2.0. + alpha (float | list[float], optional): A balanced form for Focal Loss. + Defaults to 0.5. + class_weight (list[float], optional): Weight of each class. + Defaults to None. + valid_mask (torch.Tensor, optional): A mask uses 1 to mark the valid + samples and uses 0 to mark the ignored samples. Default: None. + reduction (str, optional): The method used to reduce the loss into + a scalar. Defaults to 'mean'. Options are "none", "mean" and "sum". + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + """ + # Function.apply does not accept keyword arguments, so the decorator + # "weighted_loss" is not applicable + final_weight = torch.ones(1, pred.size(1)).type_as(pred) + if isinstance(alpha, list): + # _sigmoid_focal_loss doesn't accept alpha of list type. Therefore, if + # a list is given, we set the input alpha as 0.5. This means setting + # equal weight for foreground class and background class. By + # multiplying the loss by 2, the effect of setting alpha as 0.5 is + # undone. The alpha of type list is used to regulate the loss in the + # post-processing process. + loss = _sigmoid_focal_loss(pred.contiguous(), target.contiguous(), + gamma, 0.5, None, 'none') * 2 + alpha = pred.new_tensor(alpha) + final_weight = final_weight * ( + alpha * one_hot_target + (1 - alpha) * (1 - one_hot_target)) + else: + loss = _sigmoid_focal_loss(pred.contiguous(), target.contiguous(), + gamma, alpha, None, 'none') + if weight is not None: + if weight.shape != loss.shape and weight.size(0) == loss.size(0): + # For most cases, weight is of shape (N, ), + # which means it does not have the second axis num_class + weight = weight.view(-1, 1) + assert weight.dim() == loss.dim() + final_weight = final_weight * weight + if class_weight is not None: + final_weight = final_weight * pred.new_tensor(class_weight) + if valid_mask is not None: + final_weight = final_weight * valid_mask + loss = weight_reduce_loss(loss, final_weight, reduction, avg_factor) + return loss + + +@LOSSES.register_module() +class FocalLoss(nn.Module): + + def __init__(self, + use_sigmoid=True, + gamma=2.0, + alpha=0.5, + reduction='mean', + class_weight=None, + loss_weight=1.0, + loss_name='loss_focal'): + """`Focal Loss `_ + Args: + use_sigmoid (bool, optional): Whether to the prediction is + used for sigmoid or softmax. Defaults to True. + gamma (float, optional): The gamma for calculating the modulating + factor. Defaults to 2.0. + alpha (float | list[float], optional): A balanced form for Focal + Loss. Defaults to 0.5. When a list is provided, the length + of the list should be equal to the number of classes. + Please be careful that this parameter is not the + class-wise weight but the weight of a binary classification + problem. This binary classification problem regards the + pixels which belong to one class as the foreground + and the other pixels as the background, each element in + the list is the weight of the corresponding foreground class. + The value of alpha or each element of alpha should be a float + in the interval [0, 1]. If you want to specify the class-wise + weight, please use `class_weight` parameter. + reduction (str, optional): The method used to reduce the loss into + a scalar. Defaults to 'mean'. Options are "none", "mean" and + "sum". + class_weight (list[float], optional): Weight of each class. + Defaults to None. + loss_weight (float, optional): Weight of loss. Defaults to 1.0. + loss_name (str, optional): Name of the loss item. If you want this + loss item to be included into the backward graph, `loss_` must + be the prefix of the name. Defaults to 'loss_focal'. + """ + super(FocalLoss, self).__init__() + assert use_sigmoid is True, \ + 'AssertionError: Only sigmoid focal loss supported now.' + assert reduction in ('none', 'mean', 'sum'), \ + "AssertionError: reduction should be 'none', 'mean' or " \ + "'sum'" + assert isinstance(alpha, (float, list)), \ + 'AssertionError: alpha should be of type float' + assert isinstance(gamma, float), \ + 'AssertionError: gamma should be of type float' + assert isinstance(loss_weight, float), \ + 'AssertionError: loss_weight should be of type float' + assert isinstance(loss_name, str), \ + 'AssertionError: loss_name should be of type str' + assert isinstance(class_weight, list) or class_weight is None, \ + 'AssertionError: class_weight must be None or of type list' + self.use_sigmoid = use_sigmoid + self.gamma = gamma + self.alpha = alpha + self.reduction = reduction + self.class_weight = class_weight + self.loss_weight = loss_weight + self._loss_name = loss_name + + def forward(self, + pred, + target, + weight=None, + avg_factor=None, + reduction_override=None, + ignore_index=255, + **kwargs): + """Forward function. + + Args: + pred (torch.Tensor): The prediction with shape + (N, C) where C = number of classes, or + (N, C, d_1, d_2, ..., d_K) with K≥1 in the + case of K-dimensional loss. + target (torch.Tensor): The ground truth. If containing class + indices, shape (N) where each value is 0≤targets[i]≤C−1, + or (N, d_1, d_2, ..., d_K) with K≥1 in the case of + K-dimensional loss. If containing class probabilities, + same shape as the input. + weight (torch.Tensor, optional): The weight of loss for each + prediction. Defaults to None. + avg_factor (int, optional): Average factor that is used to + average the loss. Defaults to None. + reduction_override (str, optional): The reduction method used + to override the original reduction method of the loss. + Options are "none", "mean" and "sum". + ignore_index (int, optional): The label index to be ignored. + Default: 255 + Returns: + torch.Tensor: The calculated loss + """ + assert isinstance(ignore_index, int), \ + 'ignore_index must be of type int' + assert reduction_override in (None, 'none', 'mean', 'sum'), \ + "AssertionError: reduction should be 'none', 'mean' or " \ + "'sum'" + assert pred.shape == target.shape or \ + (pred.size(0) == target.size(0) and + pred.shape[2:] == target.shape[1:]), \ + "The shape of pred doesn't match the shape of target" + + original_shape = pred.shape + + # [B, C, d_1, d_2, ..., d_k] -> [C, B, d_1, d_2, ..., d_k] + pred = pred.transpose(0, 1) + # [C, B, d_1, d_2, ..., d_k] -> [C, N] + pred = pred.reshape(pred.size(0), -1) + # [C, N] -> [N, C] + pred = pred.transpose(0, 1).contiguous() + + if original_shape == target.shape: + # target with shape [B, C, d_1, d_2, ...] + # transform it's shape into [N, C] + # [B, C, d_1, d_2, ...] -> [C, B, d_1, d_2, ..., d_k] + target = target.transpose(0, 1) + # [C, B, d_1, d_2, ..., d_k] -> [C, N] + target = target.reshape(target.size(0), -1) + # [C, N] -> [N, C] + target = target.transpose(0, 1).contiguous() + else: + # target with shape [B, d_1, d_2, ...] + # transform it's shape into [N, ] + target = target.view(-1).contiguous() + valid_mask = (target != ignore_index).view(-1, 1) + # avoid raising error when using F.one_hot() + target = torch.where(target == ignore_index, target.new_tensor(0), + target) + + reduction = ( + reduction_override if reduction_override else self.reduction) + if self.use_sigmoid: + num_classes = pred.size(1) + if torch.cuda.is_available() and pred.is_cuda: + if target.dim() == 1: + one_hot_target = F.one_hot(target, num_classes=num_classes) + else: + one_hot_target = target + target = target.argmax(dim=1) + valid_mask = (target != ignore_index).view(-1, 1) + calculate_loss_func = sigmoid_focal_loss + else: + one_hot_target = None + if target.dim() == 1: + target = F.one_hot(target, num_classes=num_classes) + else: + valid_mask = (target.argmax(dim=1) != ignore_index).view( + -1, 1) + calculate_loss_func = py_sigmoid_focal_loss + + loss_cls = self.loss_weight * calculate_loss_func( + pred, + target, + one_hot_target, + weight, + gamma=self.gamma, + alpha=self.alpha, + class_weight=self.class_weight, + valid_mask=valid_mask, + reduction=reduction, + avg_factor=avg_factor) + + if reduction == 'none': + # [N, C] -> [C, N] + loss_cls = loss_cls.transpose(0, 1) + # [C, N] -> [C, B, d1, d2, ...] + # original_shape: [B, C, d1, d2, ...] + loss_cls = loss_cls.reshape(original_shape[1], + original_shape[0], + *original_shape[2:]) + # [C, B, d1, d2, ...] -> [B, C, d1, d2, ...] + loss_cls = loss_cls.transpose(0, 1).contiguous() + else: + raise NotImplementedError + return loss_cls + + @property + def loss_name(self): + """Loss Name. + + This function must be implemented and will return the name of this + loss function. This name will be used to combine different loss items + by simple sum operation. In addition, if you want this loss item to be + included into the backward graph, `loss_` must be the prefix of the + name. + Returns: + str: The name of this loss item. + """ + return self._loss_name diff --git a/tests/test_models/test_losses/test_focal_loss.py b/tests/test_models/test_losses/test_focal_loss.py new file mode 100644 index 0000000..687312b --- /dev/null +++ b/tests/test_models/test_losses/test_focal_loss.py @@ -0,0 +1,216 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import pytest +import torch +import torch.nn.functional as F + +from mmseg.models import build_loss + + +# test focal loss with use_sigmoid=False +def test_use_sigmoid(): + # can't init with use_sigmoid=True + with pytest.raises(AssertionError): + loss_cfg = dict(type='FocalLoss', use_sigmoid=False) + build_loss(loss_cfg) + + # can't forward with use_sigmoid=True + with pytest.raises(NotImplementedError): + loss_cfg = dict(type='FocalLoss', use_sigmoid=True) + focal_loss = build_loss(loss_cfg) + focal_loss.use_sigmoid = False + fake_pred = torch.rand(3, 4, 5, 6) + fake_target = torch.randint(0, 4, (3, 5, 6)) + focal_loss(fake_pred, fake_target) + + +# reduction type must be 'none', 'mean' or 'sum' +def test_wrong_reduction_type(): + # can't init with wrong reduction + with pytest.raises(AssertionError): + loss_cfg = dict(type='FocalLoss', reduction='test') + build_loss(loss_cfg) + + # can't forward with wrong reduction override + with pytest.raises(AssertionError): + loss_cfg = dict(type='FocalLoss') + focal_loss = build_loss(loss_cfg) + fake_pred = torch.rand(3, 4, 5, 6) + fake_target = torch.randint(0, 4, (3, 5, 6)) + focal_loss(fake_pred, fake_target, reduction_override='test') + + +# test focal loss can handle input parameters with +# unacceptable types +def test_unacceptable_parameters(): + with pytest.raises(AssertionError): + loss_cfg = dict(type='FocalLoss', gamma='test') + build_loss(loss_cfg) + with pytest.raises(AssertionError): + loss_cfg = dict(type='FocalLoss', alpha='test') + build_loss(loss_cfg) + with pytest.raises(AssertionError): + loss_cfg = dict(type='FocalLoss', class_weight='test') + build_loss(loss_cfg) + with pytest.raises(AssertionError): + loss_cfg = dict(type='FocalLoss', loss_weight='test') + build_loss(loss_cfg) + with pytest.raises(AssertionError): + loss_cfg = dict(type='FocalLoss', loss_name=123) + build_loss(loss_cfg) + + +# test if focal loss can be correctly initialize +def test_init_focal_loss(): + loss_cfg = dict( + type='FocalLoss', + use_sigmoid=True, + gamma=3.0, + alpha=3.0, + class_weight=[1, 2, 3, 4], + reduction='sum') + focal_loss = build_loss(loss_cfg) + assert focal_loss.use_sigmoid is True + assert focal_loss.gamma == 3.0 + assert focal_loss.alpha == 3.0 + assert focal_loss.reduction == 'sum' + assert focal_loss.class_weight == [1, 2, 3, 4] + assert focal_loss.loss_weight == 1.0 + assert focal_loss.loss_name == 'loss_focal' + + +# test reduction override +def test_reduction_override(): + loss_cfg = dict(type='FocalLoss', reduction='mean') + focal_loss = build_loss(loss_cfg) + fake_pred = torch.rand(3, 4, 5, 6) + fake_target = torch.randint(0, 4, (3, 5, 6)) + loss = focal_loss(fake_pred, fake_target, reduction_override='none') + assert loss.shape == fake_pred.shape + + +# test wrong pred and target shape +def test_wrong_pred_and_target_shape(): + loss_cfg = dict(type='FocalLoss') + focal_loss = build_loss(loss_cfg) + fake_pred = torch.rand(3, 4, 5, 6) + fake_target = torch.randint(0, 4, (3, 2, 2)) + fake_target = F.one_hot(fake_target, num_classes=4) + fake_target = fake_target.permute(0, 3, 1, 2) + with pytest.raises(AssertionError): + focal_loss(fake_pred, fake_target) + + +# test forward with different shape of target +def test_forward_with_different_shape_of_target(): + loss_cfg = dict(type='FocalLoss') + focal_loss = build_loss(loss_cfg) + + fake_pred = torch.rand(3, 4, 5, 6) + fake_target = torch.randint(0, 4, (3, 5, 6)) + loss1 = focal_loss(fake_pred, fake_target) + + fake_target = F.one_hot(fake_target, num_classes=4) + fake_target = fake_target.permute(0, 3, 1, 2) + loss2 = focal_loss(fake_pred, fake_target) + assert loss1 == loss2 + + +# test forward with weight +def test_forward_with_weight(): + loss_cfg = dict(type='FocalLoss') + focal_loss = build_loss(loss_cfg) + fake_pred = torch.rand(3, 4, 5, 6) + fake_target = torch.randint(0, 4, (3, 5, 6)) + weight = torch.rand(3 * 5 * 6, 1) + loss1 = focal_loss(fake_pred, fake_target, weight=weight) + + weight2 = weight.view(-1) + loss2 = focal_loss(fake_pred, fake_target, weight=weight2) + + weight3 = weight.expand(3 * 5 * 6, 4) + loss3 = focal_loss(fake_pred, fake_target, weight=weight3) + assert loss1 == loss2 == loss3 + + +# test none reduction type +def test_none_reduction_type(): + loss_cfg = dict(type='FocalLoss', reduction='none') + focal_loss = build_loss(loss_cfg) + fake_pred = torch.rand(3, 4, 5, 6) + fake_target = torch.randint(0, 4, (3, 5, 6)) + loss = focal_loss(fake_pred, fake_target) + assert loss.shape == fake_pred.shape + + +# test the usage of class weight +def test_class_weight(): + loss_cfg_cw = dict( + type='FocalLoss', reduction='none', class_weight=[1.0, 2.0, 3.0, 4.0]) + loss_cfg = dict(type='FocalLoss', reduction='none') + focal_loss_cw = build_loss(loss_cfg_cw) + focal_loss = build_loss(loss_cfg) + fake_pred = torch.rand(3, 4, 5, 6) + fake_target = torch.randint(0, 4, (3, 5, 6)) + loss_cw = focal_loss_cw(fake_pred, fake_target) + loss = focal_loss(fake_pred, fake_target) + weight = torch.tensor([1, 2, 3, 4]).view(1, 4, 1, 1) + assert (loss * weight == loss_cw).all() + + +# test ignore index +def test_ignore_index(): + loss_cfg = dict(type='FocalLoss', reduction='none') + # ignore_index within C classes + focal_loss = build_loss(loss_cfg) + fake_pred = torch.rand(3, 5, 5, 6) + fake_target = torch.randint(0, 4, (3, 5, 6)) + dim1 = torch.randint(0, 3, (4, )) + dim2 = torch.randint(0, 5, (4, )) + dim3 = torch.randint(0, 6, (4, )) + fake_target[dim1, dim2, dim3] = 4 + loss1 = focal_loss(fake_pred, fake_target, ignore_index=4) + one_hot_target = F.one_hot(fake_target, num_classes=5) + one_hot_target = one_hot_target.permute(0, 3, 1, 2) + loss2 = focal_loss(fake_pred, one_hot_target, ignore_index=4) + assert (loss1 == loss2).all() + assert (loss1[dim1, :, dim2, dim3] == 0).all() + assert (loss2[dim1, :, dim2, dim3] == 0).all() + + fake_pred = torch.rand(3, 4, 5, 6) + fake_target = torch.randint(0, 4, (3, 5, 6)) + loss1 = focal_loss(fake_pred, fake_target, ignore_index=2) + one_hot_target = F.one_hot(fake_target, num_classes=4) + one_hot_target = one_hot_target.permute(0, 3, 1, 2) + loss2 = focal_loss(fake_pred, one_hot_target, ignore_index=2) + ignore_mask = one_hot_target == 2 + assert (loss1 == loss2).all() + assert torch.sum(loss1 * ignore_mask) == 0 + assert torch.sum(loss2 * ignore_mask) == 0 + + # ignore index is not in prediction's classes + fake_pred = torch.rand(3, 4, 5, 6) + fake_target = torch.randint(0, 4, (3, 5, 6)) + dim1 = torch.randint(0, 3, (4, )) + dim2 = torch.randint(0, 5, (4, )) + dim3 = torch.randint(0, 6, (4, )) + fake_target[dim1, dim2, dim3] = 255 + loss1 = focal_loss(fake_pred, fake_target, ignore_index=255) + assert (loss1[dim1, :, dim2, dim3] == 0).all() + + +# test list alpha +def test_alpha(): + loss_cfg = dict(type='FocalLoss') + focal_loss = build_loss(loss_cfg) + alpha_float = 0.4 + alpha = [0.4, 0.4, 0.4, 0.4] + alpha2 = [0.1, 0.3, 0.2, 0.1] + fake_pred = torch.rand(3, 4, 5, 6) + fake_target = torch.randint(0, 4, (3, 5, 6)) + focal_loss.alpha = alpha_float + loss1 = focal_loss(fake_pred, fake_target) + focal_loss.alpha = alpha + loss2 = focal_loss(fake_pred, fake_target) + assert loss1 == loss2 + focal_loss.alpha = alpha2 + focal_loss(fake_pred, fake_target)