From 7576467aa1bbe60bf725c5408c38940d1ace8ca8 Mon Sep 17 00:00:00 2001 From: "chingning.chen" Date: Tue, 29 Mar 2022 17:59:50 +0800 Subject: [PATCH] refactor: remove final resize node from onnx; resize on app side --- mmseg/apis/__init__.py | 11 ++++++-- mmseg/apis/inference.py | 32 ++++++++++++++++++++++ mmseg/models/segmentors/base.py | 24 ++++++++++------ mmseg/models/segmentors/encoder_decoder.py | 3 +- 4 files changed, 58 insertions(+), 12 deletions(-) diff --git a/mmseg/apis/__init__.py b/mmseg/apis/__init__.py index c688180..dffd4e1 100644 --- a/mmseg/apis/__init__.py +++ b/mmseg/apis/__init__.py @@ -1,11 +1,16 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .inference import inference_segmentor, init_segmentor, show_result_pyplot +from .inference import ( + inference_segmentor, + init_segmentor, + init_segmentor_kn, + show_result_pyplot, +) from .test import multi_gpu_test, single_gpu_test from .train import (get_root_logger, init_random_seed, set_random_seed, train_segmentor) __all__ = [ 'get_root_logger', 'set_random_seed', 'train_segmentor', 'init_segmentor', - 'inference_segmentor', 'multi_gpu_test', 'single_gpu_test', - 'show_result_pyplot', 'init_random_seed' + 'init_segmentor_kn', 'inference_segmentor', 'multi_gpu_test', + 'single_gpu_test', 'show_result_pyplot', 'init_random_seed' ] diff --git a/mmseg/apis/inference.py b/mmseg/apis/inference.py index 9069438..95f80b6 100644 --- a/mmseg/apis/inference.py +++ b/mmseg/apis/inference.py @@ -7,6 +7,7 @@ from mmcv.runner import load_checkpoint from mmseg.datasets.pipelines import Compose from mmseg.models import build_segmentor +from mmseg.models.segmentors import ONNXRuntimeSegmentorKN def init_segmentor(config, checkpoint=None, device='cuda:0'): @@ -40,6 +41,30 @@ def init_segmentor(config, checkpoint=None, device='cuda:0'): return model +def init_segmentor_kn(config, checkpoint=None, device='cuda:0'): + """Initialize a segmentor from config file. + + Args: + config (str or :obj:`mmcv.Config`): Config file path or the config + object. + checkpoint (str, optional): Checkpoint path. If left as None, the model + will not load any weights. + device (str, optional) CPU/CUDA device option. Default 'cuda:0'. + Use 'cpu' for loading model on CPU. + Returns: + nn.Module: The constructed segmentor. + """ + if checkpoint is None or not checkpoint.endswith(".onnx"): + return init_segmentor(config, checkpoint, device) + try: + _, device_id = device.split(":") + device_id = int(device_id) + except Exception: + device_id = None if device == 'cpu' else 0 + model = ONNXRuntimeSegmentorKN(checkpoint, cfg=config, device_id=device_id).eval() + return model + + class LoadImage: """A simple pipeline to load image.""" @@ -99,6 +124,13 @@ def inference_segmentor(model, img): return result +def inference_segmentor_kn(model, img): + if model.endswith(".onnx"): + pass + else: + return inference_segmentor(model, img) + + def show_result_pyplot(model, img, result, diff --git a/mmseg/models/segmentors/base.py b/mmseg/models/segmentors/base.py index dc6730e..b4f3e39 100644 --- a/mmseg/models/segmentors/base.py +++ b/mmseg/models/segmentors/base.py @@ -291,7 +291,11 @@ class BaseSegmentor(BaseModule, metaclass=ABCMeta): class ONNXRuntimeSegmentorKN(BaseSegmentor): - def __init__(self, onnx_file: str, cfg: Any, device_id: int = 0): + def __init__( + self, + onnx_file: str, + cfg: Any, + device_id: Union[int, None] = 0): super(ONNXRuntimeSegmentorKN, self).__init__() import onnxruntime as ort @@ -301,8 +305,9 @@ class ONNXRuntimeSegmentorKN(BaseSegmentor): from mmcv.ops import get_onnxruntime_op_path ort_custom_op_path = get_onnxruntime_op_path() except (ImportError, ModuleNotFoundError): - warnings.warn('If input model has custom op from mmcv, \ - you may have to build mmcv with ONNXRuntime from source.') + warnings.warn( + 'If input model has custom op from mmcv, you may ' + 'have to build mmcv with ONNXRuntime from source.') session_options = ort.SessionOptions() # register custom op for onnxruntime if osp.exists(ort_custom_op_path): @@ -312,6 +317,7 @@ class ONNXRuntimeSegmentorKN(BaseSegmentor): is_cuda_available = ort.get_device() == 'GPU' and torch.cuda.is_available() if is_cuda_available: providers.insert(0, 'CUDAExecutionProvider') + device_id = device_id or 0 provider_options.insert(0, {'device_id': device_id}) sess = ort.InferenceSession( onnx_file, session_options, providers, provider_options @@ -397,17 +403,19 @@ class ONNXRuntimeSegmentorKN(BaseSegmentor): if self.test_mode == 'slide': seg_pred = self.simple_slide_inference(img, img_meta) else: - seg_pred = self.sess.run(self.output_name_list, {self.input_name: img})[0] - seg_pred = seg_pred.argmax(1)[:, None] + seg_pred = self.sess.run( + self.output_name_list, {self.input_name: img} + )[0] if img_meta is not None: ori_shape = img_meta[0]['ori_shape'] if not (ori_shape[0] == seg_pred.shape[-2] and ori_shape[1] == seg_pred.shape[-1]): seg_pred = torch.from_numpy(seg_pred).float() seg_pred = resize( - seg_pred, size=tuple(ori_shape[:2]), mode='nearest') - seg_pred = seg_pred.long().detach().cpu().numpy() - return list(seg_pred[0]) + seg_pred, size=tuple(ori_shape[:2]), mode='bilinear') + seg_pred = seg_pred.numpy() + seg_pred = seg_pred.argmax(1) + return list(seg_pred) def aug_test(self, imgs, img_metas, **kwargs): raise NotImplementedError('This method is not implemented.') diff --git a/mmseg/models/segmentors/encoder_decoder.py b/mmseg/models/segmentors/encoder_decoder.py index 72467b4..184ecc5 100644 --- a/mmseg/models/segmentors/encoder_decoder.py +++ b/mmseg/models/segmentors/encoder_decoder.py @@ -115,7 +115,8 @@ class EncoderDecoder(BaseSegmentor): def forward_dummy(self, img): """Dummy forward function.""" - seg_logit = self.encode_decode(img, None) + seg_logit = self.extract_feat(img) + seg_logit = self._decode_head_forward_test(seg_logit, None) return seg_logit