From b8f42c70faa0a214df79f88fce9cd4943db76aaf Mon Sep 17 00:00:00 2001 From: Jerry Jiarui XU Date: Thu, 3 Sep 2020 19:56:36 +0800 Subject: [PATCH] Add Semantic FPN (#94) * Add Semantic FPN * remove HRFPN --- configs/_base_/models/fpn_r50.py | 36 +++ configs/sem_fpn/README.md | 30 +++ .../fpn_r101_512x1024_80k_cityscapes.py | 2 + .../sem_fpn/fpn_r101_512x512_160k_ade20k.py | 2 + .../fpn_r50_512x1024_80k_cityscapes.py | 4 + .../sem_fpn/fpn_r50_512x512_160k_ade20k.py | 5 + mmseg/models/__init__.py | 1 + mmseg/models/decode_heads/__init__.py | 3 +- mmseg/models/decode_heads/fpn_head.py | 68 ++++++ mmseg/models/necks/__init__.py | 3 + mmseg/models/necks/fpn.py | 212 ++++++++++++++++++ tests/test_models/test_forward.py | 4 + tests/test_models/test_heads.py | 37 +-- tests/test_models/test_necks.py | 18 ++ 14 files changed, 388 insertions(+), 37 deletions(-) create mode 100644 configs/_base_/models/fpn_r50.py create mode 100644 configs/sem_fpn/README.md create mode 100644 configs/sem_fpn/fpn_r101_512x1024_80k_cityscapes.py create mode 100644 configs/sem_fpn/fpn_r101_512x512_160k_ade20k.py create mode 100644 configs/sem_fpn/fpn_r50_512x1024_80k_cityscapes.py create mode 100644 configs/sem_fpn/fpn_r50_512x512_160k_ade20k.py create mode 100644 mmseg/models/decode_heads/fpn_head.py create mode 100644 mmseg/models/necks/__init__.py create mode 100644 mmseg/models/necks/fpn.py create mode 100644 tests/test_models/test_necks.py diff --git a/configs/_base_/models/fpn_r50.py b/configs/_base_/models/fpn_r50.py new file mode 100644 index 0000000..ec11717 --- /dev/null +++ b/configs/_base_/models/fpn_r50.py @@ -0,0 +1,36 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained='open-mmlab://resnet50_v1c', + backbone=dict( + type='ResNetV1c', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + dilations=(1, 1, 1, 1), + strides=(1, 2, 2, 2), + norm_cfg=norm_cfg, + norm_eval=False, + style='pytorch', + contract_dilation=True), + neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + num_outs=4), + decode_head=dict( + type='FPNHead', + in_channels=[256, 256, 256, 256], + in_index=[0, 1, 2, 3], + feature_strides=[4, 8, 16, 32], + channels=128, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0))) +# model training and testing settings +train_cfg = dict() +test_cfg = dict(mode='whole') diff --git a/configs/sem_fpn/README.md b/configs/sem_fpn/README.md new file mode 100644 index 0000000..5315d36 --- /dev/null +++ b/configs/sem_fpn/README.md @@ -0,0 +1,30 @@ +# Panoptic Feature Pyramid Networks + +## Introduction +``` +@article{Kirillov_2019, + title={Panoptic Feature Pyramid Networks}, + ISBN={9781728132938}, + url={http://dx.doi.org/10.1109/CVPR.2019.00656}, + DOI={10.1109/cvpr.2019.00656}, + journal={2019 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, + publisher={IEEE}, + author={Kirillov, Alexander and Girshick, Ross and He, Kaiming and Dollar, Piotr}, + year={2019}, + month={Jun} +} +``` + +## Results and models + +### Cityscapes +| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | download | +|--------|----------|-----------|--------:|---------:|----------------|------:|---------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| FPN | R-50 | 512x1024 | 80000 | 2.8 | 13.54 | 74.52 | 76.08 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/sem_fpn/fpn_r50_512x1024_80k_cityscapes/fpn_r50_512x1024_80k_cityscapes_20200717_021437-94018a0d.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/sem_fpn/fpn_r50_512x1024_80k_cityscapes/fpn_r50_512x1024_80k_cityscapes-20200717_021437.log.json) | +| FPN | R-101 | 512x1024 | 80000 | 3.9 | 10.29 | 75.80 | 77.40 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/sem_fpn/fpn_r101_512x1024_80k_cityscapes/fpn_r101_512x1024_80k_cityscapes_20200717_012416-c5800d4c.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/sem_fpn/fpn_r101_512x1024_80k_cityscapes/fpn_r101_512x1024_80k_cityscapes-20200717_012416.log.json) | + +### ADE20K +| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | download | +|--------|----------|-----------|--------:|---------:|----------------|------:|---------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| FPN | R-50 | 512x512 | 160000 | 4.9 | 55.77 | 37.49 | 39.09 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/sem_fpn/fpn_r50_512x512_160k_ade20k/fpn_r50_512x512_160k_ade20k_20200718_131734-5b5a6ab9.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/sem_fpn/fpn_r50_512x512_160k_ade20k/fpn_r50_512x512_160k_ade20k-20200718_131734.log.json) | +| FPN | R-101 | 512x512 | 160000 | 5.9 | 40.58 | 39.35 | 40.72 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/sem_fpn/fpn_r101_512x512_160k_ade20k/fpn_r101_512x512_160k_ade20k_20200718_131734-306b5004.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/sem_fpn/fpn_r101_512x512_160k_ade20k/fpn_r101_512x512_160k_ade20k-20200718_131734.log.json) | diff --git a/configs/sem_fpn/fpn_r101_512x1024_80k_cityscapes.py b/configs/sem_fpn/fpn_r101_512x1024_80k_cityscapes.py new file mode 100644 index 0000000..7f8710d --- /dev/null +++ b/configs/sem_fpn/fpn_r101_512x1024_80k_cityscapes.py @@ -0,0 +1,2 @@ +_base_ = './fpn_r50_512x1024_80k_cityscapes.py' +model = dict(pretrained='open-mmlab://resnet101_v1c', backbone=dict(depth=101)) diff --git a/configs/sem_fpn/fpn_r101_512x512_160k_ade20k.py b/configs/sem_fpn/fpn_r101_512x512_160k_ade20k.py new file mode 100644 index 0000000..2654096 --- /dev/null +++ b/configs/sem_fpn/fpn_r101_512x512_160k_ade20k.py @@ -0,0 +1,2 @@ +_base_ = './fpn_r50_512x512_160k_ade20k.py' +model = dict(pretrained='open-mmlab://resnet101_v1c', backbone=dict(depth=101)) diff --git a/configs/sem_fpn/fpn_r50_512x1024_80k_cityscapes.py b/configs/sem_fpn/fpn_r50_512x1024_80k_cityscapes.py new file mode 100644 index 0000000..4bf3edd --- /dev/null +++ b/configs/sem_fpn/fpn_r50_512x1024_80k_cityscapes.py @@ -0,0 +1,4 @@ +_base_ = [ + '../_base_/models/fpn_r50.py', '../_base_/datasets/cityscapes.py', + '../_base_/default_runtime.py', '../_base_/schedules/schedule_80k.py' +] diff --git a/configs/sem_fpn/fpn_r50_512x512_160k_ade20k.py b/configs/sem_fpn/fpn_r50_512x512_160k_ade20k.py new file mode 100644 index 0000000..5cdfc8c --- /dev/null +++ b/configs/sem_fpn/fpn_r50_512x512_160k_ade20k.py @@ -0,0 +1,5 @@ +_base_ = [ + '../_base_/models/fpn_r50.py', '../_base_/datasets/ade20k.py', + '../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py' +] +model = dict(decode_head=dict(num_classes=150)) diff --git a/mmseg/models/__init__.py b/mmseg/models/__init__.py index d492a23..3cf93f8 100644 --- a/mmseg/models/__init__.py +++ b/mmseg/models/__init__.py @@ -3,6 +3,7 @@ from .builder import (BACKBONES, HEADS, LOSSES, SEGMENTORS, build_backbone, build_head, build_loss, build_segmentor) from .decode_heads import * # noqa: F401,F403 from .losses import * # noqa: F401,F403 +from .necks import * # noqa: F401,F403 from .segmentors import * # noqa: F401,F403 __all__ = [ diff --git a/mmseg/models/decode_heads/__init__.py b/mmseg/models/decode_heads/__init__.py index a6ead50..5828034 100644 --- a/mmseg/models/decode_heads/__init__.py +++ b/mmseg/models/decode_heads/__init__.py @@ -4,6 +4,7 @@ from .cc_head import CCHead from .da_head import DAHead from .enc_head import EncHead from .fcn_head import FCNHead +from .fpn_head import FPNHead from .gc_head import GCHead from .nl_head import NLHead from .ocr_head import OCRHead @@ -16,5 +17,5 @@ from .uper_head import UPerHead __all__ = [ 'FCNHead', 'PSPHead', 'ASPPHead', 'PSAHead', 'NLHead', 'GCHead', 'CCHead', 'UPerHead', 'DepthwiseSeparableASPPHead', 'ANNHead', 'DAHead', 'OCRHead', - 'EncHead', 'DepthwiseSeparableFCNHead' + 'EncHead', 'DepthwiseSeparableFCNHead', 'FPNHead' ] diff --git a/mmseg/models/decode_heads/fpn_head.py b/mmseg/models/decode_heads/fpn_head.py new file mode 100644 index 0000000..9b6ada0 --- /dev/null +++ b/mmseg/models/decode_heads/fpn_head.py @@ -0,0 +1,68 @@ +import numpy as np +import torch.nn as nn +from mmcv.cnn import ConvModule + +from mmseg.ops import resize +from ..builder import HEADS +from .decode_head import BaseDecodeHead + + +@HEADS.register_module() +class FPNHead(BaseDecodeHead): + """Panoptic Feature Pyramid Networks. + + This head is the implementation of `Semantic FPN + `_. + + Args: + feature_strides (tuple[int]): The strides for input feature maps. + stack_lateral. All strides suppose to be power of 2. The first + one is of largest resolution. + """ + + def __init__(self, feature_strides, **kwargs): + super(FPNHead, self).__init__( + input_transform='multiple_select', **kwargs) + assert len(feature_strides) == len(self.in_channels) + assert min(feature_strides) == feature_strides[0] + self.feature_strides = feature_strides + + self.scale_heads = nn.ModuleList() + for i in range(len(feature_strides)): + head_length = max( + 1, + int(np.log2(feature_strides[i]) - np.log2(feature_strides[0]))) + scale_head = [] + for k in range(head_length): + scale_head.append( + ConvModule( + self.in_channels[i] if k == 0 else self.channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + if feature_strides[i] != feature_strides[0]: + scale_head.append( + nn.Upsample( + scale_factor=2, + mode='bilinear', + align_corners=self.align_corners)) + self.scale_heads.append(nn.Sequential(*scale_head)) + + def forward(self, inputs): + + x = self._transform_inputs(inputs) + + output = self.scale_heads[0](x[0]) + for i in range(1, len(self.feature_strides)): + # non inplace + output = output + resize( + self.scale_heads[i](x[i]), + size=output.shape[2:], + mode='bilinear', + align_corners=self.align_corners) + + output = self.cls_seg(output) + return output diff --git a/mmseg/models/necks/__init__.py b/mmseg/models/necks/__init__.py new file mode 100644 index 0000000..0093021 --- /dev/null +++ b/mmseg/models/necks/__init__.py @@ -0,0 +1,3 @@ +from .fpn import FPN + +__all__ = ['FPN'] diff --git a/mmseg/models/necks/fpn.py b/mmseg/models/necks/fpn.py new file mode 100644 index 0000000..f43d1e6 --- /dev/null +++ b/mmseg/models/necks/fpn.py @@ -0,0 +1,212 @@ +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule, xavier_init + +from ..builder import NECKS + + +@NECKS.register_module() +class FPN(nn.Module): + """Feature Pyramid Network. + + This is an implementation of - Feature Pyramid Networks for Object + Detection (https://arxiv.org/abs/1612.03144) + + Args: + in_channels (List[int]): Number of input channels per scale. + out_channels (int): Number of output channels (used at each scale) + num_outs (int): Number of output scales. + start_level (int): Index of the start input backbone level used to + build the feature pyramid. Default: 0. + end_level (int): Index of the end input backbone level (exclusive) to + build the feature pyramid. Default: -1, which means the last level. + add_extra_convs (bool | str): If bool, it decides whether to add conv + layers on top of the original feature maps. Default to False. + If True, its actual mode is specified by `extra_convs_on_inputs`. + If str, it specifies the source feature map of the extra convs. + Only the following options are allowed + + - 'on_input': Last feat map of neck inputs (i.e. backbone feature). + - 'on_lateral': Last feature map after lateral convs. + - 'on_output': The last output feature map after fpn convs. + extra_convs_on_inputs (bool, deprecated): Whether to apply extra convs + on the original feature from the backbone. If True, + it is equivalent to `add_extra_convs='on_input'`. If False, it is + equivalent to set `add_extra_convs='on_output'`. Default to True. + relu_before_extra_convs (bool): Whether to apply relu before the extra + conv. Default: False. + no_norm_on_lateral (bool): Whether to apply norm on lateral. + Default: False. + conv_cfg (dict): Config dict for convolution layer. Default: None. + norm_cfg (dict): Config dict for normalization layer. Default: None. + act_cfg (str): Config dict for activation layer in ConvModule. + Default: None. + upsample_cfg (dict): Config dict for interpolate layer. + Default: `dict(mode='nearest')` + + Example: + >>> import torch + >>> in_channels = [2, 3, 5, 7] + >>> scales = [340, 170, 84, 43] + >>> inputs = [torch.rand(1, c, s, s) + ... for c, s in zip(in_channels, scales)] + >>> self = FPN(in_channels, 11, len(in_channels)).eval() + >>> outputs = self.forward(inputs) + >>> for i in range(len(outputs)): + ... print(f'outputs[{i}].shape = {outputs[i].shape}') + outputs[0].shape = torch.Size([1, 11, 340, 340]) + outputs[1].shape = torch.Size([1, 11, 170, 170]) + outputs[2].shape = torch.Size([1, 11, 84, 84]) + outputs[3].shape = torch.Size([1, 11, 43, 43]) + """ + + def __init__(self, + in_channels, + out_channels, + num_outs, + start_level=0, + end_level=-1, + add_extra_convs=False, + extra_convs_on_inputs=False, + relu_before_extra_convs=False, + no_norm_on_lateral=False, + conv_cfg=None, + norm_cfg=None, + act_cfg=None, + upsample_cfg=dict(mode='nearest')): + super(FPN, self).__init__() + assert isinstance(in_channels, list) + self.in_channels = in_channels + self.out_channels = out_channels + self.num_ins = len(in_channels) + self.num_outs = num_outs + self.relu_before_extra_convs = relu_before_extra_convs + self.no_norm_on_lateral = no_norm_on_lateral + self.fp16_enabled = False + self.upsample_cfg = upsample_cfg.copy() + + if end_level == -1: + self.backbone_end_level = self.num_ins + assert num_outs >= self.num_ins - start_level + else: + # if end_level < inputs, no extra level is allowed + self.backbone_end_level = end_level + assert end_level <= len(in_channels) + assert num_outs == end_level - start_level + self.start_level = start_level + self.end_level = end_level + self.add_extra_convs = add_extra_convs + assert isinstance(add_extra_convs, (str, bool)) + if isinstance(add_extra_convs, str): + # Extra_convs_source choices: 'on_input', 'on_lateral', 'on_output' + assert add_extra_convs in ('on_input', 'on_lateral', 'on_output') + elif add_extra_convs: # True + if extra_convs_on_inputs: + # For compatibility with previous release + # TODO: deprecate `extra_convs_on_inputs` + self.add_extra_convs = 'on_input' + else: + self.add_extra_convs = 'on_output' + + self.lateral_convs = nn.ModuleList() + self.fpn_convs = nn.ModuleList() + + for i in range(self.start_level, self.backbone_end_level): + l_conv = ConvModule( + in_channels[i], + out_channels, + 1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg if not self.no_norm_on_lateral else None, + act_cfg=act_cfg, + inplace=False) + fpn_conv = ConvModule( + out_channels, + out_channels, + 3, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + inplace=False) + + self.lateral_convs.append(l_conv) + self.fpn_convs.append(fpn_conv) + + # add extra conv layers (e.g., RetinaNet) + extra_levels = num_outs - self.backbone_end_level + self.start_level + if self.add_extra_convs and extra_levels >= 1: + for i in range(extra_levels): + if i == 0 and self.add_extra_convs == 'on_input': + in_channels = self.in_channels[self.backbone_end_level - 1] + else: + in_channels = out_channels + extra_fpn_conv = ConvModule( + in_channels, + out_channels, + 3, + stride=2, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + inplace=False) + self.fpn_convs.append(extra_fpn_conv) + + # default init_weights for conv(msra) and norm in ConvModule + def init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + xavier_init(m, distribution='uniform') + + def forward(self, inputs): + assert len(inputs) == len(self.in_channels) + + # build laterals + laterals = [ + lateral_conv(inputs[i + self.start_level]) + for i, lateral_conv in enumerate(self.lateral_convs) + ] + + # build top-down path + used_backbone_levels = len(laterals) + for i in range(used_backbone_levels - 1, 0, -1): + # In some cases, fixing `scale factor` (e.g. 2) is preferred, but + # it cannot co-exist with `size` in `F.interpolate`. + if 'scale_factor' in self.upsample_cfg: + laterals[i - 1] += F.interpolate(laterals[i], + **self.upsample_cfg) + else: + prev_shape = laterals[i - 1].shape[2:] + laterals[i - 1] += F.interpolate( + laterals[i], size=prev_shape, **self.upsample_cfg) + + # build outputs + # part 1: from original levels + outs = [ + self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels) + ] + # part 2: add extra levels + if self.num_outs > len(outs): + # use max pool to get more levels on top of outputs + # (e.g., Faster R-CNN, Mask R-CNN) + if not self.add_extra_convs: + for i in range(self.num_outs - used_backbone_levels): + outs.append(F.max_pool2d(outs[-1], 1, stride=2)) + # add conv layers on top of original feature maps (RetinaNet) + else: + if self.add_extra_convs == 'on_input': + extra_source = inputs[self.backbone_end_level - 1] + elif self.add_extra_convs == 'on_lateral': + extra_source = laterals[-1] + elif self.add_extra_convs == 'on_output': + extra_source = outs[-1] + else: + raise NotImplementedError + outs.append(self.fpn_convs[used_backbone_levels](extra_source)) + for i in range(used_backbone_levels + 1, self.num_outs): + if self.relu_before_extra_convs: + outs.append(self.fpn_convs[i](F.relu(outs[-1]))) + else: + outs.append(self.fpn_convs[i](outs[-1])) + return tuple(outs) diff --git a/tests/test_models/test_forward.py b/tests/test_models/test_forward.py index 620b82e..fffe23e 100644 --- a/tests/test_models/test_forward.py +++ b/tests/test_models/test_forward.py @@ -153,6 +153,10 @@ def test_encnet_forward(): 'encnet/encnet_r50-d8_512x1024_40k_cityscapes.py') +def test_sem_fpn_forward(): + _test_encoder_decoder_forward('sem_fpn/fpn_r50_512x1024_80k_cityscapes.py') + + def get_world_size(process_group): return 1 diff --git a/tests/test_models/test_heads.py b/tests/test_models/test_heads.py index 8feb0e6..3ac6bb0 100644 --- a/tests/test_models/test_heads.py +++ b/tests/test_models/test_heads.py @@ -6,8 +6,7 @@ from mmcv.cnn import ConvModule from mmcv.utils.parrots_wrapper import SyncBatchNorm from mmseg.models.decode_heads import (ANNHead, ASPPHead, CCHead, DAHead, - DepthwiseSeparableASPPHead, - DepthwiseSeparableFCNHead, EncHead, + DepthwiseSeparableASPPHead, EncHead, FCNHead, GCHead, NLHead, OCRHead, PSAHead, PSPHead, UPerHead) from mmseg.models.decode_heads.decode_head import BaseDecodeHead @@ -540,37 +539,3 @@ def test_dw_aspp_head(): assert head.aspp_modules[2].depthwise_conv.dilation == (24, 24) outputs = head(inputs) assert outputs.shape == (1, head.num_classes, 45, 45) - - -def test_sep_fcn_head(): - # test sep_fcn_head with concat_input=False - head = DepthwiseSeparableFCNHead( - in_channels=128, - channels=128, - concat_input=False, - num_classes=19, - in_index=-1, - norm_cfg=dict(type='BN', requires_grad=True, momentum=0.01)) - x = [torch.rand(2, 128, 32, 32)] - output = head(x) - assert output.shape == (2, head.num_classes, 32, 32) - assert not head.concat_input - from mmseg.ops.separable_conv_module import DepthwiseSeparableConvModule - assert isinstance(head.convs[0], DepthwiseSeparableConvModule) - assert isinstance(head.convs[1], DepthwiseSeparableConvModule) - assert head.conv_seg.kernel_size == (1, 1) - - head = DepthwiseSeparableFCNHead( - in_channels=64, - channels=64, - concat_input=True, - num_classes=19, - in_index=-1, - norm_cfg=dict(type='BN', requires_grad=True, momentum=0.01)) - x = [torch.rand(3, 64, 32, 32)] - output = head(x) - assert output.shape == (3, head.num_classes, 32, 32) - assert head.concat_input - from mmseg.ops.separable_conv_module import DepthwiseSeparableConvModule - assert isinstance(head.convs[0], DepthwiseSeparableConvModule) - assert isinstance(head.convs[1], DepthwiseSeparableConvModule) diff --git a/tests/test_models/test_necks.py b/tests/test_models/test_necks.py new file mode 100644 index 0000000..8fc9684 --- /dev/null +++ b/tests/test_models/test_necks.py @@ -0,0 +1,18 @@ +import torch + +from mmseg.models import FPN + + +def test_fpn(): + in_channels = [256, 512, 1024, 2048] + inputs = [ + torch.randn(1, c, 56 // 2**i, 56 // 2**i) + for i, c in enumerate(in_channels) + ] + + fpn = FPN(in_channels, 256, len(in_channels)) + outputs = fpn(inputs) + assert outputs[0].shape == torch.Size([1, 256, 56, 56]) + assert outputs[1].shape == torch.Size([1, 256, 28, 28]) + assert outputs[2].shape == torch.Size([1, 256, 14, 14]) + assert outputs[3].shape == torch.Size([1, 256, 7, 7])