2026-03-11 16:13:59 +08:00

209 lines
6.2 KiB
Python

import onnx
import onnx.utils
import argparse
import logging
from tools import eliminating
from tools import other
from tools import special
from tools import combo
# from tools import temp
def onnx2onnx_flow(
m: onnx.ModelProto,
disable_fuse_bn=False,
bn_on_skip=False,
bn_before_add=False,
bgr=False,
norm=False,
rgba2yynn=False,
eliminate_tail=False,
opt_matmul=False,
duplicate_shared_weights=True,
) -> onnx.ModelProto:
"""Optimize the onnx.
Args:
m (ModelProto): the input onnx ModelProto
disable_fuse_bn (bool, optional): do not fuse BN into Conv.
Defaults to False.
bn_on_skip (bool, optional): add BN operator on skip branches.
Defaults to False.
bn_before_add (bool, optional): add BN before Add node on every branch.
Defaults to False.
bgr (bool, optional): add an Conv layer to convert rgb input to bgr.
Defaults to False.
norm (bool, optional): add an Conv layer to add 0.5 tp the input.
Defaults to False.
rgba2yynn (bool, optional): add an Conv layer to convert rgb to yynn.
Defaults to False.
eliminate_tail (bool, optional): remove trailing NPU unsupported nodes.
Defaults to False.
opt_matmul(bool, optional): optimize MatMul layers due to NPU limit.
Defaults to False.
duplicate_shared_weights(bool, optional): duplicate shared weights.
Defaults to True.
Returns:
ModelProto: the optimized onnx model object.
"""
# temp.weight_broadcast(m.graph)
m = combo.preprocess(m, disable_fuse_bn, duplicate_shared_weights)
# temp.fuse_bias_in_consecutive_1x1_conv(m.graph)
# Add BN on skip branch
if bn_on_skip:
other.add_bn_on_skip_branch(m.graph)
elif bn_before_add:
other.add_bn_before_add(m.graph)
other.add_bn_before_activation(m.graph)
# My optimization
m = combo.common_optimization(m)
# Special options
if bgr:
special.change_input_from_bgr_to_rgb(m)
if norm:
special.add_0_5_to_normalized_input(m)
if rgba2yynn:
special.add_rgb2yynn_node(m)
# Remove useless last node
if eliminate_tail:
eliminating.remove_useless_last_nodes(m.graph)
# Postprocessing
m = combo.postprocess(m)
# Put matmul after postprocess to avoid transpose moving downwards
if opt_matmul:
special.special_MatMul_process(m.graph)
m = other.polish_model(m)
return m
# Main process
if __name__ == "__main__":
# Argument parser
parser = argparse.ArgumentParser(
description="Optimize an ONNX model for Kneron compiler"
)
parser.add_argument("in_file", help="input ONNX FILE")
parser.add_argument(
"-o", "--output", dest="out_file", type=str, help="ouput ONNX FILE"
)
parser.add_argument("--log", default="i", type=str, help="set log level")
parser.add_argument(
"--bgr",
action="store_true",
default=False,
help="set if the model is trained in BGR mode",
)
parser.add_argument(
"--norm",
action="store_true",
default=False,
help="set if you have the input -0.5~0.5",
)
parser.add_argument(
"--rgba2yynn",
action="store_true",
default=False,
help="set if the model has yynn input but you want "
"to take rgba images",
)
parser.add_argument(
"--add-bn-on-skip",
dest="bn_on_skip",
action="store_true",
default=False,
help="set if you only want to add BN on skip branches",
)
parser.add_argument(
"--add-bn",
dest="bn_before_add",
action="store_true",
default=False,
help="set if you want to add BN before Add",
)
parser.add_argument(
"-t",
"--eliminate-tail-unsupported",
dest="eliminate_tail",
action="store_true",
default=False,
help="whether remove the last unsupported node for hardware",
)
parser.add_argument(
"--no-bn-fusion",
dest="disable_fuse_bn",
action="store_true",
default=False,
help="set if you have met errors which related to inferenced "
"shape mismatch. This option will prevent fusing "
"BatchNormalization into Conv.",
)
parser.add_argument(
"--opt-matmul",
dest="opt_matmul",
action="store_true",
default=False,
help="set if you want to optimize MatMul operations "
"for kneron hardware.",
)
parser.add_argument(
"--no-duplicate-shared-weights",
dest="no_duplicate_shared_weights",
action="store_true",
default=False,
help="do not duplicate shared weights. Defaults to False.",
)
args = parser.parse_args()
if args.out_file is None:
outfile = args.in_file[:-5] + "_polished.onnx"
else:
outfile = args.out_file
if args.log == "w":
logging.basicConfig(level=logging.WARN)
elif args.log == "d":
logging.basicConfig(level=logging.DEBUG)
elif args.log == "e":
logging.basicConfig(level=logging.ERROR)
else:
logging.basicConfig(level=logging.INFO)
# onnx Polish model includes:
# -- nop
# -- eliminate_identity
# -- eliminate_nop_transpose
# -- eliminate_nop_pad
# -- eliminate_unused_initializer
# -- fuse_consecutive_squeezes
# -- fuse_consecutive_transposes
# -- fuse_add_bias_into_conv
# -- fuse_transpose_into_gemm
# Basic model organize
m = onnx.load(args.in_file)
m = onnx2onnx_flow(
m,
args.disable_fuse_bn,
args.bn_on_skip,
args.bn_before_add,
args.bgr,
args.norm,
args.rgba2yynn,
args.eliminate_tail,
args.opt_matmul,
not args.no_duplicate_shared_weights,
)
onnx.save(m, outfile)