[Enhancement] Support ignore_index for sigmoid BCE (#210)
* [Enhancement] Add args check for ignore_index * Support ignore_index
This commit is contained in:
parent
c2608b212a
commit
61e1d5c814
@ -25,7 +25,7 @@ model = dict(
|
||||
norm_cfg=norm_cfg,
|
||||
align_corners=False,
|
||||
loss_decode=dict(
|
||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.)),
|
||||
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.4)),
|
||||
auxiliary_head=[
|
||||
dict(
|
||||
type='FCNHead',
|
||||
@ -38,7 +38,7 @@ model = dict(
|
||||
concat_input=False,
|
||||
align_corners=False,
|
||||
loss_decode=dict(
|
||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
|
||||
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.4)),
|
||||
dict(
|
||||
type='FCNHead',
|
||||
in_channels=64,
|
||||
@ -50,7 +50,7 @@ model = dict(
|
||||
concat_input=False,
|
||||
align_corners=False,
|
||||
loss_decode=dict(
|
||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
|
||||
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.4)),
|
||||
])
|
||||
|
||||
# model training and testing settings
|
||||
|
||||
@ -4,7 +4,7 @@ _base_ = [
|
||||
]
|
||||
|
||||
# Re-config the data sampler.
|
||||
data = dict(samples_per_gpu=8, workers_per_gpu=4)
|
||||
data = dict(samples_per_gpu=2, workers_per_gpu=4)
|
||||
|
||||
# Re-config the optimizer.
|
||||
optimizer = dict(type='SGD', lr=0.12, momentum=0.9, weight_decay=4e-5)
|
||||
|
||||
@ -35,7 +35,8 @@ class BaseDecodeHead(nn.Module, metaclass=ABCMeta):
|
||||
Default: None.
|
||||
loss_decode (dict): Config of decode loss.
|
||||
Default: dict(type='CrossEntropyLoss').
|
||||
ignore_index (int): The label index to be ignored. Default: 255
|
||||
ignore_index (int | None): The label index to be ignored. When using
|
||||
masked BCE loss, ignore_index should be set to None. Default: 255
|
||||
sampler (dict|None): The config of segmentation map sampler.
|
||||
Default: None.
|
||||
align_corners (bool): align_corners argument of F.interpolate.
|
||||
|
||||
@ -32,17 +32,25 @@ def cross_entropy(pred,
|
||||
return loss
|
||||
|
||||
|
||||
def _expand_onehot_labels(labels, label_weights, label_channels):
|
||||
def _expand_onehot_labels(labels, label_weights, target_shape, ignore_index):
|
||||
"""Expand onehot labels to match the size of prediction."""
|
||||
bin_labels = labels.new_full((labels.size(0), label_channels), 0)
|
||||
inds = torch.nonzero(labels >= 1, as_tuple=False).squeeze()
|
||||
if inds.numel() > 0:
|
||||
bin_labels[inds, labels[inds] - 1] = 1
|
||||
bin_labels = labels.new_zeros(target_shape)
|
||||
valid_mask = (labels >= 0) & (labels != ignore_index)
|
||||
inds = torch.nonzero(valid_mask, as_tuple=True)
|
||||
|
||||
if inds[0].numel() > 0:
|
||||
if labels.dim() == 3:
|
||||
bin_labels[inds[0], labels[valid_mask], inds[1], inds[2]] = 1
|
||||
else:
|
||||
bin_labels[inds[0], labels[valid_mask]] = 1
|
||||
|
||||
valid_mask = valid_mask.unsqueeze(1).expand(target_shape).float()
|
||||
if label_weights is None:
|
||||
bin_label_weights = None
|
||||
bin_label_weights = valid_mask
|
||||
else:
|
||||
bin_label_weights = label_weights.view(-1, 1).expand(
|
||||
label_weights.size(0), label_channels)
|
||||
bin_label_weights = label_weights.unsqueeze(1).expand(target_shape)
|
||||
bin_label_weights *= valid_mask
|
||||
|
||||
return bin_labels, bin_label_weights
|
||||
|
||||
|
||||
@ -51,7 +59,8 @@ def binary_cross_entropy(pred,
|
||||
weight=None,
|
||||
reduction='mean',
|
||||
avg_factor=None,
|
||||
class_weight=None):
|
||||
class_weight=None,
|
||||
ignore_index=255):
|
||||
"""Calculate the binary CrossEntropy loss.
|
||||
|
||||
Args:
|
||||
@ -63,18 +72,24 @@ def binary_cross_entropy(pred,
|
||||
avg_factor (int, optional): Average factor that is used to average
|
||||
the loss. Defaults to None.
|
||||
class_weight (list[float], optional): The weight for each class.
|
||||
ignore_index (int | None): The label index to be ignored. Default: 255
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The calculated loss
|
||||
"""
|
||||
if pred.dim() != label.dim():
|
||||
label, weight = _expand_onehot_labels(label, weight, pred.size(-1))
|
||||
assert (pred.dim() == 2 and label.dim() == 1) or (
|
||||
pred.dim() == 4 and label.dim() == 3), \
|
||||
'Only pred shape [N, C], label shape [N] or pred shape [N, C, ' \
|
||||
'H, W], label shape [N, H, W] are supported'
|
||||
label, weight = _expand_onehot_labels(label, weight, pred.shape,
|
||||
ignore_index)
|
||||
|
||||
# weighted element-wise losses
|
||||
if weight is not None:
|
||||
weight = weight.float()
|
||||
loss = F.binary_cross_entropy_with_logits(
|
||||
pred, label.float(), weight=class_weight, reduction='none')
|
||||
pred, label.float(), pos_weight=class_weight, reduction='none')
|
||||
# do the reduction for the weighted loss
|
||||
loss = weight_reduce_loss(
|
||||
loss, weight, reduction=reduction, avg_factor=avg_factor)
|
||||
@ -87,7 +102,8 @@ def mask_cross_entropy(pred,
|
||||
label,
|
||||
reduction='mean',
|
||||
avg_factor=None,
|
||||
class_weight=None):
|
||||
class_weight=None,
|
||||
ignore_index=None):
|
||||
"""Calculate the CrossEntropy loss for masks.
|
||||
|
||||
Args:
|
||||
@ -103,10 +119,13 @@ def mask_cross_entropy(pred,
|
||||
avg_factor (int, optional): Average factor that is used to average
|
||||
the loss. Defaults to None.
|
||||
class_weight (list[float], optional): The weight for each class.
|
||||
ignore_index (None): Placeholder, to be consistent with other loss.
|
||||
Default: None.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The calculated loss
|
||||
"""
|
||||
assert ignore_index is None, 'BCE loss does not support ignore_index'
|
||||
# TODO: handle these two reserved arguments
|
||||
assert reduction == 'mean' and avg_factor is None
|
||||
num_rois = pred.size()[0]
|
||||
|
||||
@ -71,7 +71,17 @@ def test_ce_loss():
|
||||
loss_cls_cfg = dict(
|
||||
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0)
|
||||
loss_cls = build_loss(loss_cls_cfg)
|
||||
assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(0.))
|
||||
assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(100.))
|
||||
|
||||
fake_pred = torch.full(size=(2, 21, 8, 8), fill_value=0.5)
|
||||
fake_label = torch.ones(2, 8, 8).long()
|
||||
assert torch.allclose(
|
||||
loss_cls(fake_pred, fake_label), torch.tensor(0.9503), atol=1e-4)
|
||||
fake_label[:, 0, 0] = 255
|
||||
assert torch.allclose(
|
||||
loss_cls(fake_pred, fake_label, ignore_index=255),
|
||||
torch.tensor(0.9354),
|
||||
atol=1e-4)
|
||||
|
||||
# TODO test use_mask
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user