feat: init_segmentor_kn, inference_segmentor_kn, show_result
This commit is contained in:
parent
0563dd8847
commit
fbfe81c815
@ -1,6 +1,7 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from .inference import (
|
from .inference import (
|
||||||
inference_segmentor,
|
inference_segmentor,
|
||||||
|
inference_segmentor_kn,
|
||||||
init_segmentor,
|
init_segmentor,
|
||||||
init_segmentor_kn,
|
init_segmentor_kn,
|
||||||
show_result_pyplot,
|
show_result_pyplot,
|
||||||
@ -10,7 +11,8 @@ from .train import (get_root_logger, init_random_seed, set_random_seed,
|
|||||||
train_segmentor)
|
train_segmentor)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'get_root_logger', 'set_random_seed', 'train_segmentor', 'init_segmentor',
|
'get_root_logger', 'set_random_seed', 'train_segmentor',
|
||||||
'init_segmentor_kn', 'inference_segmentor', 'multi_gpu_test',
|
'init_segmentor', 'init_segmentor_kn', 'inference_segmentor',
|
||||||
'single_gpu_test', 'show_result_pyplot', 'init_random_seed'
|
'inference_segmentor_kn', 'multi_gpu_test', 'single_gpu_test',
|
||||||
|
'show_result_pyplot', 'init_random_seed'
|
||||||
]
|
]
|
||||||
|
|||||||
@ -128,14 +128,13 @@ def inference_segmentor(model, img):
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def inference_segmentor_kn(model, img):
|
def inference_segmentor_kn(model, img):
|
||||||
if model.endswith(".onnx"):
|
if isinstance(model, ONNXRuntimeSegmentorKN):
|
||||||
cfg = model.cfg
|
cfg = model.cfg
|
||||||
test_pipeline = [LoadImage()] + cfg.data.test.pipeline[1:]
|
test_pipeline = [LoadImage()] + cfg.data.test.pipeline[1:]
|
||||||
test_pipeline = Compose(test_pipeline)
|
test_pipeline = Compose(test_pipeline)
|
||||||
data = dict(img=img)
|
data = dict(img=img)
|
||||||
data = test_pipeline(data)
|
data = test_pipeline(data)
|
||||||
data = collate([data], samples_per_gpu=1)
|
data = collate([data], samples_per_gpu=1)
|
||||||
data['img_metas'] = [i.data[0] for i in data['img_metas']]
|
|
||||||
return model(return_loss=False, rescale=True, **data)
|
return model(return_loss=False, rescale=True, **data)
|
||||||
else:
|
else:
|
||||||
return inference_segmentor(model, img)
|
return inference_segmentor(model, img)
|
||||||
|
|||||||
@ -411,6 +411,10 @@ class ONNXRuntimeSegmentorKN(BaseSegmentor):
|
|||||||
preds /= self.count_mat
|
preds /= self.count_mat
|
||||||
return preds
|
return preds
|
||||||
|
|
||||||
|
@property
|
||||||
|
def module(self):
|
||||||
|
return self
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def simple_test(
|
def simple_test(
|
||||||
self,
|
self,
|
||||||
@ -426,7 +430,6 @@ class ONNXRuntimeSegmentorKN(BaseSegmentor):
|
|||||||
seg_pred = self.sess.run(
|
seg_pred = self.sess.run(
|
||||||
self.output_name_list, {self.input_name: img}
|
self.output_name_list, {self.input_name: img}
|
||||||
)[0]
|
)[0]
|
||||||
print(img.shape, seg_pred.shape)
|
|
||||||
if img_meta is not None:
|
if img_meta is not None:
|
||||||
ori_shape = img_meta[0]['ori_shape']
|
ori_shape = img_meta[0]['ori_shape']
|
||||||
if not (ori_shape[0] == seg_pred.shape[-2]
|
if not (ori_shape[0] == seg_pred.shape[-2]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user