feat: slide mode for kneron onnx
This commit is contained in:
parent
b2246495f5
commit
b83243c5f4
@ -320,11 +320,15 @@ class ONNXRuntimeSegmentorKN(BaseSegmentor):
|
||||
sess_inputs = sess.get_inputs()
|
||||
assert len(sess_inputs) == 1, "Only onnx with 1 input is supported"
|
||||
self.input_name = sess_inputs[0].name
|
||||
self.output_name_list = [_.name for _ in sess.get_outputs()]
|
||||
assert len(self.output_name_list) == 1, "Only onnx with 1 output is supported"
|
||||
sess_outputs = sess.get_outputs()
|
||||
self.num_classes = sess_outputs[0].shape[1]
|
||||
assert len(sess_outputs) == 1, "Only onnx with 1 output is supported"
|
||||
self.output_name_list = [sess_outputs[0].name]
|
||||
self.cfg = cfg # TODO: necessary?
|
||||
self.test_mode = cfg.model.test_cfg.mode # NOTE: should be 'whole' or 'slide'
|
||||
self.test_cfg = cfg.model.test_cfg
|
||||
self.test_mode = self.test_cfg.mode # NOTE: should be 'whole' or 'slide'
|
||||
self.is_cuda_available = is_cuda_available
|
||||
self.count_mat = None
|
||||
|
||||
def extract_feat(self, imgs):
|
||||
raise NotImplementedError('This method is not implemented.')
|
||||
@ -338,16 +342,60 @@ class ONNXRuntimeSegmentorKN(BaseSegmentor):
|
||||
def forward_test(self, imgs, img_metas, **kwargs):
|
||||
return super().forward_test(imgs, img_metas[0].data, **kwargs)
|
||||
|
||||
def simple_slide_inference(
|
||||
self,
|
||||
img: np.ndarray,
|
||||
img_meta: Union[Iterable, None] = None):
|
||||
h_stride, w_stride = self.test_cfg.stride
|
||||
h_crop, w_crop = self.test_cfg.crop_size
|
||||
_, _, h_img, w_img = img.shape
|
||||
num_classes = self.num_classes
|
||||
h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1
|
||||
w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1
|
||||
preds = np.zeros((1, num_classes, h_img, w_img), dtype=np.float32)
|
||||
# NOTE: count_mat should be invariant since
|
||||
# input shape of kneron's onnx is fixed
|
||||
if self.count_mat is None:
|
||||
count_mat = np.zeros((1, 1, h_img, w_img), dtype=np.float32)
|
||||
for h_idx in range(h_grids):
|
||||
for w_idx in range(w_grids):
|
||||
y1 = h_idx * h_stride
|
||||
x1 = w_idx * w_stride
|
||||
y2 = min(y1 + h_crop, h_img)
|
||||
x2 = min(x1 + w_crop, w_img)
|
||||
y1 = max(y2 - h_crop, 0)
|
||||
x1 = max(x2 - w_crop, 0)
|
||||
crop_img = img[:, :, y1:y2, x1:x2]
|
||||
crop_seg_logit = self.sess.run(
|
||||
self.output_name_list,
|
||||
{self.input_name: crop_img}
|
||||
)[0]
|
||||
preds += np.pad(
|
||||
crop_seg_logit,
|
||||
([0, 0],
|
||||
[0, 0],
|
||||
[int(y1), int(preds.shape[2] - y2)],
|
||||
[int(x1), int(preds.shape[3] - x2)]),
|
||||
)
|
||||
if self.count_mat is None:
|
||||
count_mat[:, :, y1:y2, x1:x2] += 1
|
||||
if self.count_mat is None:
|
||||
assert (count_mat == 0).sum() == 0
|
||||
self.count_mat = count_mat
|
||||
preds /= self.count_mat
|
||||
return preds
|
||||
|
||||
@torch.no_grad()
|
||||
def simple_test(self, img: torch.Tensor,
|
||||
img_meta: Union[Iterable, None] = None,
|
||||
**kwargs) -> list:
|
||||
def simple_test(
|
||||
self,
|
||||
img: torch.Tensor,
|
||||
img_meta: Union[Iterable, None] = None,
|
||||
**kwargs) -> list:
|
||||
img = img.cpu().numpy()
|
||||
# NOTE: not using run_with_iobinding since some ort versions
|
||||
# generate wrong results when inferencing with CUDA
|
||||
if self.test_mode == 'slide':
|
||||
# raise NotImplementedError('slide mode is not implemented yet')
|
||||
seg_pred = self.sess.run(self.output_name_list, {self.input_name: img})[0]
|
||||
seg_pred = self.simple_slide_inference(img, img_meta)
|
||||
else:
|
||||
seg_pred = self.sess.run(self.output_name_list, {self.input_name: img})[0]
|
||||
seg_pred = seg_pred.argmax(1)[:, None]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user