160 lines
3.1 KiB
Python
160 lines
3.1 KiB
Python
import onnx
|
|
import onnx.helper
|
|
|
|
|
|
# Make inputs and outputs
|
|
input_value = onnx.helper.make_tensor_value_info(
|
|
'input',
|
|
onnx.TensorProto.FLOAT,
|
|
(1, 3, 32, 32)
|
|
)
|
|
|
|
output_value = onnx.helper.make_tensor_value_info(
|
|
'reshape0',
|
|
onnx.TensorProto.FLOAT,
|
|
(1, 3, 1024)
|
|
)
|
|
|
|
nodes = []
|
|
# Make a Shape node.
|
|
shape_node = onnx.helper.make_node(
|
|
'Shape',
|
|
['input'],
|
|
['shape0'],
|
|
name='shape0'
|
|
)
|
|
nodes.append(shape_node)
|
|
# Make Slice nodes.
|
|
start_node0 = onnx.helper.make_node(
|
|
'Constant',
|
|
[],
|
|
['start0'],
|
|
name='start0',
|
|
value=onnx.helper.make_tensor(
|
|
name='start0',
|
|
data_type=onnx.TensorProto.INT64,
|
|
dims=[1],
|
|
vals=[0]
|
|
)
|
|
)
|
|
end_node0 = onnx.helper.make_node(
|
|
'Constant',
|
|
[],
|
|
['end0'],
|
|
name='end0',
|
|
value=onnx.helper.make_tensor(
|
|
name='end0',
|
|
data_type=onnx.TensorProto.INT64,
|
|
dims=[1],
|
|
vals=[2]
|
|
)
|
|
)
|
|
slice_node0 = onnx.helper.make_node(
|
|
'Slice',
|
|
['shape0', 'start0', 'end0'],
|
|
['slice0'],
|
|
name='slice0'
|
|
)
|
|
nodes.extend([start_node0, end_node0, slice_node0])
|
|
|
|
start_node1 = onnx.helper.make_node(
|
|
'Constant',
|
|
[],
|
|
['start1'],
|
|
name='start1',
|
|
value=onnx.helper.make_tensor(
|
|
name='start1',
|
|
data_type=onnx.TensorProto.INT64,
|
|
dims=[1],
|
|
vals=[2]
|
|
)
|
|
)
|
|
end_node1 = onnx.helper.make_node(
|
|
'Constant',
|
|
[],
|
|
['end1'],
|
|
name='end1',
|
|
value=onnx.helper.make_tensor(
|
|
name='end1',
|
|
data_type=onnx.TensorProto.INT64,
|
|
dims=[1],
|
|
vals=[3]
|
|
)
|
|
)
|
|
slice_node1 = onnx.helper.make_node(
|
|
'Slice',
|
|
['shape0', 'start1', 'end1'],
|
|
['slice1'],
|
|
name='slice1'
|
|
)
|
|
nodes.extend([start_node1, end_node1, slice_node1])
|
|
|
|
start_node2 = onnx.helper.make_node(
|
|
'Constant',
|
|
[],
|
|
['start2'],
|
|
name='start2',
|
|
value=onnx.helper.make_tensor(
|
|
name='start2',
|
|
data_type=onnx.TensorProto.INT64,
|
|
dims=[1],
|
|
vals=[3]
|
|
)
|
|
)
|
|
end_node2 = onnx.helper.make_node(
|
|
'Constant',
|
|
[],
|
|
['end2'],
|
|
name='end2',
|
|
value=onnx.helper.make_tensor(
|
|
name='end2',
|
|
data_type=onnx.TensorProto.INT64,
|
|
dims=[1],
|
|
vals=[4]
|
|
)
|
|
)
|
|
slice_node2 = onnx.helper.make_node(
|
|
'Slice',
|
|
['shape0', 'start2', 'end2'],
|
|
['slice2'],
|
|
name='slice2'
|
|
)
|
|
nodes.extend([start_node2, end_node2, slice_node2])
|
|
|
|
# Make an Mul node.
|
|
mul_node = onnx.helper.make_node(
|
|
'Mul',
|
|
['slice1', 'slice2'],
|
|
['mul0'],
|
|
name='mul0'
|
|
)
|
|
nodes.append(mul_node)
|
|
|
|
# Make a Concat node.
|
|
concat_node = onnx.helper.make_node(
|
|
'Concat',
|
|
['slice0', 'mul0'],
|
|
['concat0'],
|
|
name='concat0',
|
|
axis=0
|
|
)
|
|
nodes.append(concat_node)
|
|
|
|
# Make a Reshape node.
|
|
reshape_node = onnx.helper.make_node(
|
|
'Reshape',
|
|
['input', 'concat0'],
|
|
['reshape0'],
|
|
name='reshape0'
|
|
)
|
|
nodes.append(reshape_node)
|
|
|
|
# Make model.
|
|
graph_def = onnx.helper.make_graph(
|
|
nodes,
|
|
'test-model',
|
|
[input_value],
|
|
[output_value]
|
|
)
|
|
model_def = onnx.helper.make_model(graph_def, producer_name='onnx-example', opset_imports=[onnx.helper.make_opsetid("", 12)])
|
|
onnx.save(model_def, 'test.onnx') |