From 3f2c7b18beae6fb1180c659b28a2bc2ce8830449 Mon Sep 17 00:00:00 2001 From: "chingning.chen" Date: Mon, 21 Mar 2022 17:00:04 +0800 Subject: [PATCH] feat: ONNXRuntimeSegmentorKN --- mmseg/models/segmentors/base.py | 86 +++++++++++++++++++++++++++++++++ 1 file changed, 86 insertions(+) diff --git a/mmseg/models/segmentors/base.py b/mmseg/models/segmentors/base.py index 9b22a7c..3ce25ac 100644 --- a/mmseg/models/segmentors/base.py +++ b/mmseg/models/segmentors/base.py @@ -2,12 +2,15 @@ import warnings from abc import ABCMeta, abstractmethod from collections import OrderedDict +from typing import Any, Iterable, Union +from os import path as osp import mmcv import numpy as np import torch import torch.distributed as dist from mmcv.runner import BaseModule, auto_fp16 +from mmseg.ops import resize class BaseSegmentor(BaseModule, metaclass=ABCMeta): @@ -284,3 +287,86 @@ class BaseSegmentor(BaseModule, metaclass=ABCMeta): warnings.warn('show==False and out_file is not specified, only ' 'result image will be returned') return img + + +class ONNXRuntimeSegmentorKN(BaseSegmentor): + + def __init__(self, onnx_file: str, cfg: Any, device_id: int): + super(ONNXRuntimeSegmentorKN, self).__init__() + import onnxruntime as ort + + # get the custom op path + ort_custom_op_path = '' + try: + from mmcv.ops import get_onnxruntime_op_path + ort_custom_op_path = get_onnxruntime_op_path() + except (ImportError, ModuleNotFoundError): + warnings.warn('If input model has custom op from mmcv, \ + you may have to build mmcv with ONNXRuntime from source.') + session_options = ort.SessionOptions() + # register custom op for onnxruntime + if osp.exists(ort_custom_op_path): + session_options.register_custom_ops_library(ort_custom_op_path) + sess = ort.InferenceSession(onnx_file, session_options) + providers = ['CPUExecutionProvider'] + options = [{}] + is_cuda_available = ort.get_device() == 'GPU' + if is_cuda_available: + providers.insert(0, 'CUDAExecutionProvider') + options.insert(0, {'device_id': device_id}) + + sess.set_providers(providers, options) + + self.sess = sess + self.device_id = device_id + self.io_binding = sess.io_binding() + self.output_names = [_.name for _ in sess.get_outputs()] + for name in self.output_names: + self.io_binding.bind_output(name) + self.cfg = cfg + self.test_mode = cfg.model.test_cfg.mode + self.is_cuda_available = is_cuda_available + + def extract_feat(self, imgs): + raise NotImplementedError('This method is not implemented.') + + def encode_decode(self, img, img_metas): + raise NotImplementedError('This method is not implemented.') + + def forward_train(self, imgs, img_metas, **kwargs): + raise NotImplementedError('This method is not implemented.') + + def simple_test(self, img: torch.Tensor, + img_meta: Union[Iterable, None] = None, + **kwargs) -> list: + if not self.is_cuda_available: + img = img.detach().cpu() + elif self.device_id >= 0: + img = img.cuda(self.device_id) + device_type = img.device.type + if self.test_mode == 'slide': + raise NotImplementedError('slide mode is not implemented yet') + else: + self.io_binding.bind_input( + name='input', + device_type=device_type, + device_id=self.device_id, + element_type=np.float32, + shape=img.shape, + buffer_ptr=img.data_ptr()) + self.sess.run_with_iobinding(self.io_binding) + seg_pred = self.io_binding.copy_outputs_to_cpu()[0] + seg_pred = seg_pred.argmax(1)[:, None] + if img_meta is not None: + ori_shape = img_meta[0]['ori_shape'] + if not (ori_shape[0] == seg_pred.shape[-2] + and ori_shape[1] == seg_pred.shape[-1]): + seg_pred = torch.from_numpy(seg_pred).float() + seg_pred = resize( + seg_pred, size=tuple(ori_shape[:2]), mode='nearest') + seg_pred = seg_pred.long().detach().cpu().numpy() + seg_pred = list(seg_pred[0]) + return list(seg_pred[0]) + + def aug_test(self, imgs, img_metas, **kwargs): + raise NotImplementedError('This method is not implemented.')