fix: onnx inference wrong indexing bug

This commit is contained in:
chingning.chen 2022-03-23 10:28:58 +08:00
parent 5846e789be
commit f824495134

View File

@ -307,13 +307,14 @@ class ONNXRuntimeSegmentorKN(BaseSegmentor):
# register custom op for onnxruntime
if osp.exists(ort_custom_op_path):
session_options.register_custom_ops_library(ort_custom_op_path)
sess = ort.InferenceSession(onnx_file, session_options)
providers = ['CPUExecutionProvider']
options = [{}]
is_cuda_available = ort.get_device() == 'GPU'
is_cuda_available = False
# is_cuda_available = ort.get_device() == 'GPU'
if is_cuda_available:
providers.insert(0, 'CUDAExecutionProvider')
options.insert(0, {'device_id': device_id})
sess = ort.InferenceSession(onnx_file, session_options, providers=providers)
sess.set_providers(providers, options)
@ -336,6 +337,7 @@ class ONNXRuntimeSegmentorKN(BaseSegmentor):
def forward_train(self, imgs, img_metas, **kwargs):
raise NotImplementedError('This method is not implemented.')
@torch.no_grad()
def simple_test(self, img: torch.Tensor,
img_meta: Union[Iterable, None] = None,
**kwargs) -> list:
@ -365,7 +367,6 @@ class ONNXRuntimeSegmentorKN(BaseSegmentor):
seg_pred = resize(
seg_pred, size=tuple(ori_shape[:2]), mode='nearest')
seg_pred = seg_pred.long().detach().cpu().numpy()
seg_pred = list(seg_pred[0])
return list(seg_pred[0])
def aug_test(self, imgs, img_metas, **kwargs):