kneron_model_converter/ktc/onnx_optimizer_1_7.py
2026-01-28 06:16:04 +00:00

297 lines
9.3 KiB
Python

# -*- coding: utf-8 -*-
# This file mainly contains the class and the functions of running knerex and compiler.
from typing import Dict, List
import sys
import subprocess
import os
import onnx
import logging
sys.path.insert(0, "/workspace/libs/ONNX_Convertor/optimizer_scripts")
from pytorch_exported_onnx_preprocess import torch_exported_onnx_flow
from onnx2onnx import onnx2onnx_flow
from tensorflow2onnx import tf2onnx_flow
from opset_8_to_9 import convert_opset_8_to_9
from opset_9_to_11 import onnx1_4to1_6, convert_opset_9_to_11
from opset_10_to_11 import convert_opset_10_to_11
from tools import modhelper, replacing, other
sys.path.insert(0, "/workspace/libs/ONNX_Convertor/keras-onnx")
import onnx_keras
sys.path.insert(0, "/workspace/libs/ONNX_Convertor/tflite-onnx/onnx_tflite")
import tflite2onnx
def preprocess(m):
m = onnx.utils.polish_model(m)
g = m.graph
replacing.replace_initializer_with_Constant(g)
other.topological_sort(g)
return m
def postprocess(m):
g = m.graph
while(len(g.value_info) > 0):
g.value_info.pop()
passes = ['extract_constant_to_initializer']
m = onnx.optimizer.optimize(m, passes)
g = m.graph
replacing.replace_initializer_with_Constant(g)
other.topological_sort(g)
# Polish and output
m = onnx.utils.polish_model(m)
other.add_output_to_value_info(m.graph)
return m
def delete_nodes(model: onnx.ModelProto, node_names: List[str]) -> onnx.ModelProto:
"""Delete nodes with the given names.
Args:
model (onnx.ModelProto): the input onnx model.
node_names (List[str]): a list of node names.
Returns:
onnx.ModelProto: the result onnx model
"""
# Preprocess
m = preprocess(model)
# Function
modhelper.delete_nodes(m.graph, node_names)
# Post process
m = postprocess(m)
# Done
return m
def delete_inputs(model: onnx.ModelProto, value_names: List[str]) -> onnx.ModelProto:
"""Delete specific inputs
Args:
model (onnx.ModelProto): input onnx model.
value_names (List[str]): inputs to delete.
Returns:
onnx.ModelProto: result onnx model.
"""
# Preprocess
m = preprocess(model)
# Function
modhelper.delete_input(m.graph, value_names)
# Post process
m = postprocess(m)
# Done
return m
def delete_outputs(model: onnx.ModelProto, value_names: List[str]) -> onnx.ModelProto:
"""Delete specific outputs.
Args:
model (onnx.ModelProto): input onnx model.
value_names (List[str]): inputs to delete.
Returns:
onnx.ModelProto: result onnx model.
"""
# Preprocess
m = preprocess(model)
# Function
modhelper.delete_output(m.graph, value_names)
# Post process
m = postprocess(m)
# Done
return m
def cut_graph_from_nodes(model: onnx.ModelProto, node_names: List[str]) -> onnx.ModelProto:
"""Cut the graph from the given node. The difference between this function and the delete_node is that this function also delete all the following nodes.
Args:
model (onnx.ModelProto): the input onnx model.
node_names (List[str]): names of nodes to cut from.
Returns:
onnx.ModelProto: the result onnx model.
"""
# Preprocess
m = preprocess(model)
# Function
other.remove_nodes(m.graph, cut_nodes=node_names)
# Post process
m = postprocess(m)
# Done
return m
def remove_nodes_with_types(model: onnx.ModelProto, type_names: List[str]) -> onnx.ModelProto:
"""Cut the graph from the nodes with specific operation types. Similar behaviour to cut_graph_from_nodes.
Args:
model (onnx.ModelProto): the input onnx model.
type_names (List[str]): operator types to cut from.
Returns:
onnx.ModelProto: the result onnx model.
"""
# Preprocess
m = preprocess(model)
# Function
other.remove_nodes(m.graph, cut_types=type_names)
# Post process
m = postprocess(m)
# Done
return m
def change_input_output_shapes(model: onnx.ModelProto, input_shape_mapping: Dict=None, output_shape_mapping: Dict=None) -> onnx.ModelProto:
"""Change input shapes and output shapes.
Args:
model (onnx.ModelProto): input onnx model.
input_shape_mapping (Dict, optional): mapping from input names to the shapes to change. Defaults to None.
output_shape_mapping (Dict, optional): mapping from output names to the shapes to change. Defaults to None.
Returns:
onnx.ModelProto: result onnx model.
"""
# Preprocess
m = preprocess(model)
# Function
if input_shape_mapping is not None:
param = []
for input_name in input_shape_mapping:
temp_str = input_name
for x in input_shape_mapping[input_name]:
temp_str += ' ' + str(x)
param.append(temp_str)
other.change_input_shape(m.graph, param)
if output_shape_mapping is not None:
param = []
for output_name in output_shape_mapping:
temp_str = output_name
for x in output_shape_mapping[output_name]:
temp_str += ' ' + str(x)
param.append(temp_str)
other.change_output_shape(m.graph, param)
# Post process
m = postprocess(m)
# Done
return m
def add_conv_after(model: onnx.ModelProto, value_names: List[str]) -> onnx.ModelProto:
"""Add a do-nothing Conv node after the specific value.
Args:
model (onnx.ModelProto): input onnx model.
value_names (List[str]): values after which we add Conv.
Returns:
onnx.ModelProto: result onnx model.
"""
# Preprocess
m = preprocess(model)
# Function
other.add_nop_conv_after(m.graph, value_names)
other.topological_sort(m.graph)
# Post process
m = postprocess(m)
# Done
return m
def add_bn_after(model: onnx.ModelProto, value_names: List[str]) -> onnx.ModelProto:
"""Add a do-nothing BN node after the specific value.
Args:
model (onnx.ModelProto): input onnx model.
value_names (List[str]): values after which we add BN.
Returns:
onnx.ModelProto: result onnx model.
"""
# Preprocess
m = preprocess(model)
# Function
other.add_nop_bn_after(m.graph, value_names)
other.topological_sort(m.graph)
# Post process
m = postprocess(m)
# Done
return m
def rename_output(model: onnx.ModelProto, old_name: str, new_name: str) -> onnx.ModelProto:
"""Rename the specific output
Args:
model (onnx.ModelProto): input onnx model.
old_name (str): old output name.
new_name (str): new output name.
Returns:
onnx.ModelProto: result onnx model.
"""
# Preprocess
m = preprocess(model)
# Function
other.rename_output_name(m.graph, old_name, new_name)
# Post process
m = postprocess(m)
# Done
return m
def pixel_modify(model: onnx.ModelProto, scale: List[float] = None, bias: List[float] = None) -> onnx.ModelProto:
"""Add a special BN node to adjust the input range. Currently only support single input model.
Args:
model (onnx.ModelProto): input onnx model.
scale (List[float]): the scale of the BN node.
bias (List[float]): the bias of the BN node
Returns:
onnx.ModelProto: result onnx model.
"""
# Check dimension size
g = model.graph
if len(g.input) > 1:
raise ValueError("`pixel_modify` only support single input model currently.")
i_n = g.input[0]
if scale is None:
scale = [1] * i_n.type.tensor_type.shape.dim[1].dim_value
if bias is None:
bias = [0] * i_n.type.tensor_type.shape.dim[1].dim_value
if i_n.type.tensor_type.shape.dim[1].dim_value != len(bias) or i_n.type.tensor_type.shape.dim[1].dim_value != len(scale):
raise ValueError("scale (" + str(scale) + ") and bias (" + str(bias) + ") should be same as input dimension:" + str(i_n.type.tensor_type.shape.dim[1].dim_value) )
# Preprocess
m = preprocess(model)
# Function
g = model.graph
i_n = g.input[0]
other.add_bias_scale_bn_after(g, i_n.name, bias, scale)
# Post process
m = postprocess(m)
# Done
return m
def keras2onnx_flow(keras_model_path: str, optimize: int=0, input_shape: List=None) -> onnx.ModelProto:
"""Convert keras model to onnx object.
Args:
keras_model_path (str): the input hdf5/h5 model path.
optimize (int, optional): optimization level. Defaults to 0.
input_shape (List, optional): change the input shape if set. Only single input model is supported. Defaults to None.
Returns:
onnx.ModelProto: the converted onnx.
"""
onnx_keras.set_duplicate_weights(True)
converter = onnx_keras.frontend.KerasFrontend()
converter.loadFromFile(keras_model_path)
return converter.convertToOnnx(optimize, input_shape)
def tflite2onnx_flow(tflite_path: str, release_mode: bool=True, bottom_nodes: List=[]) -> onnx.ModelProto:
"""Convert tflite model to onnx object.
Args:
tflite_path (str): the input tflite model path.
release_mode (bool, optional): whether eliminate the transpose for channel first. Defaults to True.
bottom_nodes (List, optional): nodes name in tflite model which is the bottom node of sub-graph. Defaults to [].
Returns:
onnx.ModelProto: the converted onnx.
"""
return tflite2onnx.main(tflite_path, '/tmp/tflite_converted.onnx', not release_mode, bottom_nodes)