kneron_model_converter/libs/kneronnxopt/UnitTest/gen_constant_folding_test.py
2026-01-28 06:16:04 +00:00

160 lines
3.1 KiB
Python

import onnx
import onnx.helper
# Make inputs and outputs
input_value = onnx.helper.make_tensor_value_info(
'input',
onnx.TensorProto.FLOAT,
(1, 3, 32, 32)
)
output_value = onnx.helper.make_tensor_value_info(
'reshape0',
onnx.TensorProto.FLOAT,
(1, 3, 1024)
)
nodes = []
# Make a Shape node.
shape_node = onnx.helper.make_node(
'Shape',
['input'],
['shape0'],
name='shape0'
)
nodes.append(shape_node)
# Make Slice nodes.
start_node0 = onnx.helper.make_node(
'Constant',
[],
['start0'],
name='start0',
value=onnx.helper.make_tensor(
name='start0',
data_type=onnx.TensorProto.INT64,
dims=[1],
vals=[0]
)
)
end_node0 = onnx.helper.make_node(
'Constant',
[],
['end0'],
name='end0',
value=onnx.helper.make_tensor(
name='end0',
data_type=onnx.TensorProto.INT64,
dims=[1],
vals=[2]
)
)
slice_node0 = onnx.helper.make_node(
'Slice',
['shape0', 'start0', 'end0'],
['slice0'],
name='slice0'
)
nodes.extend([start_node0, end_node0, slice_node0])
start_node1 = onnx.helper.make_node(
'Constant',
[],
['start1'],
name='start1',
value=onnx.helper.make_tensor(
name='start1',
data_type=onnx.TensorProto.INT64,
dims=[1],
vals=[2]
)
)
end_node1 = onnx.helper.make_node(
'Constant',
[],
['end1'],
name='end1',
value=onnx.helper.make_tensor(
name='end1',
data_type=onnx.TensorProto.INT64,
dims=[1],
vals=[3]
)
)
slice_node1 = onnx.helper.make_node(
'Slice',
['shape0', 'start1', 'end1'],
['slice1'],
name='slice1'
)
nodes.extend([start_node1, end_node1, slice_node1])
start_node2 = onnx.helper.make_node(
'Constant',
[],
['start2'],
name='start2',
value=onnx.helper.make_tensor(
name='start2',
data_type=onnx.TensorProto.INT64,
dims=[1],
vals=[3]
)
)
end_node2 = onnx.helper.make_node(
'Constant',
[],
['end2'],
name='end2',
value=onnx.helper.make_tensor(
name='end2',
data_type=onnx.TensorProto.INT64,
dims=[1],
vals=[4]
)
)
slice_node2 = onnx.helper.make_node(
'Slice',
['shape0', 'start2', 'end2'],
['slice2'],
name='slice2'
)
nodes.extend([start_node2, end_node2, slice_node2])
# Make an Mul node.
mul_node = onnx.helper.make_node(
'Mul',
['slice1', 'slice2'],
['mul0'],
name='mul0'
)
nodes.append(mul_node)
# Make a Concat node.
concat_node = onnx.helper.make_node(
'Concat',
['slice0', 'mul0'],
['concat0'],
name='concat0',
axis=0
)
nodes.append(concat_node)
# Make a Reshape node.
reshape_node = onnx.helper.make_node(
'Reshape',
['input', 'concat0'],
['reshape0'],
name='reshape0'
)
nodes.append(reshape_node)
# Make model.
graph_def = onnx.helper.make_graph(
nodes,
'test-model',
[input_value],
[output_value]
)
model_def = onnx.helper.make_model(graph_def, producer_name='onnx-example', opset_imports=[onnx.helper.make_opsetid("", 12)])
onnx.save(model_def, 'test.onnx')