refactor: remove final resize node from onnx; resize on app side
This commit is contained in:
parent
dcac233a60
commit
7576467aa1
@ -1,11 +1,16 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from .inference import inference_segmentor, init_segmentor, show_result_pyplot
|
from .inference import (
|
||||||
|
inference_segmentor,
|
||||||
|
init_segmentor,
|
||||||
|
init_segmentor_kn,
|
||||||
|
show_result_pyplot,
|
||||||
|
)
|
||||||
from .test import multi_gpu_test, single_gpu_test
|
from .test import multi_gpu_test, single_gpu_test
|
||||||
from .train import (get_root_logger, init_random_seed, set_random_seed,
|
from .train import (get_root_logger, init_random_seed, set_random_seed,
|
||||||
train_segmentor)
|
train_segmentor)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'get_root_logger', 'set_random_seed', 'train_segmentor', 'init_segmentor',
|
'get_root_logger', 'set_random_seed', 'train_segmentor', 'init_segmentor',
|
||||||
'inference_segmentor', 'multi_gpu_test', 'single_gpu_test',
|
'init_segmentor_kn', 'inference_segmentor', 'multi_gpu_test',
|
||||||
'show_result_pyplot', 'init_random_seed'
|
'single_gpu_test', 'show_result_pyplot', 'init_random_seed'
|
||||||
]
|
]
|
||||||
|
|||||||
@ -7,6 +7,7 @@ from mmcv.runner import load_checkpoint
|
|||||||
|
|
||||||
from mmseg.datasets.pipelines import Compose
|
from mmseg.datasets.pipelines import Compose
|
||||||
from mmseg.models import build_segmentor
|
from mmseg.models import build_segmentor
|
||||||
|
from mmseg.models.segmentors import ONNXRuntimeSegmentorKN
|
||||||
|
|
||||||
|
|
||||||
def init_segmentor(config, checkpoint=None, device='cuda:0'):
|
def init_segmentor(config, checkpoint=None, device='cuda:0'):
|
||||||
@ -40,6 +41,30 @@ def init_segmentor(config, checkpoint=None, device='cuda:0'):
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def init_segmentor_kn(config, checkpoint=None, device='cuda:0'):
|
||||||
|
"""Initialize a segmentor from config file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config (str or :obj:`mmcv.Config`): Config file path or the config
|
||||||
|
object.
|
||||||
|
checkpoint (str, optional): Checkpoint path. If left as None, the model
|
||||||
|
will not load any weights.
|
||||||
|
device (str, optional) CPU/CUDA device option. Default 'cuda:0'.
|
||||||
|
Use 'cpu' for loading model on CPU.
|
||||||
|
Returns:
|
||||||
|
nn.Module: The constructed segmentor.
|
||||||
|
"""
|
||||||
|
if checkpoint is None or not checkpoint.endswith(".onnx"):
|
||||||
|
return init_segmentor(config, checkpoint, device)
|
||||||
|
try:
|
||||||
|
_, device_id = device.split(":")
|
||||||
|
device_id = int(device_id)
|
||||||
|
except Exception:
|
||||||
|
device_id = None if device == 'cpu' else 0
|
||||||
|
model = ONNXRuntimeSegmentorKN(checkpoint, cfg=config, device_id=device_id).eval()
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
class LoadImage:
|
class LoadImage:
|
||||||
"""A simple pipeline to load image."""
|
"""A simple pipeline to load image."""
|
||||||
|
|
||||||
@ -99,6 +124,13 @@ def inference_segmentor(model, img):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def inference_segmentor_kn(model, img):
|
||||||
|
if model.endswith(".onnx"):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
return inference_segmentor(model, img)
|
||||||
|
|
||||||
|
|
||||||
def show_result_pyplot(model,
|
def show_result_pyplot(model,
|
||||||
img,
|
img,
|
||||||
result,
|
result,
|
||||||
|
|||||||
@ -291,7 +291,11 @@ class BaseSegmentor(BaseModule, metaclass=ABCMeta):
|
|||||||
|
|
||||||
class ONNXRuntimeSegmentorKN(BaseSegmentor):
|
class ONNXRuntimeSegmentorKN(BaseSegmentor):
|
||||||
|
|
||||||
def __init__(self, onnx_file: str, cfg: Any, device_id: int = 0):
|
def __init__(
|
||||||
|
self,
|
||||||
|
onnx_file: str,
|
||||||
|
cfg: Any,
|
||||||
|
device_id: Union[int, None] = 0):
|
||||||
super(ONNXRuntimeSegmentorKN, self).__init__()
|
super(ONNXRuntimeSegmentorKN, self).__init__()
|
||||||
import onnxruntime as ort
|
import onnxruntime as ort
|
||||||
|
|
||||||
@ -301,8 +305,9 @@ class ONNXRuntimeSegmentorKN(BaseSegmentor):
|
|||||||
from mmcv.ops import get_onnxruntime_op_path
|
from mmcv.ops import get_onnxruntime_op_path
|
||||||
ort_custom_op_path = get_onnxruntime_op_path()
|
ort_custom_op_path = get_onnxruntime_op_path()
|
||||||
except (ImportError, ModuleNotFoundError):
|
except (ImportError, ModuleNotFoundError):
|
||||||
warnings.warn('If input model has custom op from mmcv, \
|
warnings.warn(
|
||||||
you may have to build mmcv with ONNXRuntime from source.')
|
'If input model has custom op from mmcv, you may '
|
||||||
|
'have to build mmcv with ONNXRuntime from source.')
|
||||||
session_options = ort.SessionOptions()
|
session_options = ort.SessionOptions()
|
||||||
# register custom op for onnxruntime
|
# register custom op for onnxruntime
|
||||||
if osp.exists(ort_custom_op_path):
|
if osp.exists(ort_custom_op_path):
|
||||||
@ -312,6 +317,7 @@ class ONNXRuntimeSegmentorKN(BaseSegmentor):
|
|||||||
is_cuda_available = ort.get_device() == 'GPU' and torch.cuda.is_available()
|
is_cuda_available = ort.get_device() == 'GPU' and torch.cuda.is_available()
|
||||||
if is_cuda_available:
|
if is_cuda_available:
|
||||||
providers.insert(0, 'CUDAExecutionProvider')
|
providers.insert(0, 'CUDAExecutionProvider')
|
||||||
|
device_id = device_id or 0
|
||||||
provider_options.insert(0, {'device_id': device_id})
|
provider_options.insert(0, {'device_id': device_id})
|
||||||
sess = ort.InferenceSession(
|
sess = ort.InferenceSession(
|
||||||
onnx_file, session_options, providers, provider_options
|
onnx_file, session_options, providers, provider_options
|
||||||
@ -397,17 +403,19 @@ class ONNXRuntimeSegmentorKN(BaseSegmentor):
|
|||||||
if self.test_mode == 'slide':
|
if self.test_mode == 'slide':
|
||||||
seg_pred = self.simple_slide_inference(img, img_meta)
|
seg_pred = self.simple_slide_inference(img, img_meta)
|
||||||
else:
|
else:
|
||||||
seg_pred = self.sess.run(self.output_name_list, {self.input_name: img})[0]
|
seg_pred = self.sess.run(
|
||||||
seg_pred = seg_pred.argmax(1)[:, None]
|
self.output_name_list, {self.input_name: img}
|
||||||
|
)[0]
|
||||||
if img_meta is not None:
|
if img_meta is not None:
|
||||||
ori_shape = img_meta[0]['ori_shape']
|
ori_shape = img_meta[0]['ori_shape']
|
||||||
if not (ori_shape[0] == seg_pred.shape[-2]
|
if not (ori_shape[0] == seg_pred.shape[-2]
|
||||||
and ori_shape[1] == seg_pred.shape[-1]):
|
and ori_shape[1] == seg_pred.shape[-1]):
|
||||||
seg_pred = torch.from_numpy(seg_pred).float()
|
seg_pred = torch.from_numpy(seg_pred).float()
|
||||||
seg_pred = resize(
|
seg_pred = resize(
|
||||||
seg_pred, size=tuple(ori_shape[:2]), mode='nearest')
|
seg_pred, size=tuple(ori_shape[:2]), mode='bilinear')
|
||||||
seg_pred = seg_pred.long().detach().cpu().numpy()
|
seg_pred = seg_pred.numpy()
|
||||||
return list(seg_pred[0])
|
seg_pred = seg_pred.argmax(1)
|
||||||
|
return list(seg_pred)
|
||||||
|
|
||||||
def aug_test(self, imgs, img_metas, **kwargs):
|
def aug_test(self, imgs, img_metas, **kwargs):
|
||||||
raise NotImplementedError('This method is not implemented.')
|
raise NotImplementedError('This method is not implemented.')
|
||||||
|
|||||||
@ -115,7 +115,8 @@ class EncoderDecoder(BaseSegmentor):
|
|||||||
|
|
||||||
def forward_dummy(self, img):
|
def forward_dummy(self, img):
|
||||||
"""Dummy forward function."""
|
"""Dummy forward function."""
|
||||||
seg_logit = self.encode_decode(img, None)
|
seg_logit = self.extract_feat(img)
|
||||||
|
seg_logit = self._decode_head_forward_test(seg_logit, None)
|
||||||
|
|
||||||
return seg_logit
|
return seg_logit
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user