2022-04-12 14:26:54 +08:00

268 lines
8.1 KiB
Python

"""Combo functions that are usually called together.
"""
import logging
try:
from onnx import optimizer
except ImportError:
import onnxoptimizer as optimizer
from . import helper
from . import other
from . import replacing
from . import eliminating
from . import fusing
from . import constant_folding
from . import removing_transpose
from .common_pattern import torch_pattern_match, tf_pattern_match
from .helper import logger
def preprocess(
model_proto, disable_fuse_bn=False, duplicate_shared_weights=True
):
"""The most common used functions before other processing.
Args:
model_proto: the original model input
duplicate_shared_weights(bool, optional): duplicate shared weights.
Defaults to True.
Return:
the new model after preprocessing
It includes:
- inference shapes
- optimize model by ONNX library
- give names to the nodes
- replace initializer with Constant node
- replace -1 batch size with 1
- eliminate dropout and identity
- eliminate no children inputs
- topological sort
The optimizations provided by ONNX:
- eliminate_identity
- eliminate_nop_dropout
- eliminate_nop_transpose
- eliminate_nop_pad
- eliminate_unused_initializer
- eliminate_deadend
- fuse_consecutive_squeezes
- fuse_consecutive_transposes
- fuse_add_bias_into_conv
- fuse_transpose_into_gemm
- fuse_matmul_add_bias_into_gemm
- fuse_bn_into_conv
- fuse_pad_into_conv
"""
logger.info("Preprocessing the model...")
helper.setup_current_opset_version(model_proto)
eliminating.eliminate_empty_value_infos(model_proto.graph)
other.add_name_to_node(model_proto.graph)
other.rename_all_node_name(model_proto.graph)
replacing.replace_initializer_with_Constant(model_proto.graph)
other.topological_sort(model_proto.graph)
m = other.polish_model(model_proto)
passes = [
"extract_constant_to_initializer",
"eliminate_nop_dropout",
"eliminate_deadend",
"fuse_matmul_add_bias_into_gemm",
"fuse_pad_into_conv",
]
if not disable_fuse_bn:
passes.append("fuse_bn_into_conv")
m = optimizer.optimize(m, passes)
g = m.graph
# Add name again since onnx optimizer higher than 1.7 may remove node names
other.add_name_to_node(g)
if duplicate_shared_weights:
replacing.replace_initializer_with_Constant(
g, duplicate_shared_weights=True
)
other.duplicate_param_shared_constant(g)
else:
replacing.replace_initializer_with_Constant(
g, duplicate_shared_weights=False
)
other.topological_sort(g)
m = other.polish_model(m)
g = m.graph
eliminating.eliminate_consecutive_Cast(m.graph)
eliminating.eliminate_Cast_after_input(m.graph)
eliminating.eliminate_nop_pads(g)
eliminating.eliminate_nop_cast(g)
eliminating.eliminate_Identify_and_Dropout(g)
eliminating.eliminate_trivial_maxpool(g)
eliminating.eliminate_no_children_input(g)
other.format_value_info_shape(g)
other.topological_sort(g)
m = other.inference_shapes(m)
g = m.graph
replacing.replace_split_with_slices(g)
other.topological_sort(g)
return m
def common_optimization(m):
"""Common optimizations can be used in most cases.
:param m: the original model input\\
:return: the new model after preprocessing
It includes:
- transpose B in Gemm
- fuse BN into Gemm
- fuse consecutive Gemm
- replace AveragePool with GAP
- replace Squeeze/Unsqueeze with Reshape
- replace Reshape with Flatten
"""
logger.info("Doing nodes fusion and replacement... ")
m = other.polish_model(m)
g = m.graph
other.transpose_B_in_Gemm(g)
fusing.fuse_BN_into_Gemm(g)
fusing.fuse_BN_with_Reshape_into_Gemm(g)
fusing.fuse_Gemm_into_Gemm(g)
fusing.fuse_consecutive_reducemean(g)
fusing.fuse_slice_nodes_into_conv(g)
fusing.fuse_relu_min_into_clip(g)
other.duplicate_shared_Flatten(g)
replacing.replace_average_pool_with_GAP(g)
m = other.polish_model(m)
g = m.graph
replacing.replace_Squeeze_with_Reshape(g)
replacing.replace_Unsqueeze_with_Reshape(g)
replacing.replace_Reshape_with_Flatten(g)
replacing.replace_ReduceMean_with_GlobalAveragePool(g)
replacing.replace_Sum_with_Adds(g)
replacing.replace_constant_input_concat_with_pad(g)
other.topological_sort(g)
return m
def pytorch_constant_folding(m):
"""Constant folding needed by Pytorch exported models. It should be done
before using onnx optimizers since the dynamic shape structure may affect
the optimizations.
:param m: the original model input\\
:return: the new model after preprocessing
"""
logger.info("Working on constant folding.")
replacing.replace_shape_with_constant(m.graph)
replacing.replace_ConstantOfShape_with_constant(m.graph)
# constant_folding
m = other.inference_shapes(m)
while constant_folding.constant_folding(m.graph):
logging.debug("After constant folding jobs.")
other.topological_sort(m.graph)
while len(m.graph.value_info) != 0:
m.graph.value_info.pop()
m = other.inference_shapes(m)
replacing.replace_shape_with_constant(m.graph)
other.topological_sort(m.graph)
m = torch_pattern_match(m)
m = optimizer.optimize(m, ["eliminate_deadend"])
return m
def tensorflow_optimization(m):
"""Optimizations for tf models can be used in most cases.
:param m: the original model input\\
:return: the new model after preprocessing
It includes:
- eliminate shape change after input
- eliminate Reshape cast
- eliminate Squeeze before Reshape
- fuse Transpose into Constant
- replace Shape with Constant
"""
fusing.fuse_Transpose_into_Constant(m.graph)
fusing.fuse_MatMul_and_Add_into_Gemm(m.graph)
other.topological_sort(m.graph)
m = other.polish_model(m)
# constant folding
replacing.replace_shape_with_constant(m.graph)
# constant_folding
m = other.inference_shapes(m)
while constant_folding.constant_folding(m.graph):
logging.debug("After constant folding jobs.")
other.topological_sort(m.graph)
while len(m.graph.value_info) != 0:
m.graph.value_info.pop()
m = other.inference_shapes(m)
replacing.replace_shape_with_constant(m.graph)
other.topological_sort(m.graph)
m = tf_pattern_match(m)
m = optimizer.optimize(m, ["eliminate_deadend"])
eliminating.eliminate_consecutive_reshape(m.graph)
eliminating.eliminate_Squeeze_before_Reshape(m.graph)
other.topological_sort(m.graph)
return m
def postprocess(m):
"""Inference the shape and prepare for export.
:param m: the original model input\\
:return: the new model after preprocessing
"""
logger.info("Postprocessing the model...")
while len(m.graph.value_info) > 0:
m.graph.value_info.pop()
m = other.polish_model(m)
eliminating.eliminate_single_input_Concat(m.graph)
eliminating.eliminate_nop_Maxpool_and_AveragePool(m.graph)
eliminating.eliminate_trivial_elementwise_calculation(m.graph)
m = other.polish_model(m)
replacing.replace_depthwise_1x1_with_bn(m.graph)
m = other.polish_model(m)
# removing transpose
m = removing_transpose.eliminate_transposes(m)
m = other.polish_model(m)
removing_transpose.remove_trivial_transpose(m.graph)
removing_transpose.fuse_Transpose_into_Gemm_weight(m.graph)
# fuse some nodes
fusing.fuse_mul_and_add_into_bn(m.graph)
m = other.polish_model(m)
fusing.fuse_mul_and_add_into_gemm(m.graph)
m = other.polish_model(m)
fusing.fuse_conv_and_add_into_conv(m.graph)
m = other.polish_model(m)
replacing.replace_mul_to_bn(m.graph)
replacing.replace_div_to_bn(m.graph)
replacing.replace_add_to_bn(m.graph)
replacing.replace_sub_to_bn(m.graph)
replacing.replace_sub_with_bn_and_add(m.graph)
m = other.polish_model(m)
other.add_output_to_value_info(m.graph)
m = optimizer.optimize(m, ["eliminate_deadend"])
m.producer_name = "kneron_formatter"
return m