kneron_model_converter/ktc/onnx_optimizer_1_13.py

133 lines
5.2 KiB
Python

import onnx
import kneronnxopt
import os
import sys
from pathlib import Path
# Create a dummy function to print a message when the function is not defined in the current environment.
def print_function_not_defined_message(*args, **kwargs):
print(
"The function is not defined in the current environment. Please try switching conda environment using `conda activate base`."
)
return None
def kera2onnx_flow(keras_model_path: str, optimize: int = 0, input_shape=None) -> onnx.ModelProto:
converter_root = Path(__file__).resolve().parents[1] / "libs" / "ONNX_Convertor" / "keras-onnx"
if not converter_root.exists():
raise FileNotFoundError(f"keras-onnx converter not found: {converter_root}")
converter_path = str(converter_root)
if converter_path not in sys.path:
sys.path.insert(0, converter_path)
try:
import onnx_keras # type: ignore
except Exception as exc:
raise RuntimeError(
"kera2onnx_flow requires the keras-onnx converter and its dependencies (keras/tensorflow)."
) from exc
onnx_keras.set_duplicate_weights(True)
converter = onnx_keras.frontend.KerasFrontend()
converter.loadFromFile(keras_model_path)
return converter.convertToOnnx(optimize, input_shape)
def caffe2onnx_flow(caffe_prototxt_path: str, caffe_model_path: str) -> onnx.ModelProto:
converter_root = Path(__file__).resolve().parents[1] / "libs" / "ONNX_Convertor" / "caffe-onnx"
if not converter_root.exists():
raise FileNotFoundError(f"caffe-onnx converter not found: {converter_root}")
converter_path = str(converter_root)
if converter_path not in sys.path:
sys.path.insert(0, converter_path)
try:
import onnx_caffe # type: ignore
except Exception as exc:
raise RuntimeError(
"caffe2onnx_flow requires the caffe-onnx converter and its dependencies (caffe)."
) from exc
converter = onnx_caffe.frontend.CaffeFrontend()
converter.loadFromFile(caffe_prototxt_path, caffe_model_path)
return converter.convertToOnnx()
def tflite2onnx_flow(tflite_path: str, release_mode: bool = True, bottom_nodes=None) -> onnx.ModelProto:
if bottom_nodes is None:
bottom_nodes = []
converter_root = Path(__file__).resolve().parents[1] / "libs" / "ONNX_Convertor" / "tflite-onnx" / "onnx_tflite"
if not converter_root.exists():
raise FileNotFoundError(f"tflite2onnx converter not found: {converter_root}")
converter_path = str(converter_root)
if converter_path not in sys.path:
sys.path.insert(0, converter_path)
try:
import tflite2onnx # type: ignore
except Exception as exc:
raise RuntimeError(
"tflite2onnx_flow requires the tflite-onnx converter and its dependencies (tensorflow)."
) from exc
# Compatibility: newer TF requires subgraph_index in _get_tensor_details.
try:
import tensorflow as tf # type: ignore
if not getattr(tf.lite.Interpreter, "_ktc_compat_patched", False):
_orig_get_tensor_details = tf.lite.Interpreter._get_tensor_details
def _get_tensor_details_compat(self, tensor_index, *args, **kwargs):
if args or "subgraph_index" in kwargs:
return _orig_get_tensor_details(self, tensor_index, *args, **kwargs)
try:
return _orig_get_tensor_details(self, tensor_index)
except TypeError:
return _orig_get_tensor_details(self, tensor_index, 0)
tf.lite.Interpreter._get_tensor_details = _get_tensor_details_compat # type: ignore
tf.lite.Interpreter._ktc_compat_patched = True # type: ignore
except Exception:
pass
output_dir = os.environ.get("KTC_WORKDIR") or os.environ.get("KTC_OUTPUT_DIR") or "/tmp"
os.makedirs(output_dir, exist_ok=True)
output_path = str(Path(output_dir) / "tflite_converted.onnx")
return tflite2onnx.main(tflite_path, output_path, add_transpose_for_channel_last_first_issue=not release_mode, bottom_nodes_name=bottom_nodes)
def torch_exported_onnx_flow(
m: onnx.ModelProto, disable_fuse_bn=False
) -> onnx.ModelProto:
if disable_fuse_bn:
print("WRANING: disable_fuse_bn is not available in current conda environment.")
return kneronnxopt.optimize(m)
def onnx2onnx_flow(
m,
disable_fuse_bn=False,
bgr=False,
norm=False,
rgba2yynn=False,
eliminate_tail=False,
opt_matmul=False,
opt_720=False,
duplicate_shared_weights=False,
):
print("Using kneronnxopt.optimize as the optimizer.")
if disable_fuse_bn:
print("WRANING: disable_fuse_bn is not available in current conda environment.")
if bgr:
print("WRANING: bgr is not available in current conda environment.")
if norm:
print("WRANING: norm is not available in current conda environment.")
if rgba2yynn:
print("WRANING: rgba2yynn is not available in current conda environment.")
if eliminate_tail:
print("WRANING: eliminate_tail is not available in current conda environment.")
if opt_720:
print("WRANING: opt_720 is not available in current conda environment.")
return kneronnxopt.optimize(m, duplicate_shared_weights=2 if duplicate_shared_weights else 1, opt_matmul=opt_matmul)