150 lines
6.4 KiB
Python
150 lines
6.4 KiB
Python
from .common import *
|
|
from .extract_onnx_info import extract_input_from_onnx
|
|
import json
|
|
import logging
|
|
import os
|
|
|
|
class LoadConfigException(Exception):
|
|
"""Raise this exception is the config is not as expected
|
|
"""
|
|
def __init__(self, file_path : str, field : str, addtional_info : str =None):
|
|
self.message = 'Error while loading {}: {} is required but not found'.format(file_path, field)
|
|
if addtional_info is not None:
|
|
self.message += "\n"
|
|
self.message += addtional_info
|
|
def __str__(self):
|
|
return self.message
|
|
|
|
class ModelConfig:
|
|
model_file : str = None
|
|
input_shapes = dict()
|
|
input_folders = dict()
|
|
simulator_inputs = dict()
|
|
preprocess_config = dict()
|
|
outlier = 0.999
|
|
quantize_mode = "default"
|
|
|
|
def __init__(self, filepath, model_file=None):
|
|
# Load json
|
|
try:
|
|
f = open(filepath, 'rb')
|
|
except OSError:
|
|
logging.error("Cannot load config: {}".format(filepath))
|
|
raise
|
|
self.config = json.load(f)
|
|
|
|
try:
|
|
# Check required fields in model_info
|
|
# Check the model
|
|
if model_file is None:
|
|
self.model_file = self.config["model_info"]["input_onnx_file"]
|
|
else:
|
|
self.model_file = model_file
|
|
# Check model input shapes
|
|
self.input_shapes = extract_input_from_onnx(self.model_file)
|
|
# Check model input images folders
|
|
self.input_folders = {i["model_input_name"] : i["input_image_folder"]
|
|
for i in self.config["model_info"]["model_inputs"]}
|
|
for input_name in self.input_shapes:
|
|
if input_name not in self.input_folders:
|
|
raise LoadConfigException(filepath, "model_info.model_inputs",
|
|
addtional_info="Cannot find the corresponding input image folder of model input `{}`".format(
|
|
input_name))
|
|
# Check quantize_mode
|
|
if "quantize_mode" in self.config["model_info"]:
|
|
self.quantize_mode = self.config["model_info"]["quantize_mode"]
|
|
else:
|
|
self.quantize_mode = "default"
|
|
# Check outlier
|
|
if "outlier" in self.config["model_info"]:
|
|
self.outlier = self.config["model_info"]["outlier"]
|
|
else:
|
|
self.outlier = 0.999
|
|
|
|
# Check optioanl config for simulator image file
|
|
if "simulator_img_files" in self.config:
|
|
self.simulator_inputs = {i["model_input_name"] : i["input_image"]
|
|
for i in self.config["simulator_img_files"]}
|
|
for input_name in self.input_shapes:
|
|
if input_name not in self.simulator_inputs:
|
|
self.simulator_inputs[input_name] = self.input_folders[input_name] + '/' + os.listdir(self.input_folders[input_name])[0]
|
|
logging.debug("Using {} as the simulator input for {}".format(
|
|
self.simulator_inputs[input_name], input_name))
|
|
|
|
# Check preprocess config
|
|
self.preprocess_config = self.config["preprocess"]
|
|
if "img_preprocess_method" not in self.preprocess_config:
|
|
raise LoadConfigException(filepath, "img_preprocess_method")
|
|
if "img_channel" not in self.preprocess_config:
|
|
raise LoadConfigException(filepath, "img_channel")
|
|
if "radix" not in self.preprocess_config:
|
|
raise LoadConfigException(filepath, "radix")
|
|
if "keep_aspect_ratio" not in self.preprocess_config:
|
|
self.preprocess_config["keep_aspect_ratio"] = True
|
|
if "pad_mode" not in self.preprocess_config:
|
|
self.preprocess_config["pad_mode"] = 1
|
|
if "rotate" not in self.preprocess_config:
|
|
self.preprocess_config["rotate"] = 0
|
|
if "p_crop" not in self.preprocess_config:
|
|
self.preprocess_config["p_crop"] = dict()
|
|
self.preprocess_config["p_crop"]["crop_x"] = 0
|
|
self.preprocess_config["p_crop"]["crop_y"] = 0
|
|
self.preprocess_config["p_crop"]["crop_w"] = 0
|
|
self.preprocess_config["p_crop"]["crop_h"] = 0
|
|
self.enable_crop = (self.preprocess_config["p_crop"]["crop_x"] + self.preprocess_config["p_crop"]["crop_y"] +
|
|
self.preprocess_config["p_crop"]["crop_h"] + self.preprocess_config["p_crop"]["crop_w"] > 0)
|
|
except KeyError as e:
|
|
raise LoadConfigException(filepath, e.args[0])
|
|
|
|
class BatchModelConfig:
|
|
id : int = None
|
|
version : str = None
|
|
path : str = None
|
|
radix : str = None
|
|
model_config : ModelConfig = None
|
|
|
|
def __init__(self, raw_config):
|
|
self.id = raw_config["id"]
|
|
self.version = raw_config["version"]
|
|
self.path = raw_config["path"]
|
|
if "radix_json" in raw_config:
|
|
self.radix = raw_config["radix_json"]
|
|
if self.path[-4:] == 'onnx' and self.radix is None:
|
|
self.model_config = ModelConfig(raw_config["input_params"], self.path)
|
|
|
|
class BatchConfig:
|
|
"""Config for batch analysis and compile
|
|
"""
|
|
model_section_addr: str = None
|
|
encryption = dict()
|
|
dedicated_buffer = True
|
|
weight_compress = False
|
|
model_list : List[BatchModelConfig] = []
|
|
|
|
def __init__(self, filepath):
|
|
# Load json
|
|
try:
|
|
f = open(filepath, 'rb')
|
|
except OSError:
|
|
logging.error("Cannot load config: {}".format(filepath))
|
|
raise
|
|
self.config = json.load(f)
|
|
|
|
try:
|
|
if "model_section_addr" in self.config:
|
|
self.model_section_addr = self.config["model_section_addr"]
|
|
if "dedicated_output_buffer" in self.config:
|
|
self.dedicated_buffer = self.config["dedicated_output_buffer"]
|
|
if "encryption" in self.config:
|
|
self.encryption = self.config["encryption"]
|
|
if "weight_compress" in self.config:
|
|
self.weight_compress = self.config["weight_compress"]
|
|
else:
|
|
self.encryption["whether_encryption"] = False
|
|
for raw_config in self.config["models"]:
|
|
batch_model_config = BatchModelConfig(raw_config)
|
|
self.model_list.append(batch_model_config)
|
|
|
|
except KeyError as e:
|
|
raise LoadConfigException(filepath, e.args[0])
|