From b2246495f503fd3878517f74e163b3f11497819d Mon Sep 17 00:00:00 2001 From: "chingning.chen" Date: Wed, 23 Mar 2022 17:02:03 +0800 Subject: [PATCH] refactor: not using MMDataParallel for onnx inferencing --- mmseg/models/segmentors/base.py | 3 +++ tools/deploy_test_kneron.py | 2 -- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/mmseg/models/segmentors/base.py b/mmseg/models/segmentors/base.py index 6755d9e..32f8987 100644 --- a/mmseg/models/segmentors/base.py +++ b/mmseg/models/segmentors/base.py @@ -335,6 +335,9 @@ class ONNXRuntimeSegmentorKN(BaseSegmentor): def forward_train(self, imgs, img_metas, **kwargs): raise NotImplementedError('This method is not implemented.') + def forward_test(self, imgs, img_metas, **kwargs): + return super().forward_test(imgs, img_metas[0].data, **kwargs) + @torch.no_grad() def simple_test(self, img: torch.Tensor, img_meta: Union[Iterable, None] = None, diff --git a/tools/deploy_test_kneron.py b/tools/deploy_test_kneron.py index 1dd74e3..d1aacdf 100644 --- a/tools/deploy_test_kneron.py +++ b/tools/deploy_test_kneron.py @@ -6,7 +6,6 @@ import warnings import mmcv import torch -from mmcv.parallel import MMDataParallel from mmcv.runner import get_dist_info from mmcv.utils import DictAction @@ -155,7 +154,6 @@ def main(): else: tmpdir = None - model = MMDataParallel(model, device_ids=[0]) results = single_gpu_test( model, data_loader,