tmp
This commit is contained in:
parent
7576467aa1
commit
0563dd8847
@ -61,7 +61,9 @@ def init_segmentor_kn(config, checkpoint=None, device='cuda:0'):
|
|||||||
device_id = int(device_id)
|
device_id = int(device_id)
|
||||||
except Exception:
|
except Exception:
|
||||||
device_id = None if device == 'cpu' else 0
|
device_id = None if device == 'cpu' else 0
|
||||||
model = ONNXRuntimeSegmentorKN(checkpoint, cfg=config, device_id=device_id).eval()
|
model = ONNXRuntimeSegmentorKN(
|
||||||
|
checkpoint, cfg=config, device_id=device_id
|
||||||
|
).eval()
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@ -124,9 +126,17 @@ def inference_segmentor(model, img):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def inference_segmentor_kn(model, img):
|
def inference_segmentor_kn(model, img):
|
||||||
if model.endswith(".onnx"):
|
if model.endswith(".onnx"):
|
||||||
pass
|
cfg = model.cfg
|
||||||
|
test_pipeline = [LoadImage()] + cfg.data.test.pipeline[1:]
|
||||||
|
test_pipeline = Compose(test_pipeline)
|
||||||
|
data = dict(img=img)
|
||||||
|
data = test_pipeline(data)
|
||||||
|
data = collate([data], samples_per_gpu=1)
|
||||||
|
data['img_metas'] = [i.data[0] for i in data['img_metas']]
|
||||||
|
return model(return_loss=False, rescale=True, **data)
|
||||||
else:
|
else:
|
||||||
return inference_segmentor(model, img)
|
return inference_segmentor(model, img)
|
||||||
|
|
||||||
|
|||||||
@ -10,6 +10,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from mmcv.runner import BaseModule, auto_fp16
|
from mmcv.runner import BaseModule, auto_fp16
|
||||||
|
from mmseg.core import get_classes, get_palette
|
||||||
from mmseg.ops import resize
|
from mmseg.ops import resize
|
||||||
|
|
||||||
|
|
||||||
@ -335,6 +336,25 @@ class ONNXRuntimeSegmentorKN(BaseSegmentor):
|
|||||||
self.test_mode = self.test_cfg.mode # NOTE: should be 'whole' or 'slide'
|
self.test_mode = self.test_cfg.mode # NOTE: should be 'whole' or 'slide'
|
||||||
self.is_cuda_available = is_cuda_available
|
self.is_cuda_available = is_cuda_available
|
||||||
self.count_mat = None
|
self.count_mat = None
|
||||||
|
try:
|
||||||
|
if 'test' in cfg.data:
|
||||||
|
dataset_name = cfg.data.test['type']
|
||||||
|
else:
|
||||||
|
dataset_name = cfg.data.train['type']
|
||||||
|
dataset_name = dataset_name.lower()[:-7]
|
||||||
|
self.CLASSES = get_classes(dataset_name)
|
||||||
|
self.PALETTE = get_palette(dataset_name)
|
||||||
|
except (AttributeError, KeyError):
|
||||||
|
warnings.warn(
|
||||||
|
"Failed to fetch dataset name from config; no CLASSES "
|
||||||
|
"and PALETTE for this ONNX model"
|
||||||
|
)
|
||||||
|
except ValueError:
|
||||||
|
warnings.warn(
|
||||||
|
"Failed to fetch CLASSES and PALETTE from dataset "
|
||||||
|
f"{dataset_name}; no CLASSES and PALETTE for this "
|
||||||
|
"ONNX MODEL."
|
||||||
|
)
|
||||||
|
|
||||||
def extract_feat(self, imgs):
|
def extract_feat(self, imgs):
|
||||||
raise NotImplementedError('This method is not implemented.')
|
raise NotImplementedError('This method is not implemented.')
|
||||||
@ -406,6 +426,7 @@ class ONNXRuntimeSegmentorKN(BaseSegmentor):
|
|||||||
seg_pred = self.sess.run(
|
seg_pred = self.sess.run(
|
||||||
self.output_name_list, {self.input_name: img}
|
self.output_name_list, {self.input_name: img}
|
||||||
)[0]
|
)[0]
|
||||||
|
print(img.shape, seg_pred.shape)
|
||||||
if img_meta is not None:
|
if img_meta is not None:
|
||||||
ori_shape = img_meta[0]['ori_shape']
|
ori_shape = img_meta[0]['ori_shape']
|
||||||
if not (ori_shape[0] == seg_pred.shape[-2]
|
if not (ori_shape[0] == seg_pred.shape[-2]
|
||||||
@ -414,6 +435,11 @@ class ONNXRuntimeSegmentorKN(BaseSegmentor):
|
|||||||
seg_pred = resize(
|
seg_pred = resize(
|
||||||
seg_pred, size=tuple(ori_shape[:2]), mode='bilinear')
|
seg_pred, size=tuple(ori_shape[:2]), mode='bilinear')
|
||||||
seg_pred = seg_pred.numpy()
|
seg_pred = seg_pred.numpy()
|
||||||
|
elif img.shape[2:] != seg_pred.shape[2:]:
|
||||||
|
seg_pred = torch.from_numpy(seg_pred).float()
|
||||||
|
seg_pred = resize(
|
||||||
|
seg_pred, size=(img.shape[3], img.shape[2]), mode='bilinear')
|
||||||
|
seg_pred = seg_pred.numpy()
|
||||||
seg_pred = seg_pred.argmax(1)
|
seg_pred = seg_pred.argmax(1)
|
||||||
return list(seg_pred)
|
return list(seg_pred)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user