feat: ONNXRuntimeSegmentorKN

This commit is contained in:
chingning.chen 2022-03-21 17:00:04 +08:00
parent 57bb510d5e
commit 3f2c7b18be

View File

@ -2,12 +2,15 @@
import warnings import warnings
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from collections import OrderedDict from collections import OrderedDict
from typing import Any, Iterable, Union
from os import path as osp
import mmcv import mmcv
import numpy as np import numpy as np
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from mmcv.runner import BaseModule, auto_fp16 from mmcv.runner import BaseModule, auto_fp16
from mmseg.ops import resize
class BaseSegmentor(BaseModule, metaclass=ABCMeta): class BaseSegmentor(BaseModule, metaclass=ABCMeta):
@ -284,3 +287,86 @@ class BaseSegmentor(BaseModule, metaclass=ABCMeta):
warnings.warn('show==False and out_file is not specified, only ' warnings.warn('show==False and out_file is not specified, only '
'result image will be returned') 'result image will be returned')
return img return img
class ONNXRuntimeSegmentorKN(BaseSegmentor):
def __init__(self, onnx_file: str, cfg: Any, device_id: int):
super(ONNXRuntimeSegmentorKN, self).__init__()
import onnxruntime as ort
# get the custom op path
ort_custom_op_path = ''
try:
from mmcv.ops import get_onnxruntime_op_path
ort_custom_op_path = get_onnxruntime_op_path()
except (ImportError, ModuleNotFoundError):
warnings.warn('If input model has custom op from mmcv, \
you may have to build mmcv with ONNXRuntime from source.')
session_options = ort.SessionOptions()
# register custom op for onnxruntime
if osp.exists(ort_custom_op_path):
session_options.register_custom_ops_library(ort_custom_op_path)
sess = ort.InferenceSession(onnx_file, session_options)
providers = ['CPUExecutionProvider']
options = [{}]
is_cuda_available = ort.get_device() == 'GPU'
if is_cuda_available:
providers.insert(0, 'CUDAExecutionProvider')
options.insert(0, {'device_id': device_id})
sess.set_providers(providers, options)
self.sess = sess
self.device_id = device_id
self.io_binding = sess.io_binding()
self.output_names = [_.name for _ in sess.get_outputs()]
for name in self.output_names:
self.io_binding.bind_output(name)
self.cfg = cfg
self.test_mode = cfg.model.test_cfg.mode
self.is_cuda_available = is_cuda_available
def extract_feat(self, imgs):
raise NotImplementedError('This method is not implemented.')
def encode_decode(self, img, img_metas):
raise NotImplementedError('This method is not implemented.')
def forward_train(self, imgs, img_metas, **kwargs):
raise NotImplementedError('This method is not implemented.')
def simple_test(self, img: torch.Tensor,
img_meta: Union[Iterable, None] = None,
**kwargs) -> list:
if not self.is_cuda_available:
img = img.detach().cpu()
elif self.device_id >= 0:
img = img.cuda(self.device_id)
device_type = img.device.type
if self.test_mode == 'slide':
raise NotImplementedError('slide mode is not implemented yet')
else:
self.io_binding.bind_input(
name='input',
device_type=device_type,
device_id=self.device_id,
element_type=np.float32,
shape=img.shape,
buffer_ptr=img.data_ptr())
self.sess.run_with_iobinding(self.io_binding)
seg_pred = self.io_binding.copy_outputs_to_cpu()[0]
seg_pred = seg_pred.argmax(1)[:, None]
if img_meta is not None:
ori_shape = img_meta[0]['ori_shape']
if not (ori_shape[0] == seg_pred.shape[-2]
and ori_shape[1] == seg_pred.shape[-1]):
seg_pred = torch.from_numpy(seg_pred).float()
seg_pred = resize(
seg_pred, size=tuple(ori_shape[:2]), mode='nearest')
seg_pred = seg_pred.long().detach().cpu().numpy()
seg_pred = list(seg_pred[0])
return list(seg_pred[0])
def aug_test(self, imgs, img_metas, **kwargs):
raise NotImplementedError('This method is not implemented.')