218 lines
3.8 KiB
Python
218 lines
3.8 KiB
Python
import onnx
|
|
import onnx.helper
|
|
import numpy as np
|
|
|
|
|
|
# Make inputs and outputs
|
|
input_value = onnx.helper.make_tensor_value_info(
|
|
'input',
|
|
onnx.TensorProto.FLOAT,
|
|
(1, 8, 32, 32)
|
|
)
|
|
|
|
output_value = onnx.helper.make_tensor_value_info(
|
|
'output',
|
|
onnx.TensorProto.FLOAT,
|
|
(1, 4, 30, 10, 8)
|
|
)
|
|
|
|
nodes = []
|
|
|
|
# Make a AveragePool node.
|
|
AP_node = onnx.helper.make_node(
|
|
'AveragePool',
|
|
['input'],
|
|
['AP0'],
|
|
name='AP0',
|
|
kernel_shape = [3, 3]
|
|
)
|
|
|
|
nodes.append(AP_node)
|
|
|
|
# Make a MaxPool node.
|
|
MP_node = onnx.helper.make_node(
|
|
'MaxPool',
|
|
['AP0'],
|
|
['MP0'],
|
|
name='MP0',
|
|
kernel_shape = [3, 3]
|
|
)
|
|
|
|
nodes.append(MP_node)
|
|
|
|
# Make a Upsample node.
|
|
scales_node0 = onnx.helper.make_node(
|
|
'Constant',
|
|
[],
|
|
['scales0'],
|
|
name='scales0',
|
|
value=onnx.helper.make_tensor(
|
|
name='scales0',
|
|
data_type=onnx.TensorProto.FLOAT,
|
|
dims=[4],
|
|
vals=[1., 1., 2., 2.]
|
|
)
|
|
)
|
|
|
|
nodes.append(scales_node0)
|
|
|
|
US_node = onnx.helper.make_node(
|
|
'Upsample',
|
|
['MP0', 'scales0'],
|
|
['US0'],
|
|
name='Upsample0',
|
|
mode='nearest'
|
|
)
|
|
|
|
nodes.append(US_node)
|
|
|
|
# Make a Clip node.
|
|
Clip_node = onnx.helper.make_node(
|
|
'Clip',
|
|
['US0'],
|
|
['clip0'],
|
|
name='Clip0',
|
|
max=30.,
|
|
min=-30.
|
|
)
|
|
nodes.append(Clip_node)
|
|
|
|
# Make a Pad node.
|
|
Pad_node = onnx.helper.make_node(
|
|
'Pad',
|
|
['clip0'],
|
|
['pad0'],
|
|
name='Pad0',
|
|
pads=[0, 2, 0, 0, 0, 0, 0, 0],
|
|
value=0.
|
|
)
|
|
nodes.append(Pad_node)
|
|
|
|
# Make a Cast node.
|
|
Cast_node0 = onnx.helper.make_node(
|
|
'Cast',
|
|
['pad0'],
|
|
['cast0'],
|
|
name='Cast0',
|
|
to=onnx.TensorProto.INT64
|
|
)
|
|
nodes.append(Cast_node0)
|
|
|
|
# Make a Scatter node.
|
|
Scatter_node = onnx.helper.make_node(
|
|
'Scatter',
|
|
['pad0', 'cast0', 'pad0'],
|
|
['scatter0'],
|
|
name='Scatter0'
|
|
)
|
|
nodes.append(Scatter_node)
|
|
|
|
# Make a ArgMax node.
|
|
ArgMax_node = onnx.helper.make_node(
|
|
'ArgMax',
|
|
['scatter0'],
|
|
['AM0'],
|
|
name='argMax0'
|
|
)
|
|
nodes.append(ArgMax_node)
|
|
|
|
# Make a Cast node.
|
|
Cast_node1 = onnx.helper.make_node(
|
|
'Cast',
|
|
['AM0'],
|
|
['cast1'],
|
|
name='Cast1',
|
|
to=onnx.TensorProto.FLOAT
|
|
)
|
|
nodes.append(Cast_node1)
|
|
|
|
# Make a Dropout node.
|
|
Dropout_node = onnx.helper.make_node(
|
|
'Dropout',
|
|
['cast1'],
|
|
['Dropout0'],
|
|
name='Dropout0',
|
|
ratio=0.1
|
|
)
|
|
nodes.append(Dropout_node)
|
|
|
|
# Make a DepthToSpace node.
|
|
DTS_node = onnx.helper.make_node(
|
|
'DepthToSpace',
|
|
['US0'],
|
|
['DTS0'],
|
|
name='DTS0',
|
|
blocksize=2
|
|
)
|
|
nodes.append(DTS_node)
|
|
|
|
# Make a Slice node.
|
|
Slice_node = onnx.helper.make_node(
|
|
'Slice',
|
|
['US0'],
|
|
['slice0'],
|
|
name='Slice0',
|
|
ends=[1, 4, 30, 30],
|
|
starts=[0, 0, 0, 0]
|
|
)
|
|
|
|
nodes.append(Slice_node)
|
|
|
|
# Make a TopK node.
|
|
TopK_node = onnx.helper.make_node(
|
|
'TopK',
|
|
['slice0'],
|
|
['TopK0', 'indices'],
|
|
name='TopK',
|
|
k=10
|
|
)
|
|
|
|
nodes.append(TopK_node)
|
|
|
|
# Make a OneHot node.
|
|
depth_node0 = onnx.helper.make_node(
|
|
'Constant',
|
|
[],
|
|
['depth0'],
|
|
name='depth0',
|
|
value=onnx.helper.make_tensor(
|
|
name='depth0',
|
|
data_type=onnx.TensorProto.FLOAT,
|
|
dims=[1],
|
|
vals=[8]
|
|
)
|
|
)
|
|
nodes.append(depth_node0)
|
|
|
|
values_node0 = onnx.helper.make_node(
|
|
'Constant',
|
|
[],
|
|
['values0'],
|
|
name='values0',
|
|
value=onnx.helper.make_tensor(
|
|
name='values0',
|
|
data_type=onnx.TensorProto.FLOAT,
|
|
dims=[2],
|
|
vals=[1, 3]
|
|
)
|
|
)
|
|
|
|
nodes.append(values_node0)
|
|
|
|
OneHot_node = onnx.helper.make_node(
|
|
'OneHot',
|
|
['indices', 'depth0', 'values0'],
|
|
['output'],
|
|
name='OneHot0'
|
|
)
|
|
nodes.append(OneHot_node)
|
|
|
|
# Make model.
|
|
graph_def = onnx.helper.make_graph(
|
|
nodes,
|
|
'test-model',
|
|
[input_value],
|
|
[output_value]
|
|
)
|
|
model_def = onnx.helper.make_model(graph_def, producer_name='onnx-example', opset_imports=[onnx.helper.make_opsetid("", 9)])
|
|
onnx.save(model_def, 'update_test.onnx') |