fix: onnx inference wrong indexing bug
This commit is contained in:
parent
5846e789be
commit
f824495134
@ -307,13 +307,14 @@ class ONNXRuntimeSegmentorKN(BaseSegmentor):
|
||||
# register custom op for onnxruntime
|
||||
if osp.exists(ort_custom_op_path):
|
||||
session_options.register_custom_ops_library(ort_custom_op_path)
|
||||
sess = ort.InferenceSession(onnx_file, session_options)
|
||||
providers = ['CPUExecutionProvider']
|
||||
options = [{}]
|
||||
is_cuda_available = ort.get_device() == 'GPU'
|
||||
is_cuda_available = False
|
||||
# is_cuda_available = ort.get_device() == 'GPU'
|
||||
if is_cuda_available:
|
||||
providers.insert(0, 'CUDAExecutionProvider')
|
||||
options.insert(0, {'device_id': device_id})
|
||||
sess = ort.InferenceSession(onnx_file, session_options, providers=providers)
|
||||
|
||||
sess.set_providers(providers, options)
|
||||
|
||||
@ -336,6 +337,7 @@ class ONNXRuntimeSegmentorKN(BaseSegmentor):
|
||||
def forward_train(self, imgs, img_metas, **kwargs):
|
||||
raise NotImplementedError('This method is not implemented.')
|
||||
|
||||
@torch.no_grad()
|
||||
def simple_test(self, img: torch.Tensor,
|
||||
img_meta: Union[Iterable, None] = None,
|
||||
**kwargs) -> list:
|
||||
@ -365,7 +367,6 @@ class ONNXRuntimeSegmentorKN(BaseSegmentor):
|
||||
seg_pred = resize(
|
||||
seg_pred, size=tuple(ori_shape[:2]), mode='nearest')
|
||||
seg_pred = seg_pred.long().detach().cpu().numpy()
|
||||
seg_pred = list(seg_pred[0])
|
||||
return list(seg_pred[0])
|
||||
|
||||
def aug_test(self, imgs, img_metas, **kwargs):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user