refactor: remove final resize node from onnx; resize on app side
This commit is contained in:
parent
dcac233a60
commit
7576467aa1
@ -1,11 +1,16 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .inference import inference_segmentor, init_segmentor, show_result_pyplot
|
||||
from .inference import (
|
||||
inference_segmentor,
|
||||
init_segmentor,
|
||||
init_segmentor_kn,
|
||||
show_result_pyplot,
|
||||
)
|
||||
from .test import multi_gpu_test, single_gpu_test
|
||||
from .train import (get_root_logger, init_random_seed, set_random_seed,
|
||||
train_segmentor)
|
||||
|
||||
__all__ = [
|
||||
'get_root_logger', 'set_random_seed', 'train_segmentor', 'init_segmentor',
|
||||
'inference_segmentor', 'multi_gpu_test', 'single_gpu_test',
|
||||
'show_result_pyplot', 'init_random_seed'
|
||||
'init_segmentor_kn', 'inference_segmentor', 'multi_gpu_test',
|
||||
'single_gpu_test', 'show_result_pyplot', 'init_random_seed'
|
||||
]
|
||||
|
||||
@ -7,6 +7,7 @@ from mmcv.runner import load_checkpoint
|
||||
|
||||
from mmseg.datasets.pipelines import Compose
|
||||
from mmseg.models import build_segmentor
|
||||
from mmseg.models.segmentors import ONNXRuntimeSegmentorKN
|
||||
|
||||
|
||||
def init_segmentor(config, checkpoint=None, device='cuda:0'):
|
||||
@ -40,6 +41,30 @@ def init_segmentor(config, checkpoint=None, device='cuda:0'):
|
||||
return model
|
||||
|
||||
|
||||
def init_segmentor_kn(config, checkpoint=None, device='cuda:0'):
|
||||
"""Initialize a segmentor from config file.
|
||||
|
||||
Args:
|
||||
config (str or :obj:`mmcv.Config`): Config file path or the config
|
||||
object.
|
||||
checkpoint (str, optional): Checkpoint path. If left as None, the model
|
||||
will not load any weights.
|
||||
device (str, optional) CPU/CUDA device option. Default 'cuda:0'.
|
||||
Use 'cpu' for loading model on CPU.
|
||||
Returns:
|
||||
nn.Module: The constructed segmentor.
|
||||
"""
|
||||
if checkpoint is None or not checkpoint.endswith(".onnx"):
|
||||
return init_segmentor(config, checkpoint, device)
|
||||
try:
|
||||
_, device_id = device.split(":")
|
||||
device_id = int(device_id)
|
||||
except Exception:
|
||||
device_id = None if device == 'cpu' else 0
|
||||
model = ONNXRuntimeSegmentorKN(checkpoint, cfg=config, device_id=device_id).eval()
|
||||
return model
|
||||
|
||||
|
||||
class LoadImage:
|
||||
"""A simple pipeline to load image."""
|
||||
|
||||
@ -99,6 +124,13 @@ def inference_segmentor(model, img):
|
||||
return result
|
||||
|
||||
|
||||
def inference_segmentor_kn(model, img):
|
||||
if model.endswith(".onnx"):
|
||||
pass
|
||||
else:
|
||||
return inference_segmentor(model, img)
|
||||
|
||||
|
||||
def show_result_pyplot(model,
|
||||
img,
|
||||
result,
|
||||
|
||||
@ -291,7 +291,11 @@ class BaseSegmentor(BaseModule, metaclass=ABCMeta):
|
||||
|
||||
class ONNXRuntimeSegmentorKN(BaseSegmentor):
|
||||
|
||||
def __init__(self, onnx_file: str, cfg: Any, device_id: int = 0):
|
||||
def __init__(
|
||||
self,
|
||||
onnx_file: str,
|
||||
cfg: Any,
|
||||
device_id: Union[int, None] = 0):
|
||||
super(ONNXRuntimeSegmentorKN, self).__init__()
|
||||
import onnxruntime as ort
|
||||
|
||||
@ -301,8 +305,9 @@ class ONNXRuntimeSegmentorKN(BaseSegmentor):
|
||||
from mmcv.ops import get_onnxruntime_op_path
|
||||
ort_custom_op_path = get_onnxruntime_op_path()
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
warnings.warn('If input model has custom op from mmcv, \
|
||||
you may have to build mmcv with ONNXRuntime from source.')
|
||||
warnings.warn(
|
||||
'If input model has custom op from mmcv, you may '
|
||||
'have to build mmcv with ONNXRuntime from source.')
|
||||
session_options = ort.SessionOptions()
|
||||
# register custom op for onnxruntime
|
||||
if osp.exists(ort_custom_op_path):
|
||||
@ -312,6 +317,7 @@ class ONNXRuntimeSegmentorKN(BaseSegmentor):
|
||||
is_cuda_available = ort.get_device() == 'GPU' and torch.cuda.is_available()
|
||||
if is_cuda_available:
|
||||
providers.insert(0, 'CUDAExecutionProvider')
|
||||
device_id = device_id or 0
|
||||
provider_options.insert(0, {'device_id': device_id})
|
||||
sess = ort.InferenceSession(
|
||||
onnx_file, session_options, providers, provider_options
|
||||
@ -397,17 +403,19 @@ class ONNXRuntimeSegmentorKN(BaseSegmentor):
|
||||
if self.test_mode == 'slide':
|
||||
seg_pred = self.simple_slide_inference(img, img_meta)
|
||||
else:
|
||||
seg_pred = self.sess.run(self.output_name_list, {self.input_name: img})[0]
|
||||
seg_pred = seg_pred.argmax(1)[:, None]
|
||||
seg_pred = self.sess.run(
|
||||
self.output_name_list, {self.input_name: img}
|
||||
)[0]
|
||||
if img_meta is not None:
|
||||
ori_shape = img_meta[0]['ori_shape']
|
||||
if not (ori_shape[0] == seg_pred.shape[-2]
|
||||
and ori_shape[1] == seg_pred.shape[-1]):
|
||||
seg_pred = torch.from_numpy(seg_pred).float()
|
||||
seg_pred = resize(
|
||||
seg_pred, size=tuple(ori_shape[:2]), mode='nearest')
|
||||
seg_pred = seg_pred.long().detach().cpu().numpy()
|
||||
return list(seg_pred[0])
|
||||
seg_pred, size=tuple(ori_shape[:2]), mode='bilinear')
|
||||
seg_pred = seg_pred.numpy()
|
||||
seg_pred = seg_pred.argmax(1)
|
||||
return list(seg_pred)
|
||||
|
||||
def aug_test(self, imgs, img_metas, **kwargs):
|
||||
raise NotImplementedError('This method is not implemented.')
|
||||
|
||||
@ -115,7 +115,8 @@ class EncoderDecoder(BaseSegmentor):
|
||||
|
||||
def forward_dummy(self, img):
|
||||
"""Dummy forward function."""
|
||||
seg_logit = self.encode_decode(img, None)
|
||||
seg_logit = self.extract_feat(img)
|
||||
seg_logit = self._decode_head_forward_test(seg_logit, None)
|
||||
|
||||
return seg_logit
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user