35 lines
1.2 KiB
Python
35 lines
1.2 KiB
Python
import onnx
|
|
import ktc.onnx_optimizer as kneron_opt
|
|
from onnx import helper
|
|
|
|
def replace_sigmoid_with_identity(model):
|
|
"""
|
|
Replaces all Sigmoid nodes with Identity nodes to maintain model integrity.
|
|
"""
|
|
for node in model.graph.node:
|
|
if node.op_type == "Sigmoid":
|
|
print(f"Replacing {node.name} with Identity")
|
|
identity_node = helper.make_node(
|
|
"Identity",
|
|
inputs=node.input,
|
|
outputs=node.output,
|
|
name=node.name + "_identity"
|
|
)
|
|
model.graph.node.extend([identity_node])
|
|
model.graph.node.remove(node)
|
|
|
|
return model
|
|
|
|
def process_onnx(input_onnx_path, output_onnx_path):
|
|
""" Replaces Sigmoid with Identity and saves the new model """
|
|
model = onnx.load(input_onnx_path)
|
|
model = replace_sigmoid_with_identity(model)
|
|
onnx.save(model, output_onnx_path)
|
|
print(f"Modified ONNX model saved to: {output_onnx_path}")
|
|
|
|
# 使用 Docker 掛載的路徑
|
|
input_onnx = "/workspace/yolov5/runs/train/exp24/weights/best_simplified.onnx"
|
|
output_onnx = "/workspace/yolov5/runs/train/exp24/weights/best_no_sigmoid.onnx"
|
|
|
|
process_onnx(input_onnx, output_onnx)
|