ONNX = 0 TORCH = 1 KERAS = 2 MXNET = 3 EMULATOR = 4 import numpy as np def inference_(pre_results, model, _type, **kwargs): # TODO: align multiple output situation if _type == 0: # onnx sess, input_name = model if len(input_name) == 1: if np.ndim(pre_results)==4: img_data = [np.transpose(pre_results, [0, 3, 1, 2]).astype(np.float32)] else: img_data = [pre_results.astype(np.float32)] else: img_data = [np.transpose(item, [0, 3, 1, 2]).astype(np.float32) if np.ndim(item)==4 else item for item in pre_results] inf_results = sess.run(None, dict(zip(input_name, img_data))) if isinstance(inf_results, list): inf_results = [item.transpose([0, 2, 3, 1]) if np.ndim(item)==4 else item for item in inf_results] elif np.ndim(inf_results) == 4: inf_results = inf_results.transpose([0, 2, 3, 1]) elif _type == 1: import torch from kneron_utils.torch_utils import _get_device # convert numpy to tensor device = _get_device() img_data = torch.from_numpy(np.transpose(pre_results, [0, 3, 1, 2])) tensor = img_data.float() img_tensor = tensor.to(device) # do model inference with torch.no_grad(): inf_results = model(img_tensor) if type(inf_results) is tuple or type(inf_results) is list: inf_results = [item.cpu().numpy() for item in inf_results] else: inf_results = inf_results.cpu().numpy() if isinstance(inf_results, list): inf_results = [item.transpose([0, 2, 3, 1]) if np.ndim(item)==4 else item for item in inf_results] elif np.ndim(inf_results) == 4: inf_results = inf_results.transpose([0, 2, 3, 1]) elif _type == 2: # keras inf_results = model.predict(pre_results) elif _type==3: # mxnet import mxnet as mx assert 'batch_size' in kwargs and 'channel_num' in kwargs and 'image_size' in kwargs data = mx.ndarray.zeros((kwargs['batch_size'], kwargs['channel_num'], kwargs['image_size'][0], kwargs['image_size'][1])) data[0][:] = pre_results db = mx.io.DataBatch(data=(data,)) model.forward(db, is_train=False) inf_results = [a.asnumpy() for a in model.get_outputs()] elif _type == 4: import python_flow.emulator as emu assert "USER_CONFIG" in kwargs config = kwargs.get("USER_CONFIG") assert isinstance(config, dict) emu_mode = "bypass" if "emu" in config and "emu_mode" in config["emu"]: emu_mode = config["emu"]["emu_mode"] emu_dict = { "csim": emu.emulator_csim, "float": emu.emulator_float, "fixed": emu.emulator_fixed, "dongle": emu.emulator_dongle, "bypass": lambda x, y: print("Bypassing inference...") } assert emu_mode in emu_dict inf_results = emu_dict.get(emu_mode)(config, pre_results) if inf_results is None: inf_results = [] else: raise Exception("missing inference implement") return inf_results