[Enhancement] Support ignore_index for sigmoid BCE (#210)

* [Enhancement] Add args check for ignore_index

* Support ignore_index
This commit is contained in:
Jerry Jiarui XU 2020-11-17 00:14:03 -08:00 committed by GitHub
parent c2608b212a
commit 61e1d5c814
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 48 additions and 18 deletions

View File

@ -25,7 +25,7 @@ model = dict(
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.)),
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.4)),
auxiliary_head=[
dict(
type='FCNHead',
@ -38,7 +38,7 @@ model = dict(
concat_input=False,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.4)),
dict(
type='FCNHead',
in_channels=64,
@ -50,7 +50,7 @@ model = dict(
concat_input=False,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.4)),
])
# model training and testing settings

View File

@ -4,7 +4,7 @@ _base_ = [
]
# Re-config the data sampler.
data = dict(samples_per_gpu=8, workers_per_gpu=4)
data = dict(samples_per_gpu=2, workers_per_gpu=4)
# Re-config the optimizer.
optimizer = dict(type='SGD', lr=0.12, momentum=0.9, weight_decay=4e-5)

View File

@ -35,7 +35,8 @@ class BaseDecodeHead(nn.Module, metaclass=ABCMeta):
Default: None.
loss_decode (dict): Config of decode loss.
Default: dict(type='CrossEntropyLoss').
ignore_index (int): The label index to be ignored. Default: 255
ignore_index (int | None): The label index to be ignored. When using
masked BCE loss, ignore_index should be set to None. Default: 255
sampler (dict|None): The config of segmentation map sampler.
Default: None.
align_corners (bool): align_corners argument of F.interpolate.

View File

@ -32,17 +32,25 @@ def cross_entropy(pred,
return loss
def _expand_onehot_labels(labels, label_weights, label_channels):
def _expand_onehot_labels(labels, label_weights, target_shape, ignore_index):
"""Expand onehot labels to match the size of prediction."""
bin_labels = labels.new_full((labels.size(0), label_channels), 0)
inds = torch.nonzero(labels >= 1, as_tuple=False).squeeze()
if inds.numel() > 0:
bin_labels[inds, labels[inds] - 1] = 1
bin_labels = labels.new_zeros(target_shape)
valid_mask = (labels >= 0) & (labels != ignore_index)
inds = torch.nonzero(valid_mask, as_tuple=True)
if inds[0].numel() > 0:
if labels.dim() == 3:
bin_labels[inds[0], labels[valid_mask], inds[1], inds[2]] = 1
else:
bin_labels[inds[0], labels[valid_mask]] = 1
valid_mask = valid_mask.unsqueeze(1).expand(target_shape).float()
if label_weights is None:
bin_label_weights = None
bin_label_weights = valid_mask
else:
bin_label_weights = label_weights.view(-1, 1).expand(
label_weights.size(0), label_channels)
bin_label_weights = label_weights.unsqueeze(1).expand(target_shape)
bin_label_weights *= valid_mask
return bin_labels, bin_label_weights
@ -51,7 +59,8 @@ def binary_cross_entropy(pred,
weight=None,
reduction='mean',
avg_factor=None,
class_weight=None):
class_weight=None,
ignore_index=255):
"""Calculate the binary CrossEntropy loss.
Args:
@ -63,18 +72,24 @@ def binary_cross_entropy(pred,
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
class_weight (list[float], optional): The weight for each class.
ignore_index (int | None): The label index to be ignored. Default: 255
Returns:
torch.Tensor: The calculated loss
"""
if pred.dim() != label.dim():
label, weight = _expand_onehot_labels(label, weight, pred.size(-1))
assert (pred.dim() == 2 and label.dim() == 1) or (
pred.dim() == 4 and label.dim() == 3), \
'Only pred shape [N, C], label shape [N] or pred shape [N, C, ' \
'H, W], label shape [N, H, W] are supported'
label, weight = _expand_onehot_labels(label, weight, pred.shape,
ignore_index)
# weighted element-wise losses
if weight is not None:
weight = weight.float()
loss = F.binary_cross_entropy_with_logits(
pred, label.float(), weight=class_weight, reduction='none')
pred, label.float(), pos_weight=class_weight, reduction='none')
# do the reduction for the weighted loss
loss = weight_reduce_loss(
loss, weight, reduction=reduction, avg_factor=avg_factor)
@ -87,7 +102,8 @@ def mask_cross_entropy(pred,
label,
reduction='mean',
avg_factor=None,
class_weight=None):
class_weight=None,
ignore_index=None):
"""Calculate the CrossEntropy loss for masks.
Args:
@ -103,10 +119,13 @@ def mask_cross_entropy(pred,
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
class_weight (list[float], optional): The weight for each class.
ignore_index (None): Placeholder, to be consistent with other loss.
Default: None.
Returns:
torch.Tensor: The calculated loss
"""
assert ignore_index is None, 'BCE loss does not support ignore_index'
# TODO: handle these two reserved arguments
assert reduction == 'mean' and avg_factor is None
num_rois = pred.size()[0]

View File

@ -71,7 +71,17 @@ def test_ce_loss():
loss_cls_cfg = dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0)
loss_cls = build_loss(loss_cls_cfg)
assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(0.))
assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(100.))
fake_pred = torch.full(size=(2, 21, 8, 8), fill_value=0.5)
fake_label = torch.ones(2, 8, 8).long()
assert torch.allclose(
loss_cls(fake_pred, fake_label), torch.tensor(0.9503), atol=1e-4)
fake_label[:, 0, 0] = 255
assert torch.allclose(
loss_cls(fake_pred, fake_label, ignore_index=255),
torch.tensor(0.9354),
atol=1e-4)
# TODO test use_mask