feat: ONNXRuntimeSegmentorKN
This commit is contained in:
parent
57bb510d5e
commit
3f2c7b18be
@ -2,12 +2,15 @@
|
|||||||
import warnings
|
import warnings
|
||||||
from abc import ABCMeta, abstractmethod
|
from abc import ABCMeta, abstractmethod
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
from typing import Any, Iterable, Union
|
||||||
|
from os import path as osp
|
||||||
|
|
||||||
import mmcv
|
import mmcv
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from mmcv.runner import BaseModule, auto_fp16
|
from mmcv.runner import BaseModule, auto_fp16
|
||||||
|
from mmseg.ops import resize
|
||||||
|
|
||||||
|
|
||||||
class BaseSegmentor(BaseModule, metaclass=ABCMeta):
|
class BaseSegmentor(BaseModule, metaclass=ABCMeta):
|
||||||
@ -284,3 +287,86 @@ class BaseSegmentor(BaseModule, metaclass=ABCMeta):
|
|||||||
warnings.warn('show==False and out_file is not specified, only '
|
warnings.warn('show==False and out_file is not specified, only '
|
||||||
'result image will be returned')
|
'result image will be returned')
|
||||||
return img
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
class ONNXRuntimeSegmentorKN(BaseSegmentor):
|
||||||
|
|
||||||
|
def __init__(self, onnx_file: str, cfg: Any, device_id: int):
|
||||||
|
super(ONNXRuntimeSegmentorKN, self).__init__()
|
||||||
|
import onnxruntime as ort
|
||||||
|
|
||||||
|
# get the custom op path
|
||||||
|
ort_custom_op_path = ''
|
||||||
|
try:
|
||||||
|
from mmcv.ops import get_onnxruntime_op_path
|
||||||
|
ort_custom_op_path = get_onnxruntime_op_path()
|
||||||
|
except (ImportError, ModuleNotFoundError):
|
||||||
|
warnings.warn('If input model has custom op from mmcv, \
|
||||||
|
you may have to build mmcv with ONNXRuntime from source.')
|
||||||
|
session_options = ort.SessionOptions()
|
||||||
|
# register custom op for onnxruntime
|
||||||
|
if osp.exists(ort_custom_op_path):
|
||||||
|
session_options.register_custom_ops_library(ort_custom_op_path)
|
||||||
|
sess = ort.InferenceSession(onnx_file, session_options)
|
||||||
|
providers = ['CPUExecutionProvider']
|
||||||
|
options = [{}]
|
||||||
|
is_cuda_available = ort.get_device() == 'GPU'
|
||||||
|
if is_cuda_available:
|
||||||
|
providers.insert(0, 'CUDAExecutionProvider')
|
||||||
|
options.insert(0, {'device_id': device_id})
|
||||||
|
|
||||||
|
sess.set_providers(providers, options)
|
||||||
|
|
||||||
|
self.sess = sess
|
||||||
|
self.device_id = device_id
|
||||||
|
self.io_binding = sess.io_binding()
|
||||||
|
self.output_names = [_.name for _ in sess.get_outputs()]
|
||||||
|
for name in self.output_names:
|
||||||
|
self.io_binding.bind_output(name)
|
||||||
|
self.cfg = cfg
|
||||||
|
self.test_mode = cfg.model.test_cfg.mode
|
||||||
|
self.is_cuda_available = is_cuda_available
|
||||||
|
|
||||||
|
def extract_feat(self, imgs):
|
||||||
|
raise NotImplementedError('This method is not implemented.')
|
||||||
|
|
||||||
|
def encode_decode(self, img, img_metas):
|
||||||
|
raise NotImplementedError('This method is not implemented.')
|
||||||
|
|
||||||
|
def forward_train(self, imgs, img_metas, **kwargs):
|
||||||
|
raise NotImplementedError('This method is not implemented.')
|
||||||
|
|
||||||
|
def simple_test(self, img: torch.Tensor,
|
||||||
|
img_meta: Union[Iterable, None] = None,
|
||||||
|
**kwargs) -> list:
|
||||||
|
if not self.is_cuda_available:
|
||||||
|
img = img.detach().cpu()
|
||||||
|
elif self.device_id >= 0:
|
||||||
|
img = img.cuda(self.device_id)
|
||||||
|
device_type = img.device.type
|
||||||
|
if self.test_mode == 'slide':
|
||||||
|
raise NotImplementedError('slide mode is not implemented yet')
|
||||||
|
else:
|
||||||
|
self.io_binding.bind_input(
|
||||||
|
name='input',
|
||||||
|
device_type=device_type,
|
||||||
|
device_id=self.device_id,
|
||||||
|
element_type=np.float32,
|
||||||
|
shape=img.shape,
|
||||||
|
buffer_ptr=img.data_ptr())
|
||||||
|
self.sess.run_with_iobinding(self.io_binding)
|
||||||
|
seg_pred = self.io_binding.copy_outputs_to_cpu()[0]
|
||||||
|
seg_pred = seg_pred.argmax(1)[:, None]
|
||||||
|
if img_meta is not None:
|
||||||
|
ori_shape = img_meta[0]['ori_shape']
|
||||||
|
if not (ori_shape[0] == seg_pred.shape[-2]
|
||||||
|
and ori_shape[1] == seg_pred.shape[-1]):
|
||||||
|
seg_pred = torch.from_numpy(seg_pred).float()
|
||||||
|
seg_pred = resize(
|
||||||
|
seg_pred, size=tuple(ori_shape[:2]), mode='nearest')
|
||||||
|
seg_pred = seg_pred.long().detach().cpu().numpy()
|
||||||
|
seg_pred = list(seg_pred[0])
|
||||||
|
return list(seg_pred[0])
|
||||||
|
|
||||||
|
def aug_test(self, imgs, img_metas, **kwargs):
|
||||||
|
raise NotImplementedError('This method is not implemented.')
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user