43 lines
1.3 KiB
Python
43 lines
1.3 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
# Copyright (c) 2019 Western Digital Corporation or its affiliates.
|
|
import torch
|
|
|
|
from ..builder import DETECTORS
|
|
from .single_stage import SingleStageDetector
|
|
|
|
|
|
@DETECTORS.register_module()
|
|
class YOLOV3(SingleStageDetector):
|
|
|
|
def __init__(self,
|
|
backbone,
|
|
neck,
|
|
bbox_head,
|
|
train_cfg=None,
|
|
test_cfg=None,
|
|
pretrained=None,
|
|
init_cfg=None):
|
|
super(YOLOV3, self).__init__(backbone, neck, bbox_head, train_cfg,
|
|
test_cfg, pretrained, init_cfg)
|
|
|
|
def onnx_export(self, img, img_metas):
|
|
"""Test function for exporting to ONNX, without test time augmentation.
|
|
|
|
Args:
|
|
img (torch.Tensor): input images.
|
|
img_metas (list[dict]): List of image information.
|
|
|
|
Returns:
|
|
tuple[Tensor, Tensor]: dets of shape [N, num_det, 5]
|
|
and class labels of shape [N, num_det].
|
|
"""
|
|
x = self.extract_feat(img)
|
|
outs = self.bbox_head.forward(x)
|
|
# get shape as tensor
|
|
img_shape = torch._shape_as_tensor(img)[2:]
|
|
img_metas[0]['img_shape_for_onnx'] = img_shape
|
|
|
|
det_bboxes, det_labels = self.bbox_head.onnx_export(*outs, img_metas)
|
|
|
|
return det_bboxes, det_labels
|