diff --git a/mmseg/models/segmentors/base.py b/mmseg/models/segmentors/base.py index 32f8987..dc6730e 100644 --- a/mmseg/models/segmentors/base.py +++ b/mmseg/models/segmentors/base.py @@ -320,11 +320,15 @@ class ONNXRuntimeSegmentorKN(BaseSegmentor): sess_inputs = sess.get_inputs() assert len(sess_inputs) == 1, "Only onnx with 1 input is supported" self.input_name = sess_inputs[0].name - self.output_name_list = [_.name for _ in sess.get_outputs()] - assert len(self.output_name_list) == 1, "Only onnx with 1 output is supported" + sess_outputs = sess.get_outputs() + self.num_classes = sess_outputs[0].shape[1] + assert len(sess_outputs) == 1, "Only onnx with 1 output is supported" + self.output_name_list = [sess_outputs[0].name] self.cfg = cfg # TODO: necessary? - self.test_mode = cfg.model.test_cfg.mode # NOTE: should be 'whole' or 'slide' + self.test_cfg = cfg.model.test_cfg + self.test_mode = self.test_cfg.mode # NOTE: should be 'whole' or 'slide' self.is_cuda_available = is_cuda_available + self.count_mat = None def extract_feat(self, imgs): raise NotImplementedError('This method is not implemented.') @@ -338,16 +342,60 @@ class ONNXRuntimeSegmentorKN(BaseSegmentor): def forward_test(self, imgs, img_metas, **kwargs): return super().forward_test(imgs, img_metas[0].data, **kwargs) + def simple_slide_inference( + self, + img: np.ndarray, + img_meta: Union[Iterable, None] = None): + h_stride, w_stride = self.test_cfg.stride + h_crop, w_crop = self.test_cfg.crop_size + _, _, h_img, w_img = img.shape + num_classes = self.num_classes + h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1 + w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1 + preds = np.zeros((1, num_classes, h_img, w_img), dtype=np.float32) + # NOTE: count_mat should be invariant since + # input shape of kneron's onnx is fixed + if self.count_mat is None: + count_mat = np.zeros((1, 1, h_img, w_img), dtype=np.float32) + for h_idx in range(h_grids): + for w_idx in range(w_grids): + y1 = h_idx * h_stride + x1 = w_idx * w_stride + y2 = min(y1 + h_crop, h_img) + x2 = min(x1 + w_crop, w_img) + y1 = max(y2 - h_crop, 0) + x1 = max(x2 - w_crop, 0) + crop_img = img[:, :, y1:y2, x1:x2] + crop_seg_logit = self.sess.run( + self.output_name_list, + {self.input_name: crop_img} + )[0] + preds += np.pad( + crop_seg_logit, + ([0, 0], + [0, 0], + [int(y1), int(preds.shape[2] - y2)], + [int(x1), int(preds.shape[3] - x2)]), + ) + if self.count_mat is None: + count_mat[:, :, y1:y2, x1:x2] += 1 + if self.count_mat is None: + assert (count_mat == 0).sum() == 0 + self.count_mat = count_mat + preds /= self.count_mat + return preds + @torch.no_grad() - def simple_test(self, img: torch.Tensor, - img_meta: Union[Iterable, None] = None, - **kwargs) -> list: + def simple_test( + self, + img: torch.Tensor, + img_meta: Union[Iterable, None] = None, + **kwargs) -> list: img = img.cpu().numpy() # NOTE: not using run_with_iobinding since some ort versions # generate wrong results when inferencing with CUDA if self.test_mode == 'slide': - # raise NotImplementedError('slide mode is not implemented yet') - seg_pred = self.sess.run(self.output_name_list, {self.input_name: img})[0] + seg_pred = self.simple_slide_inference(img, img_meta) else: seg_pred = self.sess.run(self.output_name_list, {self.input_name: img})[0] seg_pred = seg_pred.argmax(1)[:, None]