52 lines
1.8 KiB
Python
52 lines
1.8 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from mmcv.cnn.bricks.wrappers import NewEmptyTensorOp, obsolete_torch_version
|
|
|
|
if torch.__version__ == 'parrots':
|
|
TORCH_VERSION = torch.__version__
|
|
else:
|
|
# torch.__version__ could be 1.3.1+cu92, we only need the first two
|
|
# for comparison
|
|
TORCH_VERSION = tuple(int(x) for x in torch.__version__.split('.')[:2])
|
|
|
|
|
|
def adaptive_avg_pool2d(input, output_size):
|
|
"""Handle empty batch dimension to adaptive_avg_pool2d.
|
|
|
|
Args:
|
|
input (tensor): 4D tensor.
|
|
output_size (int, tuple[int,int]): the target output size.
|
|
"""
|
|
if input.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 9)):
|
|
if isinstance(output_size, int):
|
|
output_size = [output_size, output_size]
|
|
output_size = [*input.shape[:2], *output_size]
|
|
empty = NewEmptyTensorOp.apply(input, output_size)
|
|
return empty
|
|
else:
|
|
return F.adaptive_avg_pool2d(input, output_size)
|
|
|
|
|
|
class AdaptiveAvgPool2d(nn.AdaptiveAvgPool2d):
|
|
"""Handle empty batch dimension to AdaptiveAvgPool2d."""
|
|
|
|
def forward(self, x):
|
|
# PyTorch 1.9 does not support empty tensor inference yet
|
|
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 9)):
|
|
output_size = self.output_size
|
|
if isinstance(output_size, int):
|
|
output_size = [output_size, output_size]
|
|
else:
|
|
output_size = [
|
|
v if v is not None else d
|
|
for v, d in zip(output_size,
|
|
x.size()[-2:])
|
|
]
|
|
output_size = [*x.shape[:2], *output_size]
|
|
empty = NewEmptyTensorOp.apply(x, output_size)
|
|
return empty
|
|
|
|
return super().forward(x)
|