48 lines
1.7 KiB
Python
48 lines
1.7 KiB
Python
import onnx
|
|
import kneronnxopt
|
|
|
|
# Create a dummy function to print a message when the function is not defined in the current environment.
|
|
def print_function_not_defined_message(*args, **kwargs):
|
|
print(
|
|
"The function is not defined in the current environment. Please try switching conda environment using `conda activate base`."
|
|
)
|
|
return None
|
|
|
|
kera2onnx_flow = print_function_not_defined_message
|
|
caffe2onnx_flow = print_function_not_defined_message
|
|
tflite2onnx_flow = print_function_not_defined_message
|
|
|
|
def torch_exported_onnx_flow(
|
|
m: onnx.ModelProto, disable_fuse_bn=False
|
|
) -> onnx.ModelProto:
|
|
if disable_fuse_bn:
|
|
print("WRANING: disable_fuse_bn is not available in current conda environment.")
|
|
return kneronnxopt.optimize(m)
|
|
|
|
|
|
def onnx2onnx_flow(
|
|
m,
|
|
disable_fuse_bn=False,
|
|
bgr=False,
|
|
norm=False,
|
|
rgba2yynn=False,
|
|
eliminate_tail=False,
|
|
opt_matmul=False,
|
|
opt_720=False,
|
|
duplicate_shared_weights=False,
|
|
):
|
|
print("Using kneronnxopt.optimize as the optimizer.")
|
|
if disable_fuse_bn:
|
|
print("WRANING: disable_fuse_bn is not available in current conda environment.")
|
|
if bgr:
|
|
print("WRANING: bgr is not available in current conda environment.")
|
|
if norm:
|
|
print("WRANING: norm is not available in current conda environment.")
|
|
if rgba2yynn:
|
|
print("WRANING: rgba2yynn is not available in current conda environment.")
|
|
if eliminate_tail:
|
|
print("WRANING: eliminate_tail is not available in current conda environment.")
|
|
if opt_720:
|
|
print("WRANING: opt_720 is not available in current conda environment.")
|
|
return kneronnxopt.optimize(m, duplicate_shared_weights=2 if duplicate_shared_weights else 1, opt_matmul=opt_matmul)
|