297 lines
9.3 KiB
Python
297 lines
9.3 KiB
Python
# -*- coding: utf-8 -*-
|
|
|
|
# This file mainly contains the class and the functions of running knerex and compiler.
|
|
from typing import Dict, List
|
|
import sys
|
|
import subprocess
|
|
import os
|
|
import onnx
|
|
import logging
|
|
|
|
sys.path.insert(0, "/workspace/libs/ONNX_Convertor/optimizer_scripts")
|
|
from pytorch_exported_onnx_preprocess import torch_exported_onnx_flow
|
|
from onnx2onnx import onnx2onnx_flow
|
|
from tensorflow2onnx import tf2onnx_flow
|
|
from opset_8_to_9 import convert_opset_8_to_9
|
|
from opset_9_to_11 import onnx1_4to1_6, convert_opset_9_to_11
|
|
from opset_10_to_11 import convert_opset_10_to_11
|
|
from tools import modhelper, replacing, other
|
|
sys.path.insert(0, "/workspace/libs/ONNX_Convertor/keras-onnx")
|
|
import onnx_keras
|
|
|
|
sys.path.insert(0, "/workspace/libs/ONNX_Convertor/tflite-onnx/onnx_tflite")
|
|
import tflite2onnx
|
|
|
|
def preprocess(m):
|
|
m = onnx.utils.polish_model(m)
|
|
g = m.graph
|
|
replacing.replace_initializer_with_Constant(g)
|
|
other.topological_sort(g)
|
|
return m
|
|
|
|
def postprocess(m):
|
|
g = m.graph
|
|
while(len(g.value_info) > 0):
|
|
g.value_info.pop()
|
|
passes = ['extract_constant_to_initializer']
|
|
m = onnx.optimizer.optimize(m, passes)
|
|
g = m.graph
|
|
replacing.replace_initializer_with_Constant(g)
|
|
other.topological_sort(g)
|
|
# Polish and output
|
|
m = onnx.utils.polish_model(m)
|
|
other.add_output_to_value_info(m.graph)
|
|
return m
|
|
|
|
def delete_nodes(model: onnx.ModelProto, node_names: List[str]) -> onnx.ModelProto:
|
|
"""Delete nodes with the given names.
|
|
|
|
Args:
|
|
model (onnx.ModelProto): the input onnx model.
|
|
node_names (List[str]): a list of node names.
|
|
|
|
Returns:
|
|
onnx.ModelProto: the result onnx model
|
|
"""
|
|
# Preprocess
|
|
m = preprocess(model)
|
|
# Function
|
|
modhelper.delete_nodes(m.graph, node_names)
|
|
# Post process
|
|
m = postprocess(m)
|
|
# Done
|
|
return m
|
|
|
|
def delete_inputs(model: onnx.ModelProto, value_names: List[str]) -> onnx.ModelProto:
|
|
"""Delete specific inputs
|
|
|
|
Args:
|
|
model (onnx.ModelProto): input onnx model.
|
|
value_names (List[str]): inputs to delete.
|
|
|
|
Returns:
|
|
onnx.ModelProto: result onnx model.
|
|
"""
|
|
# Preprocess
|
|
m = preprocess(model)
|
|
# Function
|
|
modhelper.delete_input(m.graph, value_names)
|
|
# Post process
|
|
m = postprocess(m)
|
|
# Done
|
|
return m
|
|
|
|
def delete_outputs(model: onnx.ModelProto, value_names: List[str]) -> onnx.ModelProto:
|
|
"""Delete specific outputs.
|
|
|
|
Args:
|
|
model (onnx.ModelProto): input onnx model.
|
|
value_names (List[str]): inputs to delete.
|
|
|
|
Returns:
|
|
onnx.ModelProto: result onnx model.
|
|
"""
|
|
# Preprocess
|
|
m = preprocess(model)
|
|
# Function
|
|
modhelper.delete_output(m.graph, value_names)
|
|
# Post process
|
|
m = postprocess(m)
|
|
# Done
|
|
return m
|
|
|
|
def cut_graph_from_nodes(model: onnx.ModelProto, node_names: List[str]) -> onnx.ModelProto:
|
|
"""Cut the graph from the given node. The difference between this function and the delete_node is that this function also delete all the following nodes.
|
|
|
|
Args:
|
|
model (onnx.ModelProto): the input onnx model.
|
|
node_names (List[str]): names of nodes to cut from.
|
|
|
|
Returns:
|
|
onnx.ModelProto: the result onnx model.
|
|
"""
|
|
# Preprocess
|
|
m = preprocess(model)
|
|
# Function
|
|
other.remove_nodes(m.graph, cut_nodes=node_names)
|
|
# Post process
|
|
m = postprocess(m)
|
|
# Done
|
|
return m
|
|
|
|
def remove_nodes_with_types(model: onnx.ModelProto, type_names: List[str]) -> onnx.ModelProto:
|
|
"""Cut the graph from the nodes with specific operation types. Similar behaviour to cut_graph_from_nodes.
|
|
|
|
Args:
|
|
model (onnx.ModelProto): the input onnx model.
|
|
type_names (List[str]): operator types to cut from.
|
|
|
|
Returns:
|
|
onnx.ModelProto: the result onnx model.
|
|
"""
|
|
# Preprocess
|
|
m = preprocess(model)
|
|
# Function
|
|
other.remove_nodes(m.graph, cut_types=type_names)
|
|
# Post process
|
|
m = postprocess(m)
|
|
# Done
|
|
return m
|
|
|
|
def change_input_output_shapes(model: onnx.ModelProto, input_shape_mapping: Dict=None, output_shape_mapping: Dict=None) -> onnx.ModelProto:
|
|
"""Change input shapes and output shapes.
|
|
|
|
Args:
|
|
model (onnx.ModelProto): input onnx model.
|
|
input_shape_mapping (Dict, optional): mapping from input names to the shapes to change. Defaults to None.
|
|
output_shape_mapping (Dict, optional): mapping from output names to the shapes to change. Defaults to None.
|
|
|
|
Returns:
|
|
onnx.ModelProto: result onnx model.
|
|
"""
|
|
# Preprocess
|
|
m = preprocess(model)
|
|
# Function
|
|
if input_shape_mapping is not None:
|
|
param = []
|
|
for input_name in input_shape_mapping:
|
|
temp_str = input_name
|
|
for x in input_shape_mapping[input_name]:
|
|
temp_str += ' ' + str(x)
|
|
param.append(temp_str)
|
|
other.change_input_shape(m.graph, param)
|
|
if output_shape_mapping is not None:
|
|
param = []
|
|
for output_name in output_shape_mapping:
|
|
temp_str = output_name
|
|
for x in output_shape_mapping[output_name]:
|
|
temp_str += ' ' + str(x)
|
|
param.append(temp_str)
|
|
other.change_output_shape(m.graph, param)
|
|
# Post process
|
|
m = postprocess(m)
|
|
# Done
|
|
return m
|
|
|
|
def add_conv_after(model: onnx.ModelProto, value_names: List[str]) -> onnx.ModelProto:
|
|
"""Add a do-nothing Conv node after the specific value.
|
|
|
|
Args:
|
|
model (onnx.ModelProto): input onnx model.
|
|
value_names (List[str]): values after which we add Conv.
|
|
|
|
Returns:
|
|
onnx.ModelProto: result onnx model.
|
|
"""
|
|
# Preprocess
|
|
m = preprocess(model)
|
|
# Function
|
|
other.add_nop_conv_after(m.graph, value_names)
|
|
other.topological_sort(m.graph)
|
|
# Post process
|
|
m = postprocess(m)
|
|
# Done
|
|
return m
|
|
|
|
def add_bn_after(model: onnx.ModelProto, value_names: List[str]) -> onnx.ModelProto:
|
|
"""Add a do-nothing BN node after the specific value.
|
|
|
|
Args:
|
|
model (onnx.ModelProto): input onnx model.
|
|
value_names (List[str]): values after which we add BN.
|
|
|
|
Returns:
|
|
onnx.ModelProto: result onnx model.
|
|
"""
|
|
# Preprocess
|
|
m = preprocess(model)
|
|
# Function
|
|
other.add_nop_bn_after(m.graph, value_names)
|
|
other.topological_sort(m.graph)
|
|
# Post process
|
|
m = postprocess(m)
|
|
# Done
|
|
return m
|
|
|
|
def rename_output(model: onnx.ModelProto, old_name: str, new_name: str) -> onnx.ModelProto:
|
|
"""Rename the specific output
|
|
|
|
Args:
|
|
model (onnx.ModelProto): input onnx model.
|
|
old_name (str): old output name.
|
|
new_name (str): new output name.
|
|
|
|
Returns:
|
|
onnx.ModelProto: result onnx model.
|
|
"""
|
|
# Preprocess
|
|
m = preprocess(model)
|
|
# Function
|
|
other.rename_output_name(m.graph, old_name, new_name)
|
|
# Post process
|
|
m = postprocess(m)
|
|
# Done
|
|
return m
|
|
|
|
def pixel_modify(model: onnx.ModelProto, scale: List[float] = None, bias: List[float] = None) -> onnx.ModelProto:
|
|
"""Add a special BN node to adjust the input range. Currently only support single input model.
|
|
|
|
Args:
|
|
model (onnx.ModelProto): input onnx model.
|
|
scale (List[float]): the scale of the BN node.
|
|
bias (List[float]): the bias of the BN node
|
|
Returns:
|
|
onnx.ModelProto: result onnx model.
|
|
"""
|
|
# Check dimension size
|
|
g = model.graph
|
|
if len(g.input) > 1:
|
|
raise ValueError("`pixel_modify` only support single input model currently.")
|
|
i_n = g.input[0]
|
|
if scale is None:
|
|
scale = [1] * i_n.type.tensor_type.shape.dim[1].dim_value
|
|
if bias is None:
|
|
bias = [0] * i_n.type.tensor_type.shape.dim[1].dim_value
|
|
if i_n.type.tensor_type.shape.dim[1].dim_value != len(bias) or i_n.type.tensor_type.shape.dim[1].dim_value != len(scale):
|
|
raise ValueError("scale (" + str(scale) + ") and bias (" + str(bias) + ") should be same as input dimension:" + str(i_n.type.tensor_type.shape.dim[1].dim_value) )
|
|
# Preprocess
|
|
m = preprocess(model)
|
|
# Function
|
|
g = model.graph
|
|
i_n = g.input[0]
|
|
other.add_bias_scale_bn_after(g, i_n.name, bias, scale)
|
|
# Post process
|
|
m = postprocess(m)
|
|
# Done
|
|
return m
|
|
|
|
def keras2onnx_flow(keras_model_path: str, optimize: int=0, input_shape: List=None) -> onnx.ModelProto:
|
|
"""Convert keras model to onnx object.
|
|
|
|
Args:
|
|
keras_model_path (str): the input hdf5/h5 model path.
|
|
optimize (int, optional): optimization level. Defaults to 0.
|
|
input_shape (List, optional): change the input shape if set. Only single input model is supported. Defaults to None.
|
|
|
|
Returns:
|
|
onnx.ModelProto: the converted onnx.
|
|
"""
|
|
onnx_keras.set_duplicate_weights(True)
|
|
converter = onnx_keras.frontend.KerasFrontend()
|
|
converter.loadFromFile(keras_model_path)
|
|
return converter.convertToOnnx(optimize, input_shape)
|
|
|
|
|
|
def tflite2onnx_flow(tflite_path: str, release_mode: bool=True, bottom_nodes: List=[]) -> onnx.ModelProto:
|
|
"""Convert tflite model to onnx object.
|
|
|
|
Args:
|
|
tflite_path (str): the input tflite model path.
|
|
release_mode (bool, optional): whether eliminate the transpose for channel first. Defaults to True.
|
|
bottom_nodes (List, optional): nodes name in tflite model which is the bottom node of sub-graph. Defaults to [].
|
|
|
|
Returns:
|
|
onnx.ModelProto: the converted onnx.
|
|
"""
|
|
return tflite2onnx.main(tflite_path, '/tmp/tflite_converted.onnx', not release_mode, bottom_nodes)
|