from .common import * from .extract_onnx_info import extract_input_from_onnx import json import logging import os class LoadConfigException(Exception): """Raise this exception is the config is not as expected """ def __init__(self, file_path : str, field : str, addtional_info : str =None): self.message = 'Error while loading {}: {} is required but not found'.format(file_path, field) if addtional_info is not None: self.message += "\n" self.message += addtional_info def __str__(self): return self.message class ModelConfig: model_file : str = None input_shapes = dict() input_folders = dict() simulator_inputs = dict() preprocess_config = dict() outlier = 0.999 quantize_mode = "default" def __init__(self, filepath, model_file=None): # Load json try: f = open(filepath, 'rb') except OSError: logging.error("Cannot load config: {}".format(filepath)) raise self.config = json.load(f) try: # Check required fields in model_info # Check the model if model_file is None: self.model_file = self.config["model_info"]["input_onnx_file"] else: self.model_file = model_file # Check model input shapes self.input_shapes = extract_input_from_onnx(self.model_file) # Check model input images folders self.input_folders = {i["model_input_name"] : i["input_image_folder"] for i in self.config["model_info"]["model_inputs"]} for input_name in self.input_shapes: if input_name not in self.input_folders: raise LoadConfigException(filepath, "model_info.model_inputs", addtional_info="Cannot find the corresponding input image folder of model input `{}`".format( input_name)) # Check quantize_mode if "quantize_mode" in self.config["model_info"]: self.quantize_mode = self.config["model_info"]["quantize_mode"] else: self.quantize_mode = "default" # Check outlier if "outlier" in self.config["model_info"]: self.outlier = self.config["model_info"]["outlier"] else: self.outlier = 0.999 # Check optioanl config for simulator image file if "simulator_img_files" in self.config: self.simulator_inputs = {i["model_input_name"] : i["input_image"] for i in self.config["simulator_img_files"]} for input_name in self.input_shapes: if input_name not in self.simulator_inputs: self.simulator_inputs[input_name] = self.input_folders[input_name] + '/' + os.listdir(self.input_folders[input_name])[0] logging.debug("Using {} as the simulator input for {}".format( self.simulator_inputs[input_name], input_name)) # Check preprocess config self.preprocess_config = self.config["preprocess"] if "img_preprocess_method" not in self.preprocess_config: raise LoadConfigException(filepath, "img_preprocess_method") if "img_channel" not in self.preprocess_config: raise LoadConfigException(filepath, "img_channel") if "radix" not in self.preprocess_config: raise LoadConfigException(filepath, "radix") if "keep_aspect_ratio" not in self.preprocess_config: self.preprocess_config["keep_aspect_ratio"] = True if "pad_mode" not in self.preprocess_config: self.preprocess_config["pad_mode"] = 1 if "rotate" not in self.preprocess_config: self.preprocess_config["rotate"] = 0 if "p_crop" not in self.preprocess_config: self.preprocess_config["p_crop"] = dict() self.preprocess_config["p_crop"]["crop_x"] = 0 self.preprocess_config["p_crop"]["crop_y"] = 0 self.preprocess_config["p_crop"]["crop_w"] = 0 self.preprocess_config["p_crop"]["crop_h"] = 0 self.enable_crop = (self.preprocess_config["p_crop"]["crop_x"] + self.preprocess_config["p_crop"]["crop_y"] + self.preprocess_config["p_crop"]["crop_h"] + self.preprocess_config["p_crop"]["crop_w"] > 0) except KeyError as e: raise LoadConfigException(filepath, e.args[0]) class BatchModelConfig: id : int = None version : str = None path : str = None radix : str = None model_config : ModelConfig = None def __init__(self, raw_config): self.id = raw_config["id"] self.version = raw_config["version"] self.path = raw_config["path"] if "radix_json" in raw_config: self.radix = raw_config["radix_json"] if self.path[-4:] == 'onnx' and self.radix is None: self.model_config = ModelConfig(raw_config["input_params"], self.path) class BatchConfig: """Config for batch analysis and compile """ model_section_addr: str = None encryption = dict() dedicated_buffer = True weight_compress = False model_list : List[BatchModelConfig] = [] def __init__(self, filepath): # Load json try: f = open(filepath, 'rb') except OSError: logging.error("Cannot load config: {}".format(filepath)) raise self.config = json.load(f) try: if "model_section_addr" in self.config: self.model_section_addr = self.config["model_section_addr"] if "dedicated_output_buffer" in self.config: self.dedicated_buffer = self.config["dedicated_output_buffer"] if "encryption" in self.config: self.encryption = self.config["encryption"] if "weight_compress" in self.config: self.weight_compress = self.config["weight_compress"] else: self.encryption["whether_encryption"] = False for raw_config in self.config["models"]: batch_model_config = BatchModelConfig(raw_config) self.model_list.append(batch_model_config) except KeyError as e: raise LoadConfigException(filepath, e.args[0])