39 lines
1.8 KiB
Python

import os
import base.inference
def load_model(model_path, **kwargs):
if "USER_CONFIG" not in kwargs:
type = os.path.splitext(model_path.lower())[1][1:]
if type == 'onnx':
import onnxruntime
sess = onnxruntime.InferenceSession(model_path)
# input_name = sess.get_inputs()[0].name
input_names = [item.name for item in sess.get_inputs()]
# model = [sess, input_name]
model = [sess, input_names]
_inference_type = base.inference.ONNX
elif type in ['h5', "hdf5"]:
import keras
model = keras.models.load_model(model_path)
_inference_type = base.inference.KERAS
elif type in ['pts', 'pt', 'tar', 'pth']:
from kneron_utils.torch_utils import load_torch_model, _get_device
assert 'model_def_path' in kwargs and 'module_name' in kwargs
model = load_torch_model(model_path, **kwargs)
device = _get_device()
model.eval()
model = model.to(device)
_inference_type = base.inference.TORCH
elif type in ['params']:
import mxnet as mx
from kneron_utils.mxnet_utils import get_model
assert 'layer' in kwargs and 'batch_size' in kwargs and 'channel_num' in kwargs and 'image_size' in kwargs
label_shape=eval(kwargs['label_shape']) if 'label_shape' in kwargs else None
model=get_model(mx.cpu(),kwargs['image_size'],kwargs['channel_num'],model_path,kwargs['layer'],kwargs['batch_size'],label_shape)
_inference_type = base.inference.MXNET
else:
raise TypeError("unknown model type")
else:
model = None
_inference_type = base.inference.EMULATOR
return model, _inference_type