39 lines
1.8 KiB
Python
39 lines
1.8 KiB
Python
import os
|
|
import base.inference
|
|
def load_model(model_path, **kwargs):
|
|
if "USER_CONFIG" not in kwargs:
|
|
type = os.path.splitext(model_path.lower())[1][1:]
|
|
if type == 'onnx':
|
|
import onnxruntime
|
|
sess = onnxruntime.InferenceSession(model_path)
|
|
# input_name = sess.get_inputs()[0].name
|
|
input_names = [item.name for item in sess.get_inputs()]
|
|
# model = [sess, input_name]
|
|
model = [sess, input_names]
|
|
_inference_type = base.inference.ONNX
|
|
elif type in ['h5', "hdf5"]:
|
|
import keras
|
|
model = keras.models.load_model(model_path)
|
|
_inference_type = base.inference.KERAS
|
|
elif type in ['pts', 'pt', 'tar', 'pth']:
|
|
from kneron_utils.torch_utils import load_torch_model, _get_device
|
|
assert 'model_def_path' in kwargs and 'module_name' in kwargs
|
|
model = load_torch_model(model_path, **kwargs)
|
|
device = _get_device()
|
|
model.eval()
|
|
model = model.to(device)
|
|
_inference_type = base.inference.TORCH
|
|
elif type in ['params']:
|
|
import mxnet as mx
|
|
from kneron_utils.mxnet_utils import get_model
|
|
assert 'layer' in kwargs and 'batch_size' in kwargs and 'channel_num' in kwargs and 'image_size' in kwargs
|
|
label_shape=eval(kwargs['label_shape']) if 'label_shape' in kwargs else None
|
|
model=get_model(mx.cpu(),kwargs['image_size'],kwargs['channel_num'],model_path,kwargs['layer'],kwargs['batch_size'],label_shape)
|
|
_inference_type = base.inference.MXNET
|
|
else:
|
|
raise TypeError("unknown model type")
|
|
else:
|
|
model = None
|
|
_inference_type = base.inference.EMULATOR
|
|
|
|
return model, _inference_type |