refactor: not using MMDataParallel for onnx inferencing
This commit is contained in:
parent
12c840564e
commit
b2246495f5
@ -335,6 +335,9 @@ class ONNXRuntimeSegmentorKN(BaseSegmentor):
|
||||
def forward_train(self, imgs, img_metas, **kwargs):
|
||||
raise NotImplementedError('This method is not implemented.')
|
||||
|
||||
def forward_test(self, imgs, img_metas, **kwargs):
|
||||
return super().forward_test(imgs, img_metas[0].data, **kwargs)
|
||||
|
||||
@torch.no_grad()
|
||||
def simple_test(self, img: torch.Tensor,
|
||||
img_meta: Union[Iterable, None] = None,
|
||||
|
||||
@ -6,7 +6,6 @@ import warnings
|
||||
|
||||
import mmcv
|
||||
import torch
|
||||
from mmcv.parallel import MMDataParallel
|
||||
from mmcv.runner import get_dist_info
|
||||
from mmcv.utils import DictAction
|
||||
|
||||
@ -155,7 +154,6 @@ def main():
|
||||
else:
|
||||
tmpdir = None
|
||||
|
||||
model = MMDataParallel(model, device_ids=[0])
|
||||
results = single_gpu_test(
|
||||
model,
|
||||
data_loader,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user