diff --git a/mmseg/apis/inference.py b/mmseg/apis/inference.py index 95f80b6..41d508d 100644 --- a/mmseg/apis/inference.py +++ b/mmseg/apis/inference.py @@ -61,7 +61,9 @@ def init_segmentor_kn(config, checkpoint=None, device='cuda:0'): device_id = int(device_id) except Exception: device_id = None if device == 'cpu' else 0 - model = ONNXRuntimeSegmentorKN(checkpoint, cfg=config, device_id=device_id).eval() + model = ONNXRuntimeSegmentorKN( + checkpoint, cfg=config, device_id=device_id + ).eval() return model @@ -124,9 +126,17 @@ def inference_segmentor(model, img): return result +@torch.no_grad() def inference_segmentor_kn(model, img): if model.endswith(".onnx"): - pass + cfg = model.cfg + test_pipeline = [LoadImage()] + cfg.data.test.pipeline[1:] + test_pipeline = Compose(test_pipeline) + data = dict(img=img) + data = test_pipeline(data) + data = collate([data], samples_per_gpu=1) + data['img_metas'] = [i.data[0] for i in data['img_metas']] + return model(return_loss=False, rescale=True, **data) else: return inference_segmentor(model, img) diff --git a/mmseg/models/segmentors/base.py b/mmseg/models/segmentors/base.py index b4f3e39..257d4fa 100644 --- a/mmseg/models/segmentors/base.py +++ b/mmseg/models/segmentors/base.py @@ -10,6 +10,7 @@ import numpy as np import torch import torch.distributed as dist from mmcv.runner import BaseModule, auto_fp16 +from mmseg.core import get_classes, get_palette from mmseg.ops import resize @@ -335,6 +336,25 @@ class ONNXRuntimeSegmentorKN(BaseSegmentor): self.test_mode = self.test_cfg.mode # NOTE: should be 'whole' or 'slide' self.is_cuda_available = is_cuda_available self.count_mat = None + try: + if 'test' in cfg.data: + dataset_name = cfg.data.test['type'] + else: + dataset_name = cfg.data.train['type'] + dataset_name = dataset_name.lower()[:-7] + self.CLASSES = get_classes(dataset_name) + self.PALETTE = get_palette(dataset_name) + except (AttributeError, KeyError): + warnings.warn( + "Failed to fetch dataset name from config; no CLASSES " + "and PALETTE for this ONNX model" + ) + except ValueError: + warnings.warn( + "Failed to fetch CLASSES and PALETTE from dataset " + f"{dataset_name}; no CLASSES and PALETTE for this " + "ONNX MODEL." + ) def extract_feat(self, imgs): raise NotImplementedError('This method is not implemented.') @@ -406,6 +426,7 @@ class ONNXRuntimeSegmentorKN(BaseSegmentor): seg_pred = self.sess.run( self.output_name_list, {self.input_name: img} )[0] + print(img.shape, seg_pred.shape) if img_meta is not None: ori_shape = img_meta[0]['ori_shape'] if not (ori_shape[0] == seg_pred.shape[-2] @@ -414,6 +435,11 @@ class ONNXRuntimeSegmentorKN(BaseSegmentor): seg_pred = resize( seg_pred, size=tuple(ori_shape[:2]), mode='bilinear') seg_pred = seg_pred.numpy() + elif img.shape[2:] != seg_pred.shape[2:]: + seg_pred = torch.from_numpy(seg_pred).float() + seg_pred = resize( + seg_pred, size=(img.shape[3], img.shape[2]), mode='bilinear') + seg_pred = seg_pred.numpy() seg_pred = seg_pred.argmax(1) return list(seg_pred)