73 lines
2.5 KiB
Python
73 lines
2.5 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from torch.autograd import Function
|
|
from torch.nn import functional as F
|
|
|
|
|
|
class SigmoidGeometricMean(Function):
|
|
"""Forward and backward function of geometric mean of two sigmoid
|
|
functions.
|
|
|
|
This implementation with analytical gradient function substitutes
|
|
the autograd function of (x.sigmoid() * y.sigmoid()).sqrt(). The
|
|
original implementation incurs none during gradient backprapagation
|
|
if both x and y are very small values.
|
|
"""
|
|
|
|
@staticmethod
|
|
def forward(ctx, x, y):
|
|
x_sigmoid = x.sigmoid()
|
|
y_sigmoid = y.sigmoid()
|
|
z = (x_sigmoid * y_sigmoid).sqrt()
|
|
ctx.save_for_backward(x_sigmoid, y_sigmoid, z)
|
|
return z
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
x_sigmoid, y_sigmoid, z = ctx.saved_tensors
|
|
grad_x = grad_output * z * (1 - x_sigmoid) / 2
|
|
grad_y = grad_output * z * (1 - y_sigmoid) / 2
|
|
return grad_x, grad_y
|
|
|
|
|
|
sigmoid_geometric_mean = SigmoidGeometricMean.apply
|
|
|
|
|
|
def interpolate_as(source, target, mode='bilinear', align_corners=False):
|
|
"""Interpolate the `source` to the shape of the `target`.
|
|
|
|
The `source` must be a Tensor, but the `target` can be a Tensor or a
|
|
np.ndarray with the shape (..., target_h, target_w).
|
|
|
|
Args:
|
|
source (Tensor): A 3D/4D Tensor with the shape (N, H, W) or
|
|
(N, C, H, W).
|
|
target (Tensor | np.ndarray): The interpolation target with the shape
|
|
(..., target_h, target_w).
|
|
mode (str): Algorithm used for interpolation. The options are the
|
|
same as those in F.interpolate(). Default: ``'bilinear'``.
|
|
align_corners (bool): The same as the argument in F.interpolate().
|
|
|
|
Returns:
|
|
Tensor: The interpolated source Tensor.
|
|
"""
|
|
assert len(target.shape) >= 2
|
|
|
|
def _interpolate_as(source, target, mode='bilinear', align_corners=False):
|
|
"""Interpolate the `source` (4D) to the shape of the `target`."""
|
|
target_h, target_w = target.shape[-2:]
|
|
source_h, source_w = source.shape[-2:]
|
|
if target_h != source_h or target_w != source_w:
|
|
source = F.interpolate(
|
|
source,
|
|
size=(target_h, target_w),
|
|
mode=mode,
|
|
align_corners=align_corners)
|
|
return source
|
|
|
|
if len(source.shape) == 3:
|
|
source = source[:, None, :, :]
|
|
source = _interpolate_as(source, target, mode, align_corners)
|
|
return source[:, 0, :, :]
|
|
else:
|
|
return _interpolate_as(source, target, mode, align_corners)
|