# -*- coding: utf-8 -*- # This file mainly contains the class and the functions of running knerex and compiler. from typing import Dict, List import sys import subprocess import os import onnx import logging sys.path.insert(0, "/workspace/libs/ONNX_Convertor/optimizer_scripts") from pytorch_exported_onnx_preprocess import torch_exported_onnx_flow from onnx2onnx import onnx2onnx_flow from tensorflow2onnx import tf2onnx_flow from opset_8_to_9 import convert_opset_8_to_9 from opset_9_to_11 import onnx1_4to1_6, convert_opset_9_to_11 from opset_10_to_11 import convert_opset_10_to_11 from tools import modhelper, replacing, other sys.path.insert(0, "/workspace/libs/ONNX_Convertor/keras-onnx") import onnx_keras sys.path.insert(0, "/workspace/libs/ONNX_Convertor/tflite-onnx/onnx_tflite") import tflite2onnx def preprocess(m): m = onnx.utils.polish_model(m) g = m.graph replacing.replace_initializer_with_Constant(g) other.topological_sort(g) return m def postprocess(m): g = m.graph while(len(g.value_info) > 0): g.value_info.pop() passes = ['extract_constant_to_initializer'] m = onnx.optimizer.optimize(m, passes) g = m.graph replacing.replace_initializer_with_Constant(g) other.topological_sort(g) # Polish and output m = onnx.utils.polish_model(m) other.add_output_to_value_info(m.graph) return m def delete_nodes(model: onnx.ModelProto, node_names: List[str]) -> onnx.ModelProto: """Delete nodes with the given names. Args: model (onnx.ModelProto): the input onnx model. node_names (List[str]): a list of node names. Returns: onnx.ModelProto: the result onnx model """ # Preprocess m = preprocess(model) # Function modhelper.delete_nodes(m.graph, node_names) # Post process m = postprocess(m) # Done return m def delete_inputs(model: onnx.ModelProto, value_names: List[str]) -> onnx.ModelProto: """Delete specific inputs Args: model (onnx.ModelProto): input onnx model. value_names (List[str]): inputs to delete. Returns: onnx.ModelProto: result onnx model. """ # Preprocess m = preprocess(model) # Function modhelper.delete_input(m.graph, value_names) # Post process m = postprocess(m) # Done return m def delete_outputs(model: onnx.ModelProto, value_names: List[str]) -> onnx.ModelProto: """Delete specific outputs. Args: model (onnx.ModelProto): input onnx model. value_names (List[str]): inputs to delete. Returns: onnx.ModelProto: result onnx model. """ # Preprocess m = preprocess(model) # Function modhelper.delete_output(m.graph, value_names) # Post process m = postprocess(m) # Done return m def cut_graph_from_nodes(model: onnx.ModelProto, node_names: List[str]) -> onnx.ModelProto: """Cut the graph from the given node. The difference between this function and the delete_node is that this function also delete all the following nodes. Args: model (onnx.ModelProto): the input onnx model. node_names (List[str]): names of nodes to cut from. Returns: onnx.ModelProto: the result onnx model. """ # Preprocess m = preprocess(model) # Function other.remove_nodes(m.graph, cut_nodes=node_names) # Post process m = postprocess(m) # Done return m def remove_nodes_with_types(model: onnx.ModelProto, type_names: List[str]) -> onnx.ModelProto: """Cut the graph from the nodes with specific operation types. Similar behaviour to cut_graph_from_nodes. Args: model (onnx.ModelProto): the input onnx model. type_names (List[str]): operator types to cut from. Returns: onnx.ModelProto: the result onnx model. """ # Preprocess m = preprocess(model) # Function other.remove_nodes(m.graph, cut_types=type_names) # Post process m = postprocess(m) # Done return m def change_input_output_shapes(model: onnx.ModelProto, input_shape_mapping: Dict=None, output_shape_mapping: Dict=None) -> onnx.ModelProto: """Change input shapes and output shapes. Args: model (onnx.ModelProto): input onnx model. input_shape_mapping (Dict, optional): mapping from input names to the shapes to change. Defaults to None. output_shape_mapping (Dict, optional): mapping from output names to the shapes to change. Defaults to None. Returns: onnx.ModelProto: result onnx model. """ # Preprocess m = preprocess(model) # Function if input_shape_mapping is not None: param = [] for input_name in input_shape_mapping: temp_str = input_name for x in input_shape_mapping[input_name]: temp_str += ' ' + str(x) param.append(temp_str) other.change_input_shape(m.graph, param) if output_shape_mapping is not None: param = [] for output_name in output_shape_mapping: temp_str = output_name for x in output_shape_mapping[output_name]: temp_str += ' ' + str(x) param.append(temp_str) other.change_output_shape(m.graph, param) # Post process m = postprocess(m) # Done return m def add_conv_after(model: onnx.ModelProto, value_names: List[str]) -> onnx.ModelProto: """Add a do-nothing Conv node after the specific value. Args: model (onnx.ModelProto): input onnx model. value_names (List[str]): values after which we add Conv. Returns: onnx.ModelProto: result onnx model. """ # Preprocess m = preprocess(model) # Function other.add_nop_conv_after(m.graph, value_names) other.topological_sort(m.graph) # Post process m = postprocess(m) # Done return m def add_bn_after(model: onnx.ModelProto, value_names: List[str]) -> onnx.ModelProto: """Add a do-nothing BN node after the specific value. Args: model (onnx.ModelProto): input onnx model. value_names (List[str]): values after which we add BN. Returns: onnx.ModelProto: result onnx model. """ # Preprocess m = preprocess(model) # Function other.add_nop_bn_after(m.graph, value_names) other.topological_sort(m.graph) # Post process m = postprocess(m) # Done return m def rename_output(model: onnx.ModelProto, old_name: str, new_name: str) -> onnx.ModelProto: """Rename the specific output Args: model (onnx.ModelProto): input onnx model. old_name (str): old output name. new_name (str): new output name. Returns: onnx.ModelProto: result onnx model. """ # Preprocess m = preprocess(model) # Function other.rename_output_name(m.graph, old_name, new_name) # Post process m = postprocess(m) # Done return m def pixel_modify(model: onnx.ModelProto, scale: List[float] = None, bias: List[float] = None) -> onnx.ModelProto: """Add a special BN node to adjust the input range. Currently only support single input model. Args: model (onnx.ModelProto): input onnx model. scale (List[float]): the scale of the BN node. bias (List[float]): the bias of the BN node Returns: onnx.ModelProto: result onnx model. """ # Check dimension size g = model.graph if len(g.input) > 1: raise ValueError("`pixel_modify` only support single input model currently.") i_n = g.input[0] if scale is None: scale = [1] * i_n.type.tensor_type.shape.dim[1].dim_value if bias is None: bias = [0] * i_n.type.tensor_type.shape.dim[1].dim_value if i_n.type.tensor_type.shape.dim[1].dim_value != len(bias) or i_n.type.tensor_type.shape.dim[1].dim_value != len(scale): raise ValueError("scale (" + str(scale) + ") and bias (" + str(bias) + ") should be same as input dimension:" + str(i_n.type.tensor_type.shape.dim[1].dim_value) ) # Preprocess m = preprocess(model) # Function g = model.graph i_n = g.input[0] other.add_bias_scale_bn_after(g, i_n.name, bias, scale) # Post process m = postprocess(m) # Done return m def keras2onnx_flow(keras_model_path: str, optimize: int=0, input_shape: List=None) -> onnx.ModelProto: """Convert keras model to onnx object. Args: keras_model_path (str): the input hdf5/h5 model path. optimize (int, optional): optimization level. Defaults to 0. input_shape (List, optional): change the input shape if set. Only single input model is supported. Defaults to None. Returns: onnx.ModelProto: the converted onnx. """ onnx_keras.set_duplicate_weights(True) converter = onnx_keras.frontend.KerasFrontend() converter.loadFromFile(keras_model_path) return converter.convertToOnnx(optimize, input_shape) def tflite2onnx_flow(tflite_path: str, release_mode: bool=True, bottom_nodes: List=[]) -> onnx.ModelProto: """Convert tflite model to onnx object. Args: tflite_path (str): the input tflite model path. release_mode (bool, optional): whether eliminate the transpose for channel first. Defaults to True. bottom_nodes (List, optional): nodes name in tflite model which is the bottom node of sub-graph. Defaults to []. Returns: onnx.ModelProto: the converted onnx. """ return tflite2onnx.main(tflite_path, '/tmp/tflite_converted.onnx', not release_mode, bottom_nodes)