refactor: run_with_iobinding -> run

This commit is contained in:
chingning.chen 2022-03-23 15:56:01 +08:00
parent f824495134
commit 12c840564e

View File

@ -291,7 +291,7 @@ class BaseSegmentor(BaseModule, metaclass=ABCMeta):
class ONNXRuntimeSegmentorKN(BaseSegmentor): class ONNXRuntimeSegmentorKN(BaseSegmentor):
def __init__(self, onnx_file: str, cfg: Any, device_id: int): def __init__(self, onnx_file: str, cfg: Any, device_id: int = 0):
super(ONNXRuntimeSegmentorKN, self).__init__() super(ONNXRuntimeSegmentorKN, self).__init__()
import onnxruntime as ort import onnxruntime as ort
@ -308,24 +308,22 @@ class ONNXRuntimeSegmentorKN(BaseSegmentor):
if osp.exists(ort_custom_op_path): if osp.exists(ort_custom_op_path):
session_options.register_custom_ops_library(ort_custom_op_path) session_options.register_custom_ops_library(ort_custom_op_path)
providers = ['CPUExecutionProvider'] providers = ['CPUExecutionProvider']
options = [{}] provider_options = [{}]
is_cuda_available = False is_cuda_available = ort.get_device() == 'GPU' and torch.cuda.is_available()
# is_cuda_available = ort.get_device() == 'GPU'
if is_cuda_available: if is_cuda_available:
providers.insert(0, 'CUDAExecutionProvider') providers.insert(0, 'CUDAExecutionProvider')
options.insert(0, {'device_id': device_id}) provider_options.insert(0, {'device_id': device_id})
sess = ort.InferenceSession(onnx_file, session_options, providers=providers) sess = ort.InferenceSession(
onnx_file, session_options, providers, provider_options
sess.set_providers(providers, options) )
self.sess = sess self.sess = sess
self.device_id = device_id sess_inputs = sess.get_inputs()
self.io_binding = sess.io_binding() assert len(sess_inputs) == 1, "Only onnx with 1 input is supported"
self.output_names = [_.name for _ in sess.get_outputs()] self.input_name = sess_inputs[0].name
for name in self.output_names: self.output_name_list = [_.name for _ in sess.get_outputs()]
self.io_binding.bind_output(name) assert len(self.output_name_list) == 1, "Only onnx with 1 output is supported"
self.cfg = cfg self.cfg = cfg # TODO: necessary?
self.test_mode = cfg.model.test_cfg.mode self.test_mode = cfg.model.test_cfg.mode # NOTE: should be 'whole' or 'slide'
self.is_cuda_available = is_cuda_available self.is_cuda_available = is_cuda_available
def extract_feat(self, imgs): def extract_feat(self, imgs):
@ -341,23 +339,14 @@ class ONNXRuntimeSegmentorKN(BaseSegmentor):
def simple_test(self, img: torch.Tensor, def simple_test(self, img: torch.Tensor,
img_meta: Union[Iterable, None] = None, img_meta: Union[Iterable, None] = None,
**kwargs) -> list: **kwargs) -> list:
if not self.is_cuda_available: img = img.cpu().numpy()
img = img.detach().cpu() # NOTE: not using run_with_iobinding since some ort versions
elif self.device_id >= 0: # generate wrong results when inferencing with CUDA
img = img.cuda(self.device_id)
device_type = img.device.type
if self.test_mode == 'slide': if self.test_mode == 'slide':
raise NotImplementedError('slide mode is not implemented yet') # raise NotImplementedError('slide mode is not implemented yet')
seg_pred = self.sess.run(self.output_name_list, {self.input_name: img})[0]
else: else:
self.io_binding.bind_input( seg_pred = self.sess.run(self.output_name_list, {self.input_name: img})[0]
name='input',
device_type=device_type,
device_id=self.device_id,
element_type=np.float32,
shape=img.shape,
buffer_ptr=img.data_ptr())
self.sess.run_with_iobinding(self.io_binding)
seg_pred = self.io_binding.copy_outputs_to_cpu()[0]
seg_pred = seg_pred.argmax(1)[:, None] seg_pred = seg_pred.argmax(1)[:, None]
if img_meta is not None: if img_meta is not None:
ori_shape = img_meta[0]['ori_shape'] ori_shape = img_meta[0]['ori_shape']