262 lines
9.1 KiB
Python
262 lines
9.1 KiB
Python
#! /usr/bin/env python3
|
|
|
|
|
|
import sys
|
|
from collections import defaultdict
|
|
from pathlib import Path
|
|
from docopt import docopt
|
|
import pprint
|
|
import pydoc
|
|
import onnx
|
|
import pickle
|
|
from functools import lru_cache
|
|
|
|
import numpy as np
|
|
|
|
####################################################################################################################
|
|
def is_pproper_case(p_onnx):
|
|
"""Check onnx in correct folder structure. Otherwise may override."""
|
|
t1 = p_onnx.parent.name == "input"
|
|
t2 = p_onnx.name == f"{p_onnx.parent.parent.name}.origin.onnx"
|
|
is_proper = t1 and t2
|
|
if not is_proper:
|
|
print(f"ERROR: given {p_onnx} does not sit in proper case structure.")
|
|
print(f"Expected: {p_onnx.parent.parent.name} / input / {p_onnx.parent.parent.name}.origin.onnx")
|
|
return is_proper
|
|
|
|
|
|
@lru_cache(maxsize=20)
|
|
def get_input_size(fn_onnx):
|
|
"""Load input nodes shape."""
|
|
o = onnx.load(fn_onnx)
|
|
# use list because need to keep order.
|
|
dims = [(node.name, [a.dim_value for a in node.type.tensor_type.shape.dim]) for node in o.graph.input]
|
|
for node_name, dim in dims:
|
|
print(f"{node_name}: {dim}")
|
|
return dims
|
|
|
|
|
|
def create_1_random_npy(p_npy, dim):
|
|
"""Create 1 random npy."""
|
|
int_cf = np.random.random(dim) - 0.5
|
|
np.save(p_npy, int_cf)
|
|
return int_cf
|
|
|
|
|
|
def create_random_npy_for_onnx(p_onnx, n=1):
|
|
"""Create inputs for this onnx."""
|
|
if not is_pproper_case(p_onnx):
|
|
print(f"WARNING: given {p_onnx} does not sit in proper case structure.")
|
|
return
|
|
|
|
print(f"Create input.npy for {p_onnx.name}")
|
|
p_input = p_onnx.parent
|
|
p_knerex_input = p_input / "knerex_input"
|
|
if p_knerex_input.exists():
|
|
raise FileExistsError(f"Knerex input {p_knerex_input}* exists already! DELETE it if need to re-generate")
|
|
p_pkl = p_input / "np_in.pkl"
|
|
if p_pkl.exists():
|
|
raise FileExistsError(f"Overall np_in.pkl {p_pkl} exists already! DELETE it if need to re-generate")
|
|
|
|
names = [f"random_{i:05d}" for i in range(n)]
|
|
names[0] = "test_input"
|
|
|
|
p_npy_folder = []
|
|
p_txt_folder = []
|
|
for i_input, (node_name, dim) in enumerate(get_input_size(p_onnx)):
|
|
p_npy_folder.append(p_input / f"knerex_input_{i_input}" if i_input > 0 else p_input / "knerex_input")
|
|
p_txt_folder.append(p_input / f"knerex_txt_{i_input}" if i_input > 0 else p_input / "knerex_txt")
|
|
p_npy_folder[-1].mkdir(exist_ok=True, parents=True)
|
|
p_txt_folder[-1].mkdir(exist_ok=True, parents=True)
|
|
|
|
d_npy_1 = defaultdict(list)
|
|
for name in names:
|
|
for i_input, (node_name, dim) in enumerate(get_input_size(p_onnx)):
|
|
# create 1 input per input node.
|
|
|
|
p_npy = p_npy_folder[i_input] / f"{name}.npy"
|
|
p_txt = p_txt_folder[i_input] / f"{name}.txt"
|
|
|
|
int_cf = create_1_random_npy(p_npy, dim)
|
|
print(f" - Save {i_input}th input (\"{node_name}\") with dimension: {dim}. max: {int_cf.max():.4f}, min: {int_cf.min():.4f}, variance: {int_cf.var():.2f}")
|
|
np.savetxt(p_txt, int_cf.ravel(), fmt="%.8f")
|
|
d_npy_1[node_name].append(int_cf)
|
|
|
|
with open(p_pkl, "wb") as f_pkl:
|
|
pickle.dump(d_npy_1, f_pkl, protocol=pickle.HIGHEST_PROTOCOL)
|
|
print(f"Saved overall np_in.pkl to {p_pkl}")
|
|
|
|
####################################################################################################################
|
|
|
|
|
|
def display_with_pprint_pager(data, compact=True, width=120, depth=None):
|
|
"""Print data with pager.
|
|
|
|
Args:
|
|
data: 要显示的数据
|
|
compact: 是否使用紧凑模式显示
|
|
width: 每行的最大宽度
|
|
depth: 控制嵌套结构的显示深度,None表示无限制。设置为1可以使列表在同一行显示
|
|
"""
|
|
# 创建一个 pprint 的 PrettyPrinter 对象
|
|
pp = pprint.PrettyPrinter(compact=compact, width=width, depth=depth)
|
|
|
|
# 使用 pprint 将数据格式化为字符串
|
|
formatted_data = pp.pformat(data)
|
|
|
|
# 使用 pydoc.pager 分页显示格式化后的数据
|
|
pydoc.pager(formatted_data)
|
|
|
|
|
|
def parse_bie(path_to_bie, platform):
|
|
"""Parse bie and print its contents."""
|
|
fn_bie = Path(path_to_bie)
|
|
if not fn_bie.exists():
|
|
raise FileNotFoundError(f"{fn_bie} does not exists!")
|
|
from sys_flow_v2.flow_utils import get_ioinfo_from_bie
|
|
dp_in, dp_out, dp_out_shape, ioinfo = get_ioinfo_from_bie(fn_bie, platform)
|
|
|
|
rslt = {}
|
|
rslt["dp_in"] = dp_in
|
|
rslt["dp_out"] = dp_out
|
|
rslt["dp_out shape"] = dp_out_shape
|
|
rslt["ioinfo"] = ioinfo
|
|
display_with_pprint_pager(rslt, depth=None)
|
|
|
|
|
|
def parse_nef(p_nef, hw_mode):
|
|
"""CLI wrapper for unpack_nefs."""
|
|
from sys_flow_v2.compiler_v2 import unpack_nefs as unpack
|
|
fn_maps, p_out = unpack(p_nef, hw_mode)
|
|
display_with_pprint_pager({k: v[1] for k, v in fn_maps.items()})
|
|
print(f"Remember to clean up {p_out}")
|
|
|
|
|
|
def get_lib_version(lib_name):
|
|
"""Get a installed library version."""
|
|
try:
|
|
lib = __import__(lib_name)
|
|
return lib.__version__, lib.__path__
|
|
except ImportError:
|
|
return f"⚠ {lib_name} not installed", "!!!!!!!!!!!!!!!!!!"
|
|
|
|
|
|
def check_libs():
|
|
"""Check library version for our interested libs."""
|
|
print("Python version:", sys.version.split("\n")[0])
|
|
|
|
def print_lib_info(lib_name):
|
|
print(f"checking \"{lib_name}\"")
|
|
v, p = get_lib_version(lib_name)
|
|
print(f" version: {v}")
|
|
print(f" path: {p}")
|
|
|
|
for lib in ["onnx", "numpy", "pandas", "matplotlib", "sys_flow", "sys_flow_v2"]:
|
|
print_lib_info(lib)
|
|
|
|
|
|
def convert_all_npy_to_txt(p_base, dry_run=False, replace=False):
|
|
"""Find all the .npy file and convert to same name (flattened) .txt."""
|
|
p_base = Path(p_base)
|
|
if not p_base.exists():
|
|
raise FileNotFoundError(f"{p_base} does not exist!!!")
|
|
|
|
p_npy_s = list(p_base.rglob("*.npy"))
|
|
for p_npy in p_npy_s:
|
|
print(p_npy)
|
|
if not dry_run:
|
|
p_txt = p_npy.with_suffix(".txt")
|
|
if not p_txt.exists():
|
|
np.savetxt(p_txt, np.load(p_npy).ravel(), fmt="%.6f")
|
|
if replace:
|
|
print(f"delete: {p_npy}")
|
|
p_npy.unlink()
|
|
|
|
|
|
def extract_nonqat(p_txt):
|
|
"""从文本文件中提取包含 'ptq' 或 'constraint' 的行.
|
|
|
|
参数:
|
|
p_txt (str): 要读取的文本文件路径。
|
|
|
|
返回:
|
|
tuple: 包含所有包含 'ptq' 或 'constraint' 的行。
|
|
"""
|
|
with open(p_txt, 'r') as file:
|
|
return [line.strip() for line in file if 'ptq' in line or 'constraint' in line]
|
|
|
|
|
|
def collect_nonqat(p_base, p_text, platform=730):
|
|
"""Routine of overall."""
|
|
radix_st_fns = list(p_base.rglob(f"knerex_{platform}/radix_status"))
|
|
print(f"Found {len(radix_st_fns)} radix_status files.")
|
|
|
|
info = {fn.parent.parent.parent.name: extract_nonqat(fn) for fn in radix_st_fns}
|
|
with open(p_text, 'w') as file:
|
|
for model_name, lines in info.items():
|
|
file.write(f"# {model_name}\n")
|
|
for line in lines:
|
|
file.write(line + '\n')
|
|
file.write("\n\n")
|
|
print(f"saved {len(lines)} lines for {model_name} to {p_text}")
|
|
|
|
|
|
def main():
|
|
"""CLI interface for kneron tools.
|
|
|
|
Usage:
|
|
kneron_cli run <path_to_json>
|
|
kneron_cli parse_bie <platform> <path_to_bie>
|
|
kneron_cli parse_nef <platform> <path_to_nef>
|
|
kneron_cli check_libs
|
|
kneron_cli npy2txt [--dry] [--replace] <path_with_npy>
|
|
kneron_cli check_qat [--platform=<platform>] <p_base> <p_result>
|
|
kneron_cli create_random_input_for_onnx <p_onnx> [--n=<n>]
|
|
kneron_cli (-h | --help)
|
|
kneron_cli --version
|
|
|
|
Options:
|
|
-d, --dry Dry run. Just print without actually executing.
|
|
--replace Delete .npy after generate .txt files.
|
|
--platform=<platform> Platform setting [default: 730].
|
|
-h --help Show this screen.
|
|
--version Show version.
|
|
"""
|
|
arguments = docopt(main.__doc__, version="kneron_cli 0.3")
|
|
# print(arguments)
|
|
|
|
if arguments["parse_bie"]:
|
|
parse_bie(arguments["<path_to_bie>"], arguments["<platform>"])
|
|
|
|
elif arguments["parse_nef"]:
|
|
parse_nef(arguments["<path_to_nef>"], arguments["<platform>"])
|
|
|
|
elif arguments["run"]:
|
|
raise NotImplementedError("subcommand `run` is not ready yet.")
|
|
|
|
elif arguments["check_libs"]:
|
|
check_libs()
|
|
|
|
elif arguments["npy2txt"]:
|
|
dry_run = arguments["--dry"]
|
|
replace = arguments["--replace"] and not dry_run
|
|
p_base = arguments["<path_with_npy>"]
|
|
convert_all_npy_to_txt(p_base, dry_run=dry_run, replace=replace)
|
|
|
|
elif arguments["check_qat"]:
|
|
p_base = Path(arguments['<p_base>'])
|
|
p_result = Path(arguments['<p_result>'])
|
|
platform = int(arguments['--platform'])
|
|
assert platform in [730], f"only support platform 730, but got {platform}"
|
|
|
|
collect_nonqat(p_base, p_result, platform)
|
|
|
|
elif arguments["create_random_input_for_onnx"]:
|
|
p_onnx = Path(arguments["<p_onnx>"])
|
|
n = int(arguments["--n"] or 1)
|
|
create_random_npy_for_onnx(p_onnx, n=n)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|