import onnx import kneronnxopt import os import sys from pathlib import Path # Create a dummy function to print a message when the function is not defined in the current environment. def print_function_not_defined_message(*args, **kwargs): print( "The function is not defined in the current environment. Please try switching conda environment using `conda activate base`." ) return None def kera2onnx_flow(keras_model_path: str, optimize: int = 0, input_shape=None) -> onnx.ModelProto: converter_root = Path(__file__).resolve().parents[1] / "libs" / "ONNX_Convertor" / "keras-onnx" if not converter_root.exists(): raise FileNotFoundError(f"keras-onnx converter not found: {converter_root}") converter_path = str(converter_root) if converter_path not in sys.path: sys.path.insert(0, converter_path) try: import onnx_keras # type: ignore except Exception as exc: raise RuntimeError( "kera2onnx_flow requires the keras-onnx converter and its dependencies (keras/tensorflow)." ) from exc onnx_keras.set_duplicate_weights(True) converter = onnx_keras.frontend.KerasFrontend() converter.loadFromFile(keras_model_path) return converter.convertToOnnx(optimize, input_shape) def caffe2onnx_flow(caffe_prototxt_path: str, caffe_model_path: str) -> onnx.ModelProto: converter_root = Path(__file__).resolve().parents[1] / "libs" / "ONNX_Convertor" / "caffe-onnx" if not converter_root.exists(): raise FileNotFoundError(f"caffe-onnx converter not found: {converter_root}") converter_path = str(converter_root) if converter_path not in sys.path: sys.path.insert(0, converter_path) try: import onnx_caffe # type: ignore except Exception as exc: raise RuntimeError( "caffe2onnx_flow requires the caffe-onnx converter and its dependencies (caffe)." ) from exc converter = onnx_caffe.frontend.CaffeFrontend() converter.loadFromFile(caffe_prototxt_path, caffe_model_path) return converter.convertToOnnx() def tflite2onnx_flow(tflite_path: str, release_mode: bool = True, bottom_nodes=None) -> onnx.ModelProto: if bottom_nodes is None: bottom_nodes = [] converter_root = Path(__file__).resolve().parents[1] / "libs" / "ONNX_Convertor" / "tflite-onnx" / "onnx_tflite" if not converter_root.exists(): raise FileNotFoundError(f"tflite2onnx converter not found: {converter_root}") converter_path = str(converter_root) if converter_path not in sys.path: sys.path.insert(0, converter_path) try: import tflite2onnx # type: ignore except Exception as exc: raise RuntimeError( "tflite2onnx_flow requires the tflite-onnx converter and its dependencies (tensorflow)." ) from exc # Compatibility: newer TF requires subgraph_index in _get_tensor_details. try: import tensorflow as tf # type: ignore if not getattr(tf.lite.Interpreter, "_ktc_compat_patched", False): _orig_get_tensor_details = tf.lite.Interpreter._get_tensor_details def _get_tensor_details_compat(self, tensor_index, *args, **kwargs): if args or "subgraph_index" in kwargs: return _orig_get_tensor_details(self, tensor_index, *args, **kwargs) try: return _orig_get_tensor_details(self, tensor_index) except TypeError: return _orig_get_tensor_details(self, tensor_index, 0) tf.lite.Interpreter._get_tensor_details = _get_tensor_details_compat # type: ignore tf.lite.Interpreter._ktc_compat_patched = True # type: ignore except Exception: pass output_dir = os.environ.get("KTC_WORKDIR") or os.environ.get("KTC_OUTPUT_DIR") or "/tmp" os.makedirs(output_dir, exist_ok=True) output_path = str(Path(output_dir) / "tflite_converted.onnx") return tflite2onnx.main(tflite_path, output_path, add_transpose_for_channel_last_first_issue=not release_mode, bottom_nodes_name=bottom_nodes) def torch_exported_onnx_flow( m: onnx.ModelProto, disable_fuse_bn=False ) -> onnx.ModelProto: if disable_fuse_bn: print("WRANING: disable_fuse_bn is not available in current conda environment.") return kneronnxopt.optimize(m) def onnx2onnx_flow( m, disable_fuse_bn=False, bgr=False, norm=False, rgba2yynn=False, eliminate_tail=False, opt_matmul=False, opt_720=False, duplicate_shared_weights=False, ): print("Using kneronnxopt.optimize as the optimizer.") if disable_fuse_bn: print("WRANING: disable_fuse_bn is not available in current conda environment.") if bgr: print("WRANING: bgr is not available in current conda environment.") if norm: print("WRANING: norm is not available in current conda environment.") if rgba2yynn: print("WRANING: rgba2yynn is not available in current conda environment.") if eliminate_tail: print("WRANING: eliminate_tail is not available in current conda environment.") if opt_720: print("WRANING: opt_720 is not available in current conda environment.") return kneronnxopt.optimize(m, duplicate_shared_weights=2 if duplicate_shared_weights else 1, opt_matmul=opt_matmul)