133 lines
5.2 KiB
Python
133 lines
5.2 KiB
Python
import onnx
|
|
import kneronnxopt
|
|
import os
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
# Create a dummy function to print a message when the function is not defined in the current environment.
|
|
def print_function_not_defined_message(*args, **kwargs):
|
|
print(
|
|
"The function is not defined in the current environment. Please try switching conda environment using `conda activate base`."
|
|
)
|
|
return None
|
|
|
|
def kera2onnx_flow(keras_model_path: str, optimize: int = 0, input_shape=None) -> onnx.ModelProto:
|
|
converter_root = Path(__file__).resolve().parents[1] / "libs" / "ONNX_Convertor" / "keras-onnx"
|
|
if not converter_root.exists():
|
|
raise FileNotFoundError(f"keras-onnx converter not found: {converter_root}")
|
|
|
|
converter_path = str(converter_root)
|
|
if converter_path not in sys.path:
|
|
sys.path.insert(0, converter_path)
|
|
|
|
try:
|
|
import onnx_keras # type: ignore
|
|
except Exception as exc:
|
|
raise RuntimeError(
|
|
"kera2onnx_flow requires the keras-onnx converter and its dependencies (keras/tensorflow)."
|
|
) from exc
|
|
|
|
onnx_keras.set_duplicate_weights(True)
|
|
converter = onnx_keras.frontend.KerasFrontend()
|
|
converter.loadFromFile(keras_model_path)
|
|
return converter.convertToOnnx(optimize, input_shape)
|
|
|
|
|
|
def caffe2onnx_flow(caffe_prototxt_path: str, caffe_model_path: str) -> onnx.ModelProto:
|
|
converter_root = Path(__file__).resolve().parents[1] / "libs" / "ONNX_Convertor" / "caffe-onnx"
|
|
if not converter_root.exists():
|
|
raise FileNotFoundError(f"caffe-onnx converter not found: {converter_root}")
|
|
|
|
converter_path = str(converter_root)
|
|
if converter_path not in sys.path:
|
|
sys.path.insert(0, converter_path)
|
|
|
|
try:
|
|
import onnx_caffe # type: ignore
|
|
except Exception as exc:
|
|
raise RuntimeError(
|
|
"caffe2onnx_flow requires the caffe-onnx converter and its dependencies (caffe)."
|
|
) from exc
|
|
|
|
converter = onnx_caffe.frontend.CaffeFrontend()
|
|
converter.loadFromFile(caffe_prototxt_path, caffe_model_path)
|
|
return converter.convertToOnnx()
|
|
def tflite2onnx_flow(tflite_path: str, release_mode: bool = True, bottom_nodes=None) -> onnx.ModelProto:
|
|
if bottom_nodes is None:
|
|
bottom_nodes = []
|
|
|
|
converter_root = Path(__file__).resolve().parents[1] / "libs" / "ONNX_Convertor" / "tflite-onnx" / "onnx_tflite"
|
|
if not converter_root.exists():
|
|
raise FileNotFoundError(f"tflite2onnx converter not found: {converter_root}")
|
|
|
|
converter_path = str(converter_root)
|
|
if converter_path not in sys.path:
|
|
sys.path.insert(0, converter_path)
|
|
|
|
try:
|
|
import tflite2onnx # type: ignore
|
|
except Exception as exc:
|
|
raise RuntimeError(
|
|
"tflite2onnx_flow requires the tflite-onnx converter and its dependencies (tensorflow)."
|
|
) from exc
|
|
|
|
# Compatibility: newer TF requires subgraph_index in _get_tensor_details.
|
|
try:
|
|
import tensorflow as tf # type: ignore
|
|
|
|
if not getattr(tf.lite.Interpreter, "_ktc_compat_patched", False):
|
|
_orig_get_tensor_details = tf.lite.Interpreter._get_tensor_details
|
|
|
|
def _get_tensor_details_compat(self, tensor_index, *args, **kwargs):
|
|
if args or "subgraph_index" in kwargs:
|
|
return _orig_get_tensor_details(self, tensor_index, *args, **kwargs)
|
|
try:
|
|
return _orig_get_tensor_details(self, tensor_index)
|
|
except TypeError:
|
|
return _orig_get_tensor_details(self, tensor_index, 0)
|
|
|
|
tf.lite.Interpreter._get_tensor_details = _get_tensor_details_compat # type: ignore
|
|
tf.lite.Interpreter._ktc_compat_patched = True # type: ignore
|
|
except Exception:
|
|
pass
|
|
|
|
output_dir = os.environ.get("KTC_WORKDIR") or os.environ.get("KTC_OUTPUT_DIR") or "/tmp"
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
output_path = str(Path(output_dir) / "tflite_converted.onnx")
|
|
|
|
return tflite2onnx.main(tflite_path, output_path, add_transpose_for_channel_last_first_issue=not release_mode, bottom_nodes_name=bottom_nodes)
|
|
|
|
def torch_exported_onnx_flow(
|
|
m: onnx.ModelProto, disable_fuse_bn=False
|
|
) -> onnx.ModelProto:
|
|
if disable_fuse_bn:
|
|
print("WRANING: disable_fuse_bn is not available in current conda environment.")
|
|
return kneronnxopt.optimize(m)
|
|
|
|
|
|
def onnx2onnx_flow(
|
|
m,
|
|
disable_fuse_bn=False,
|
|
bgr=False,
|
|
norm=False,
|
|
rgba2yynn=False,
|
|
eliminate_tail=False,
|
|
opt_matmul=False,
|
|
opt_720=False,
|
|
duplicate_shared_weights=False,
|
|
):
|
|
print("Using kneronnxopt.optimize as the optimizer.")
|
|
if disable_fuse_bn:
|
|
print("WRANING: disable_fuse_bn is not available in current conda environment.")
|
|
if bgr:
|
|
print("WRANING: bgr is not available in current conda environment.")
|
|
if norm:
|
|
print("WRANING: norm is not available in current conda environment.")
|
|
if rgba2yynn:
|
|
print("WRANING: rgba2yynn is not available in current conda environment.")
|
|
if eliminate_tail:
|
|
print("WRANING: eliminate_tail is not available in current conda environment.")
|
|
if opt_720:
|
|
print("WRANING: opt_720 is not available in current conda environment.")
|
|
return kneronnxopt.optimize(m, duplicate_shared_weights=2 if duplicate_shared_weights else 1, opt_matmul=opt_matmul)
|