refactor: not using MMDataParallel for onnx inferencing
This commit is contained in:
parent
12c840564e
commit
b2246495f5
@ -335,6 +335,9 @@ class ONNXRuntimeSegmentorKN(BaseSegmentor):
|
|||||||
def forward_train(self, imgs, img_metas, **kwargs):
|
def forward_train(self, imgs, img_metas, **kwargs):
|
||||||
raise NotImplementedError('This method is not implemented.')
|
raise NotImplementedError('This method is not implemented.')
|
||||||
|
|
||||||
|
def forward_test(self, imgs, img_metas, **kwargs):
|
||||||
|
return super().forward_test(imgs, img_metas[0].data, **kwargs)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def simple_test(self, img: torch.Tensor,
|
def simple_test(self, img: torch.Tensor,
|
||||||
img_meta: Union[Iterable, None] = None,
|
img_meta: Union[Iterable, None] = None,
|
||||||
|
|||||||
@ -6,7 +6,6 @@ import warnings
|
|||||||
|
|
||||||
import mmcv
|
import mmcv
|
||||||
import torch
|
import torch
|
||||||
from mmcv.parallel import MMDataParallel
|
|
||||||
from mmcv.runner import get_dist_info
|
from mmcv.runner import get_dist_info
|
||||||
from mmcv.utils import DictAction
|
from mmcv.utils import DictAction
|
||||||
|
|
||||||
@ -155,7 +154,6 @@ def main():
|
|||||||
else:
|
else:
|
||||||
tmpdir = None
|
tmpdir = None
|
||||||
|
|
||||||
model = MMDataParallel(model, device_ids=[0])
|
|
||||||
results = single_gpu_test(
|
results = single_gpu_test(
|
||||||
model,
|
model,
|
||||||
data_loader,
|
data_loader,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user