refactor: not using MMDataParallel for onnx inferencing

This commit is contained in:
chingning.chen 2022-03-23 17:02:03 +08:00
parent 12c840564e
commit b2246495f5
2 changed files with 3 additions and 2 deletions

View File

@ -335,6 +335,9 @@ class ONNXRuntimeSegmentorKN(BaseSegmentor):
def forward_train(self, imgs, img_metas, **kwargs):
raise NotImplementedError('This method is not implemented.')
def forward_test(self, imgs, img_metas, **kwargs):
return super().forward_test(imgs, img_metas[0].data, **kwargs)
@torch.no_grad()
def simple_test(self, img: torch.Tensor,
img_meta: Union[Iterable, None] = None,

View File

@ -6,7 +6,6 @@ import warnings
import mmcv
import torch
from mmcv.parallel import MMDataParallel
from mmcv.runner import get_dist_info
from mmcv.utils import DictAction
@ -155,7 +154,6 @@ def main():
else:
tmpdir = None
model = MMDataParallel(model, device_ids=[0])
results = single_gpu_test(
model,
data_loader,