refactor: remove final resize node from onnx; resize on app side

This commit is contained in:
chingning.chen 2022-03-29 17:59:50 +08:00
parent dcac233a60
commit 7576467aa1
4 changed files with 58 additions and 12 deletions

View File

@ -1,11 +1,16 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .inference import inference_segmentor, init_segmentor, show_result_pyplot
from .inference import (
inference_segmentor,
init_segmentor,
init_segmentor_kn,
show_result_pyplot,
)
from .test import multi_gpu_test, single_gpu_test
from .train import (get_root_logger, init_random_seed, set_random_seed,
train_segmentor)
__all__ = [
'get_root_logger', 'set_random_seed', 'train_segmentor', 'init_segmentor',
'inference_segmentor', 'multi_gpu_test', 'single_gpu_test',
'show_result_pyplot', 'init_random_seed'
'init_segmentor_kn', 'inference_segmentor', 'multi_gpu_test',
'single_gpu_test', 'show_result_pyplot', 'init_random_seed'
]

View File

@ -7,6 +7,7 @@ from mmcv.runner import load_checkpoint
from mmseg.datasets.pipelines import Compose
from mmseg.models import build_segmentor
from mmseg.models.segmentors import ONNXRuntimeSegmentorKN
def init_segmentor(config, checkpoint=None, device='cuda:0'):
@ -40,6 +41,30 @@ def init_segmentor(config, checkpoint=None, device='cuda:0'):
return model
def init_segmentor_kn(config, checkpoint=None, device='cuda:0'):
"""Initialize a segmentor from config file.
Args:
config (str or :obj:`mmcv.Config`): Config file path or the config
object.
checkpoint (str, optional): Checkpoint path. If left as None, the model
will not load any weights.
device (str, optional) CPU/CUDA device option. Default 'cuda:0'.
Use 'cpu' for loading model on CPU.
Returns:
nn.Module: The constructed segmentor.
"""
if checkpoint is None or not checkpoint.endswith(".onnx"):
return init_segmentor(config, checkpoint, device)
try:
_, device_id = device.split(":")
device_id = int(device_id)
except Exception:
device_id = None if device == 'cpu' else 0
model = ONNXRuntimeSegmentorKN(checkpoint, cfg=config, device_id=device_id).eval()
return model
class LoadImage:
"""A simple pipeline to load image."""
@ -99,6 +124,13 @@ def inference_segmentor(model, img):
return result
def inference_segmentor_kn(model, img):
if model.endswith(".onnx"):
pass
else:
return inference_segmentor(model, img)
def show_result_pyplot(model,
img,
result,

View File

@ -291,7 +291,11 @@ class BaseSegmentor(BaseModule, metaclass=ABCMeta):
class ONNXRuntimeSegmentorKN(BaseSegmentor):
def __init__(self, onnx_file: str, cfg: Any, device_id: int = 0):
def __init__(
self,
onnx_file: str,
cfg: Any,
device_id: Union[int, None] = 0):
super(ONNXRuntimeSegmentorKN, self).__init__()
import onnxruntime as ort
@ -301,8 +305,9 @@ class ONNXRuntimeSegmentorKN(BaseSegmentor):
from mmcv.ops import get_onnxruntime_op_path
ort_custom_op_path = get_onnxruntime_op_path()
except (ImportError, ModuleNotFoundError):
warnings.warn('If input model has custom op from mmcv, \
you may have to build mmcv with ONNXRuntime from source.')
warnings.warn(
'If input model has custom op from mmcv, you may '
'have to build mmcv with ONNXRuntime from source.')
session_options = ort.SessionOptions()
# register custom op for onnxruntime
if osp.exists(ort_custom_op_path):
@ -312,6 +317,7 @@ class ONNXRuntimeSegmentorKN(BaseSegmentor):
is_cuda_available = ort.get_device() == 'GPU' and torch.cuda.is_available()
if is_cuda_available:
providers.insert(0, 'CUDAExecutionProvider')
device_id = device_id or 0
provider_options.insert(0, {'device_id': device_id})
sess = ort.InferenceSession(
onnx_file, session_options, providers, provider_options
@ -397,17 +403,19 @@ class ONNXRuntimeSegmentorKN(BaseSegmentor):
if self.test_mode == 'slide':
seg_pred = self.simple_slide_inference(img, img_meta)
else:
seg_pred = self.sess.run(self.output_name_list, {self.input_name: img})[0]
seg_pred = seg_pred.argmax(1)[:, None]
seg_pred = self.sess.run(
self.output_name_list, {self.input_name: img}
)[0]
if img_meta is not None:
ori_shape = img_meta[0]['ori_shape']
if not (ori_shape[0] == seg_pred.shape[-2]
and ori_shape[1] == seg_pred.shape[-1]):
seg_pred = torch.from_numpy(seg_pred).float()
seg_pred = resize(
seg_pred, size=tuple(ori_shape[:2]), mode='nearest')
seg_pred = seg_pred.long().detach().cpu().numpy()
return list(seg_pred[0])
seg_pred, size=tuple(ori_shape[:2]), mode='bilinear')
seg_pred = seg_pred.numpy()
seg_pred = seg_pred.argmax(1)
return list(seg_pred)
def aug_test(self, imgs, img_metas, **kwargs):
raise NotImplementedError('This method is not implemented.')

View File

@ -115,7 +115,8 @@ class EncoderDecoder(BaseSegmentor):
def forward_dummy(self, img):
"""Dummy forward function."""
seg_logit = self.encode_decode(img, None)
seg_logit = self.extract_feat(img)
seg_logit = self._decode_head_forward_test(seg_logit, None)
return seg_logit