88 lines
3.2 KiB
Python

ONNX = 0
TORCH = 1
KERAS = 2
MXNET = 3
EMULATOR = 4
import numpy as np
def inference_(pre_results, model, _type, **kwargs):
# TODO: align multiple output situation
if _type == 0:
# onnx
sess, input_name = model
if len(input_name) == 1:
if np.ndim(pre_results)==4:
img_data = [np.transpose(pre_results, [0, 3, 1, 2]).astype(np.float32)]
else:
img_data = [pre_results.astype(np.float32)]
else:
img_data = [np.transpose(item, [0, 3, 1, 2]).astype(np.float32)
if np.ndim(item)==4 else item for item in pre_results]
inf_results = sess.run(None, dict(zip(input_name, img_data)))
if isinstance(inf_results, list):
inf_results = [item.transpose([0, 2, 3, 1]) if np.ndim(item)==4 else item for item in inf_results]
elif np.ndim(inf_results) == 4:
inf_results = inf_results.transpose([0, 2, 3, 1])
elif _type == 1:
import torch
from kneron_utils.torch_utils import _get_device
# convert numpy to tensor
device = _get_device()
img_data = torch.from_numpy(np.transpose(pre_results, [0, 3, 1, 2]))
tensor = img_data.float()
img_tensor = tensor.to(device)
# do model inference
with torch.no_grad():
inf_results = model(img_tensor)
if type(inf_results) is tuple or type(inf_results) is list:
inf_results = [item.cpu().numpy() for item in inf_results]
else:
inf_results = inf_results.cpu().numpy()
if isinstance(inf_results, list):
inf_results = [item.transpose([0, 2, 3, 1]) if np.ndim(item)==4 else item for item in inf_results]
elif np.ndim(inf_results) == 4:
inf_results = inf_results.transpose([0, 2, 3, 1])
elif _type == 2:
# keras
inf_results = model.predict(pre_results)
elif _type==3:
# mxnet
import mxnet as mx
assert 'batch_size' in kwargs and 'channel_num' in kwargs and 'image_size' in kwargs
data = mx.ndarray.zeros((kwargs['batch_size'], kwargs['channel_num'], kwargs['image_size'][0], kwargs['image_size'][1]))
data[0][:] = pre_results
db = mx.io.DataBatch(data=(data,))
model.forward(db, is_train=False)
inf_results = [a.asnumpy() for a in model.get_outputs()]
elif _type == 4:
import python_flow.emulator as emu
assert "USER_CONFIG" in kwargs
config = kwargs.get("USER_CONFIG")
assert isinstance(config, dict)
emu_mode = "bypass"
if "emu" in config and "emu_mode" in config["emu"]:
emu_mode = config["emu"]["emu_mode"]
emu_dict = {
"csim": emu.emulator_csim,
"float": emu.emulator_float,
"fixed": emu.emulator_fixed,
"dongle": emu.emulator_dongle,
"bypass": lambda x, y: print("Bypassing inference...")
}
assert emu_mode in emu_dict
inf_results = emu_dict.get(emu_mode)(config, pre_results)
if inf_results is None:
inf_results = []
else:
raise Exception("missing inference implement")
return inf_results