diff --git a/mmseg/models/segmentors/base.py b/mmseg/models/segmentors/base.py index 2d5ea23..6755d9e 100644 --- a/mmseg/models/segmentors/base.py +++ b/mmseg/models/segmentors/base.py @@ -291,7 +291,7 @@ class BaseSegmentor(BaseModule, metaclass=ABCMeta): class ONNXRuntimeSegmentorKN(BaseSegmentor): - def __init__(self, onnx_file: str, cfg: Any, device_id: int): + def __init__(self, onnx_file: str, cfg: Any, device_id: int = 0): super(ONNXRuntimeSegmentorKN, self).__init__() import onnxruntime as ort @@ -308,24 +308,22 @@ class ONNXRuntimeSegmentorKN(BaseSegmentor): if osp.exists(ort_custom_op_path): session_options.register_custom_ops_library(ort_custom_op_path) providers = ['CPUExecutionProvider'] - options = [{}] - is_cuda_available = False - # is_cuda_available = ort.get_device() == 'GPU' + provider_options = [{}] + is_cuda_available = ort.get_device() == 'GPU' and torch.cuda.is_available() if is_cuda_available: providers.insert(0, 'CUDAExecutionProvider') - options.insert(0, {'device_id': device_id}) - sess = ort.InferenceSession(onnx_file, session_options, providers=providers) - - sess.set_providers(providers, options) - + provider_options.insert(0, {'device_id': device_id}) + sess = ort.InferenceSession( + onnx_file, session_options, providers, provider_options + ) self.sess = sess - self.device_id = device_id - self.io_binding = sess.io_binding() - self.output_names = [_.name for _ in sess.get_outputs()] - for name in self.output_names: - self.io_binding.bind_output(name) - self.cfg = cfg - self.test_mode = cfg.model.test_cfg.mode + sess_inputs = sess.get_inputs() + assert len(sess_inputs) == 1, "Only onnx with 1 input is supported" + self.input_name = sess_inputs[0].name + self.output_name_list = [_.name for _ in sess.get_outputs()] + assert len(self.output_name_list) == 1, "Only onnx with 1 output is supported" + self.cfg = cfg # TODO: necessary? + self.test_mode = cfg.model.test_cfg.mode # NOTE: should be 'whole' or 'slide' self.is_cuda_available = is_cuda_available def extract_feat(self, imgs): @@ -341,23 +339,14 @@ class ONNXRuntimeSegmentorKN(BaseSegmentor): def simple_test(self, img: torch.Tensor, img_meta: Union[Iterable, None] = None, **kwargs) -> list: - if not self.is_cuda_available: - img = img.detach().cpu() - elif self.device_id >= 0: - img = img.cuda(self.device_id) - device_type = img.device.type + img = img.cpu().numpy() + # NOTE: not using run_with_iobinding since some ort versions + # generate wrong results when inferencing with CUDA if self.test_mode == 'slide': - raise NotImplementedError('slide mode is not implemented yet') + # raise NotImplementedError('slide mode is not implemented yet') + seg_pred = self.sess.run(self.output_name_list, {self.input_name: img})[0] else: - self.io_binding.bind_input( - name='input', - device_type=device_type, - device_id=self.device_id, - element_type=np.float32, - shape=img.shape, - buffer_ptr=img.data_ptr()) - self.sess.run_with_iobinding(self.io_binding) - seg_pred = self.io_binding.copy_outputs_to_cpu()[0] + seg_pred = self.sess.run(self.output_name_list, {self.input_name: img})[0] seg_pred = seg_pred.argmax(1)[:, None] if img_meta is not None: ori_shape = img_meta[0]['ori_shape']