import onnx import onnx.utils import argparse import logging from tools import eliminating from tools import other from tools import special from tools import combo # from tools import temp def onnx2onnx_flow( m: onnx.ModelProto, disable_fuse_bn=False, bn_on_skip=False, bn_before_add=False, bgr=False, norm=False, rgba2yynn=False, eliminate_tail=False, opt_matmul=False, duplicate_shared_weights=True, ) -> onnx.ModelProto: """Optimize the onnx. Args: m (ModelProto): the input onnx ModelProto disable_fuse_bn (bool, optional): do not fuse BN into Conv. Defaults to False. bn_on_skip (bool, optional): add BN operator on skip branches. Defaults to False. bn_before_add (bool, optional): add BN before Add node on every branch. Defaults to False. bgr (bool, optional): add an Conv layer to convert rgb input to bgr. Defaults to False. norm (bool, optional): add an Conv layer to add 0.5 tp the input. Defaults to False. rgba2yynn (bool, optional): add an Conv layer to convert rgb to yynn. Defaults to False. eliminate_tail (bool, optional): remove trailing NPU unsupported nodes. Defaults to False. opt_matmul(bool, optional): optimize MatMul layers due to NPU limit. Defaults to False. duplicate_shared_weights(bool, optional): duplicate shared weights. Defaults to True. Returns: ModelProto: the optimized onnx model object. """ # temp.weight_broadcast(m.graph) m = combo.preprocess(m, disable_fuse_bn, duplicate_shared_weights) # temp.fuse_bias_in_consecutive_1x1_conv(m.graph) # Add BN on skip branch if bn_on_skip: other.add_bn_on_skip_branch(m.graph) elif bn_before_add: other.add_bn_before_add(m.graph) other.add_bn_before_activation(m.graph) # My optimization m = combo.common_optimization(m) # Special options if bgr: special.change_input_from_bgr_to_rgb(m) if norm: special.add_0_5_to_normalized_input(m) if rgba2yynn: special.add_rgb2yynn_node(m) # Remove useless last node if eliminate_tail: eliminating.remove_useless_last_nodes(m.graph) # Postprocessing m = combo.postprocess(m) # Put matmul after postprocess to avoid transpose moving downwards if opt_matmul: special.special_MatMul_process(m.graph) m = other.polish_model(m) return m # Main process if __name__ == "__main__": # Argument parser parser = argparse.ArgumentParser( description="Optimize an ONNX model for Kneron compiler" ) parser.add_argument("in_file", help="input ONNX FILE") parser.add_argument( "-o", "--output", dest="out_file", type=str, help="ouput ONNX FILE" ) parser.add_argument("--log", default="i", type=str, help="set log level") parser.add_argument( "--bgr", action="store_true", default=False, help="set if the model is trained in BGR mode", ) parser.add_argument( "--norm", action="store_true", default=False, help="set if you have the input -0.5~0.5", ) parser.add_argument( "--rgba2yynn", action="store_true", default=False, help="set if the model has yynn input but you want " "to take rgba images", ) parser.add_argument( "--add-bn-on-skip", dest="bn_on_skip", action="store_true", default=False, help="set if you only want to add BN on skip branches", ) parser.add_argument( "--add-bn", dest="bn_before_add", action="store_true", default=False, help="set if you want to add BN before Add", ) parser.add_argument( "-t", "--eliminate-tail-unsupported", dest="eliminate_tail", action="store_true", default=False, help="whether remove the last unsupported node for hardware", ) parser.add_argument( "--no-bn-fusion", dest="disable_fuse_bn", action="store_true", default=False, help="set if you have met errors which related to inferenced " "shape mismatch. This option will prevent fusing " "BatchNormalization into Conv.", ) parser.add_argument( "--opt-matmul", dest="opt_matmul", action="store_true", default=False, help="set if you want to optimize MatMul operations " "for kneron hardware.", ) parser.add_argument( "--no-duplicate-shared-weights", dest="no_duplicate_shared_weights", action="store_true", default=False, help="do not duplicate shared weights. Defaults to False.", ) args = parser.parse_args() if args.out_file is None: outfile = args.in_file[:-5] + "_polished.onnx" else: outfile = args.out_file if args.log == "w": logging.basicConfig(level=logging.WARN) elif args.log == "d": logging.basicConfig(level=logging.DEBUG) elif args.log == "e": logging.basicConfig(level=logging.ERROR) else: logging.basicConfig(level=logging.INFO) # onnx Polish model includes: # -- nop # -- eliminate_identity # -- eliminate_nop_transpose # -- eliminate_nop_pad # -- eliminate_unused_initializer # -- fuse_consecutive_squeezes # -- fuse_consecutive_transposes # -- fuse_add_bias_into_conv # -- fuse_transpose_into_gemm # Basic model organize m = onnx.load(args.in_file) m = onnx2onnx_flow( m, args.disable_fuse_bn, args.bn_on_skip, args.bn_before_add, args.bgr, args.norm, args.rgba2yynn, args.eliminate_tail, args.opt_matmul, not args.no_duplicate_shared_weights, ) onnx.save(m, outfile)