16 lines
552 B
Python
16 lines
552 B
Python
import onnx
|
|
from typing import Sequence, Mapping
|
|
|
|
def get_shape_from_value_info(value) -> Sequence:
|
|
"""Get shape from a value info.
|
|
:param value: the value_info proto\\
|
|
:return: list of the shape
|
|
"""
|
|
return [d.dim_value for d in value.type.tensor_type.shape.dim]
|
|
|
|
def extract_input_from_onnx(onnx_path : str) -> Mapping [str, Sequence]:
|
|
# TODO: If no onnx found or onnx failed to load, throw excetpion.
|
|
m = onnx.load(onnx_path)
|
|
inputs = { v.name : get_shape_from_value_info(v) for v in m.graph.input }
|
|
return inputs
|