2026-01-28 06:16:04 +00:00

150 lines
6.4 KiB
Python

from .common import *
from .extract_onnx_info import extract_input_from_onnx
import json
import logging
import os
class LoadConfigException(Exception):
"""Raise this exception is the config is not as expected
"""
def __init__(self, file_path : str, field : str, addtional_info : str =None):
self.message = 'Error while loading {}: {} is required but not found'.format(file_path, field)
if addtional_info is not None:
self.message += "\n"
self.message += addtional_info
def __str__(self):
return self.message
class ModelConfig:
model_file : str = None
input_shapes = dict()
input_folders = dict()
simulator_inputs = dict()
preprocess_config = dict()
outlier = 0.999
quantize_mode = "default"
def __init__(self, filepath, model_file=None):
# Load json
try:
f = open(filepath, 'rb')
except OSError:
logging.error("Cannot load config: {}".format(filepath))
raise
self.config = json.load(f)
try:
# Check required fields in model_info
# Check the model
if model_file is None:
self.model_file = self.config["model_info"]["input_onnx_file"]
else:
self.model_file = model_file
# Check model input shapes
self.input_shapes = extract_input_from_onnx(self.model_file)
# Check model input images folders
self.input_folders = {i["model_input_name"] : i["input_image_folder"]
for i in self.config["model_info"]["model_inputs"]}
for input_name in self.input_shapes:
if input_name not in self.input_folders:
raise LoadConfigException(filepath, "model_info.model_inputs",
addtional_info="Cannot find the corresponding input image folder of model input `{}`".format(
input_name))
# Check quantize_mode
if "quantize_mode" in self.config["model_info"]:
self.quantize_mode = self.config["model_info"]["quantize_mode"]
else:
self.quantize_mode = "default"
# Check outlier
if "outlier" in self.config["model_info"]:
self.outlier = self.config["model_info"]["outlier"]
else:
self.outlier = 0.999
# Check optioanl config for simulator image file
if "simulator_img_files" in self.config:
self.simulator_inputs = {i["model_input_name"] : i["input_image"]
for i in self.config["simulator_img_files"]}
for input_name in self.input_shapes:
if input_name not in self.simulator_inputs:
self.simulator_inputs[input_name] = self.input_folders[input_name] + '/' + os.listdir(self.input_folders[input_name])[0]
logging.debug("Using {} as the simulator input for {}".format(
self.simulator_inputs[input_name], input_name))
# Check preprocess config
self.preprocess_config = self.config["preprocess"]
if "img_preprocess_method" not in self.preprocess_config:
raise LoadConfigException(filepath, "img_preprocess_method")
if "img_channel" not in self.preprocess_config:
raise LoadConfigException(filepath, "img_channel")
if "radix" not in self.preprocess_config:
raise LoadConfigException(filepath, "radix")
if "keep_aspect_ratio" not in self.preprocess_config:
self.preprocess_config["keep_aspect_ratio"] = True
if "pad_mode" not in self.preprocess_config:
self.preprocess_config["pad_mode"] = 1
if "rotate" not in self.preprocess_config:
self.preprocess_config["rotate"] = 0
if "p_crop" not in self.preprocess_config:
self.preprocess_config["p_crop"] = dict()
self.preprocess_config["p_crop"]["crop_x"] = 0
self.preprocess_config["p_crop"]["crop_y"] = 0
self.preprocess_config["p_crop"]["crop_w"] = 0
self.preprocess_config["p_crop"]["crop_h"] = 0
self.enable_crop = (self.preprocess_config["p_crop"]["crop_x"] + self.preprocess_config["p_crop"]["crop_y"] +
self.preprocess_config["p_crop"]["crop_h"] + self.preprocess_config["p_crop"]["crop_w"] > 0)
except KeyError as e:
raise LoadConfigException(filepath, e.args[0])
class BatchModelConfig:
id : int = None
version : str = None
path : str = None
radix : str = None
model_config : ModelConfig = None
def __init__(self, raw_config):
self.id = raw_config["id"]
self.version = raw_config["version"]
self.path = raw_config["path"]
if "radix_json" in raw_config:
self.radix = raw_config["radix_json"]
if self.path[-4:] == 'onnx' and self.radix is None:
self.model_config = ModelConfig(raw_config["input_params"], self.path)
class BatchConfig:
"""Config for batch analysis and compile
"""
model_section_addr: str = None
encryption = dict()
dedicated_buffer = True
weight_compress = False
model_list : List[BatchModelConfig] = []
def __init__(self, filepath):
# Load json
try:
f = open(filepath, 'rb')
except OSError:
logging.error("Cannot load config: {}".format(filepath))
raise
self.config = json.load(f)
try:
if "model_section_addr" in self.config:
self.model_section_addr = self.config["model_section_addr"]
if "dedicated_output_buffer" in self.config:
self.dedicated_buffer = self.config["dedicated_output_buffer"]
if "encryption" in self.config:
self.encryption = self.config["encryption"]
if "weight_compress" in self.config:
self.weight_compress = self.config["weight_compress"]
else:
self.encryption["whether_encryption"] = False
for raw_config in self.config["models"]:
batch_model_config = BatchModelConfig(raw_config)
self.model_list.append(batch_model_config)
except KeyError as e:
raise LoadConfigException(filepath, e.args[0])