diff --git a/mmseg/models/segmentors/base.py b/mmseg/models/segmentors/base.py index 3ce25ac..2d5ea23 100644 --- a/mmseg/models/segmentors/base.py +++ b/mmseg/models/segmentors/base.py @@ -307,13 +307,14 @@ class ONNXRuntimeSegmentorKN(BaseSegmentor): # register custom op for onnxruntime if osp.exists(ort_custom_op_path): session_options.register_custom_ops_library(ort_custom_op_path) - sess = ort.InferenceSession(onnx_file, session_options) providers = ['CPUExecutionProvider'] options = [{}] - is_cuda_available = ort.get_device() == 'GPU' + is_cuda_available = False + # is_cuda_available = ort.get_device() == 'GPU' if is_cuda_available: providers.insert(0, 'CUDAExecutionProvider') options.insert(0, {'device_id': device_id}) + sess = ort.InferenceSession(onnx_file, session_options, providers=providers) sess.set_providers(providers, options) @@ -336,6 +337,7 @@ class ONNXRuntimeSegmentorKN(BaseSegmentor): def forward_train(self, imgs, img_metas, **kwargs): raise NotImplementedError('This method is not implemented.') + @torch.no_grad() def simple_test(self, img: torch.Tensor, img_meta: Union[Iterable, None] = None, **kwargs) -> list: @@ -365,7 +367,6 @@ class ONNXRuntimeSegmentorKN(BaseSegmentor): seg_pred = resize( seg_pred, size=tuple(ori_shape[:2]), mode='nearest') seg_pred = seg_pred.long().detach().cpu().numpy() - seg_pred = list(seg_pred[0]) return list(seg_pred[0]) def aug_test(self, imgs, img_metas, **kwargs):