import onnx import onnx.helper # Make inputs and outputs input_value = onnx.helper.make_tensor_value_info( 'input', onnx.TensorProto.FLOAT, (1, 3, 32, 32) ) output_value = onnx.helper.make_tensor_value_info( 'reshape0', onnx.TensorProto.FLOAT, (1, 3, 1024) ) nodes = [] # Make a Shape node. shape_node = onnx.helper.make_node( 'Shape', ['input'], ['shape0'], name='shape0' ) nodes.append(shape_node) # Make Slice nodes. start_node0 = onnx.helper.make_node( 'Constant', [], ['start0'], name='start0', value=onnx.helper.make_tensor( name='start0', data_type=onnx.TensorProto.INT64, dims=[1], vals=[0] ) ) end_node0 = onnx.helper.make_node( 'Constant', [], ['end0'], name='end0', value=onnx.helper.make_tensor( name='end0', data_type=onnx.TensorProto.INT64, dims=[1], vals=[2] ) ) slice_node0 = onnx.helper.make_node( 'Slice', ['shape0', 'start0', 'end0'], ['slice0'], name='slice0' ) nodes.extend([start_node0, end_node0, slice_node0]) start_node1 = onnx.helper.make_node( 'Constant', [], ['start1'], name='start1', value=onnx.helper.make_tensor( name='start1', data_type=onnx.TensorProto.INT64, dims=[1], vals=[2] ) ) end_node1 = onnx.helper.make_node( 'Constant', [], ['end1'], name='end1', value=onnx.helper.make_tensor( name='end1', data_type=onnx.TensorProto.INT64, dims=[1], vals=[3] ) ) slice_node1 = onnx.helper.make_node( 'Slice', ['shape0', 'start1', 'end1'], ['slice1'], name='slice1' ) nodes.extend([start_node1, end_node1, slice_node1]) start_node2 = onnx.helper.make_node( 'Constant', [], ['start2'], name='start2', value=onnx.helper.make_tensor( name='start2', data_type=onnx.TensorProto.INT64, dims=[1], vals=[3] ) ) end_node2 = onnx.helper.make_node( 'Constant', [], ['end2'], name='end2', value=onnx.helper.make_tensor( name='end2', data_type=onnx.TensorProto.INT64, dims=[1], vals=[4] ) ) slice_node2 = onnx.helper.make_node( 'Slice', ['shape0', 'start2', 'end2'], ['slice2'], name='slice2' ) nodes.extend([start_node2, end_node2, slice_node2]) # Make an Mul node. mul_node = onnx.helper.make_node( 'Mul', ['slice1', 'slice2'], ['mul0'], name='mul0' ) nodes.append(mul_node) # Make a Concat node. concat_node = onnx.helper.make_node( 'Concat', ['slice0', 'mul0'], ['concat0'], name='concat0', axis=0 ) nodes.append(concat_node) # Make a Reshape node. reshape_node = onnx.helper.make_node( 'Reshape', ['input', 'concat0'], ['reshape0'], name='reshape0' ) nodes.append(reshape_node) # Make model. graph_def = onnx.helper.make_graph( nodes, 'test-model', [input_value], [output_value] ) model_def = onnx.helper.make_model(graph_def, producer_name='onnx-example', opset_imports=[onnx.helper.make_opsetid("", 12)]) onnx.save(model_def, 'test.onnx')