diff --git a/mmseg/models/segmentors/base.py b/mmseg/models/segmentors/base.py index 6755d9e..32f8987 100644 --- a/mmseg/models/segmentors/base.py +++ b/mmseg/models/segmentors/base.py @@ -335,6 +335,9 @@ class ONNXRuntimeSegmentorKN(BaseSegmentor): def forward_train(self, imgs, img_metas, **kwargs): raise NotImplementedError('This method is not implemented.') + def forward_test(self, imgs, img_metas, **kwargs): + return super().forward_test(imgs, img_metas[0].data, **kwargs) + @torch.no_grad() def simple_test(self, img: torch.Tensor, img_meta: Union[Iterable, None] = None, diff --git a/tools/deploy_test_kneron.py b/tools/deploy_test_kneron.py index 1dd74e3..d1aacdf 100644 --- a/tools/deploy_test_kneron.py +++ b/tools/deploy_test_kneron.py @@ -6,7 +6,6 @@ import warnings import mmcv import torch -from mmcv.parallel import MMDataParallel from mmcv.runner import get_dist_info from mmcv.utils import DictAction @@ -155,7 +154,6 @@ def main(): else: tmpdir = None - model = MMDataParallel(model, device_ids=[0]) results = single_gpu_test( model, data_loader,