support cpu deploy_test (#769)
This commit is contained in:
parent
9155d9e9ed
commit
58f5dbce7d
@ -53,6 +53,7 @@ class ONNXRuntimeSegmentor(BaseSegmentor):
|
|||||||
self.io_binding.bind_output(name)
|
self.io_binding.bind_output(name)
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
self.test_mode = cfg.model.test_cfg.mode
|
self.test_mode = cfg.model.test_cfg.mode
|
||||||
|
self.is_cuda_available = is_cuda_available
|
||||||
|
|
||||||
def extract_feat(self, imgs):
|
def extract_feat(self, imgs):
|
||||||
raise NotImplementedError('This method is not implemented.')
|
raise NotImplementedError('This method is not implemented.')
|
||||||
@ -65,6 +66,10 @@ class ONNXRuntimeSegmentor(BaseSegmentor):
|
|||||||
|
|
||||||
def simple_test(self, img: torch.Tensor, img_meta: Iterable,
|
def simple_test(self, img: torch.Tensor, img_meta: Iterable,
|
||||||
**kwargs) -> list:
|
**kwargs) -> list:
|
||||||
|
if not self.is_cuda_available:
|
||||||
|
img = img.detach().cpu()
|
||||||
|
elif self.device_id >= 0:
|
||||||
|
img = img.cuda(self.device_id)
|
||||||
device_type = img.device.type
|
device_type = img.device.type
|
||||||
self.io_binding.bind_input(
|
self.io_binding.bind_input(
|
||||||
name='input',
|
name='input',
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user