diff --git a/mmseg/apis/__init__.py b/mmseg/apis/__init__.py index dffd4e1..a3a2933 100644 --- a/mmseg/apis/__init__.py +++ b/mmseg/apis/__init__.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from .inference import ( inference_segmentor, + inference_segmentor_kn, init_segmentor, init_segmentor_kn, show_result_pyplot, @@ -10,7 +11,8 @@ from .train import (get_root_logger, init_random_seed, set_random_seed, train_segmentor) __all__ = [ - 'get_root_logger', 'set_random_seed', 'train_segmentor', 'init_segmentor', - 'init_segmentor_kn', 'inference_segmentor', 'multi_gpu_test', - 'single_gpu_test', 'show_result_pyplot', 'init_random_seed' + 'get_root_logger', 'set_random_seed', 'train_segmentor', + 'init_segmentor', 'init_segmentor_kn', 'inference_segmentor', + 'inference_segmentor_kn', 'multi_gpu_test', 'single_gpu_test', + 'show_result_pyplot', 'init_random_seed' ] diff --git a/mmseg/apis/inference.py b/mmseg/apis/inference.py index 41d508d..648f255 100644 --- a/mmseg/apis/inference.py +++ b/mmseg/apis/inference.py @@ -128,14 +128,13 @@ def inference_segmentor(model, img): @torch.no_grad() def inference_segmentor_kn(model, img): - if model.endswith(".onnx"): + if isinstance(model, ONNXRuntimeSegmentorKN): cfg = model.cfg test_pipeline = [LoadImage()] + cfg.data.test.pipeline[1:] test_pipeline = Compose(test_pipeline) data = dict(img=img) data = test_pipeline(data) data = collate([data], samples_per_gpu=1) - data['img_metas'] = [i.data[0] for i in data['img_metas']] return model(return_loss=False, rescale=True, **data) else: return inference_segmentor(model, img) diff --git a/mmseg/models/segmentors/base.py b/mmseg/models/segmentors/base.py index 257d4fa..339b6b3 100644 --- a/mmseg/models/segmentors/base.py +++ b/mmseg/models/segmentors/base.py @@ -411,6 +411,10 @@ class ONNXRuntimeSegmentorKN(BaseSegmentor): preds /= self.count_mat return preds + @property + def module(self): + return self + @torch.no_grad() def simple_test( self, @@ -426,7 +430,6 @@ class ONNXRuntimeSegmentorKN(BaseSegmentor): seg_pred = self.sess.run( self.output_name_list, {self.input_name: img} )[0] - print(img.shape, seg_pred.shape) if img_meta is not None: ori_shape = img_meta[0]['ori_shape'] if not (ori_shape[0] == seg_pred.shape[-2]