diff --git a/tools/deploy_test.py b/tools/deploy_test.py index 51f16b4..56fd61c 100644 --- a/tools/deploy_test.py +++ b/tools/deploy_test.py @@ -53,6 +53,7 @@ class ONNXRuntimeSegmentor(BaseSegmentor): self.io_binding.bind_output(name) self.cfg = cfg self.test_mode = cfg.model.test_cfg.mode + self.is_cuda_available = is_cuda_available def extract_feat(self, imgs): raise NotImplementedError('This method is not implemented.') @@ -65,6 +66,10 @@ class ONNXRuntimeSegmentor(BaseSegmentor): def simple_test(self, img: torch.Tensor, img_meta: Iterable, **kwargs) -> list: + if not self.is_cuda_available: + img = img.detach().cpu() + elif self.device_id >= 0: + img = img.cuda(self.device_id) device_type = img.device.type self.io_binding.bind_input( name='input',