feat: slide mode for kneron onnx

This commit is contained in:
chingning.chen 2022-03-23 17:58:23 +08:00
parent b2246495f5
commit b83243c5f4

View File

@ -320,11 +320,15 @@ class ONNXRuntimeSegmentorKN(BaseSegmentor):
sess_inputs = sess.get_inputs() sess_inputs = sess.get_inputs()
assert len(sess_inputs) == 1, "Only onnx with 1 input is supported" assert len(sess_inputs) == 1, "Only onnx with 1 input is supported"
self.input_name = sess_inputs[0].name self.input_name = sess_inputs[0].name
self.output_name_list = [_.name for _ in sess.get_outputs()] sess_outputs = sess.get_outputs()
assert len(self.output_name_list) == 1, "Only onnx with 1 output is supported" self.num_classes = sess_outputs[0].shape[1]
assert len(sess_outputs) == 1, "Only onnx with 1 output is supported"
self.output_name_list = [sess_outputs[0].name]
self.cfg = cfg # TODO: necessary? self.cfg = cfg # TODO: necessary?
self.test_mode = cfg.model.test_cfg.mode # NOTE: should be 'whole' or 'slide' self.test_cfg = cfg.model.test_cfg
self.test_mode = self.test_cfg.mode # NOTE: should be 'whole' or 'slide'
self.is_cuda_available = is_cuda_available self.is_cuda_available = is_cuda_available
self.count_mat = None
def extract_feat(self, imgs): def extract_feat(self, imgs):
raise NotImplementedError('This method is not implemented.') raise NotImplementedError('This method is not implemented.')
@ -338,16 +342,60 @@ class ONNXRuntimeSegmentorKN(BaseSegmentor):
def forward_test(self, imgs, img_metas, **kwargs): def forward_test(self, imgs, img_metas, **kwargs):
return super().forward_test(imgs, img_metas[0].data, **kwargs) return super().forward_test(imgs, img_metas[0].data, **kwargs)
def simple_slide_inference(
self,
img: np.ndarray,
img_meta: Union[Iterable, None] = None):
h_stride, w_stride = self.test_cfg.stride
h_crop, w_crop = self.test_cfg.crop_size
_, _, h_img, w_img = img.shape
num_classes = self.num_classes
h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1
w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1
preds = np.zeros((1, num_classes, h_img, w_img), dtype=np.float32)
# NOTE: count_mat should be invariant since
# input shape of kneron's onnx is fixed
if self.count_mat is None:
count_mat = np.zeros((1, 1, h_img, w_img), dtype=np.float32)
for h_idx in range(h_grids):
for w_idx in range(w_grids):
y1 = h_idx * h_stride
x1 = w_idx * w_stride
y2 = min(y1 + h_crop, h_img)
x2 = min(x1 + w_crop, w_img)
y1 = max(y2 - h_crop, 0)
x1 = max(x2 - w_crop, 0)
crop_img = img[:, :, y1:y2, x1:x2]
crop_seg_logit = self.sess.run(
self.output_name_list,
{self.input_name: crop_img}
)[0]
preds += np.pad(
crop_seg_logit,
([0, 0],
[0, 0],
[int(y1), int(preds.shape[2] - y2)],
[int(x1), int(preds.shape[3] - x2)]),
)
if self.count_mat is None:
count_mat[:, :, y1:y2, x1:x2] += 1
if self.count_mat is None:
assert (count_mat == 0).sum() == 0
self.count_mat = count_mat
preds /= self.count_mat
return preds
@torch.no_grad() @torch.no_grad()
def simple_test(self, img: torch.Tensor, def simple_test(
img_meta: Union[Iterable, None] = None, self,
**kwargs) -> list: img: torch.Tensor,
img_meta: Union[Iterable, None] = None,
**kwargs) -> list:
img = img.cpu().numpy() img = img.cpu().numpy()
# NOTE: not using run_with_iobinding since some ort versions # NOTE: not using run_with_iobinding since some ort versions
# generate wrong results when inferencing with CUDA # generate wrong results when inferencing with CUDA
if self.test_mode == 'slide': if self.test_mode == 'slide':
# raise NotImplementedError('slide mode is not implemented yet') seg_pred = self.simple_slide_inference(img, img_meta)
seg_pred = self.sess.run(self.output_name_list, {self.input_name: img})[0]
else: else:
seg_pred = self.sess.run(self.output_name_list, {self.input_name: img})[0] seg_pred = self.sess.run(self.output_name_list, {self.input_name: img})[0]
seg_pred = seg_pred.argmax(1)[:, None] seg_pred = seg_pred.argmax(1)[:, None]