2026-01-28 06:16:04 +00:00

262 lines
9.1 KiB
Python

#! /usr/bin/env python3
import sys
from collections import defaultdict
from pathlib import Path
from docopt import docopt
import pprint
import pydoc
import onnx
import pickle
from functools import lru_cache
import numpy as np
####################################################################################################################
def is_pproper_case(p_onnx):
"""Check onnx in correct folder structure. Otherwise may override."""
t1 = p_onnx.parent.name == "input"
t2 = p_onnx.name == f"{p_onnx.parent.parent.name}.origin.onnx"
is_proper = t1 and t2
if not is_proper:
print(f"ERROR: given {p_onnx} does not sit in proper case structure.")
print(f"Expected: {p_onnx.parent.parent.name} / input / {p_onnx.parent.parent.name}.origin.onnx")
return is_proper
@lru_cache(maxsize=20)
def get_input_size(fn_onnx):
"""Load input nodes shape."""
o = onnx.load(fn_onnx)
# use list because need to keep order.
dims = [(node.name, [a.dim_value for a in node.type.tensor_type.shape.dim]) for node in o.graph.input]
for node_name, dim in dims:
print(f"{node_name}: {dim}")
return dims
def create_1_random_npy(p_npy, dim):
"""Create 1 random npy."""
int_cf = np.random.random(dim) - 0.5
np.save(p_npy, int_cf)
return int_cf
def create_random_npy_for_onnx(p_onnx, n=1):
"""Create inputs for this onnx."""
if not is_pproper_case(p_onnx):
print(f"WARNING: given {p_onnx} does not sit in proper case structure.")
return
print(f"Create input.npy for {p_onnx.name}")
p_input = p_onnx.parent
p_knerex_input = p_input / "knerex_input"
if p_knerex_input.exists():
raise FileExistsError(f"Knerex input {p_knerex_input}* exists already! DELETE it if need to re-generate")
p_pkl = p_input / "np_in.pkl"
if p_pkl.exists():
raise FileExistsError(f"Overall np_in.pkl {p_pkl} exists already! DELETE it if need to re-generate")
names = [f"random_{i:05d}" for i in range(n)]
names[0] = "test_input"
p_npy_folder = []
p_txt_folder = []
for i_input, (node_name, dim) in enumerate(get_input_size(p_onnx)):
p_npy_folder.append(p_input / f"knerex_input_{i_input}" if i_input > 0 else p_input / "knerex_input")
p_txt_folder.append(p_input / f"knerex_txt_{i_input}" if i_input > 0 else p_input / "knerex_txt")
p_npy_folder[-1].mkdir(exist_ok=True, parents=True)
p_txt_folder[-1].mkdir(exist_ok=True, parents=True)
d_npy_1 = defaultdict(list)
for name in names:
for i_input, (node_name, dim) in enumerate(get_input_size(p_onnx)):
# create 1 input per input node.
p_npy = p_npy_folder[i_input] / f"{name}.npy"
p_txt = p_txt_folder[i_input] / f"{name}.txt"
int_cf = create_1_random_npy(p_npy, dim)
print(f" - Save {i_input}th input (\"{node_name}\") with dimension: {dim}. max: {int_cf.max():.4f}, min: {int_cf.min():.4f}, variance: {int_cf.var():.2f}")
np.savetxt(p_txt, int_cf.ravel(), fmt="%.8f")
d_npy_1[node_name].append(int_cf)
with open(p_pkl, "wb") as f_pkl:
pickle.dump(d_npy_1, f_pkl, protocol=pickle.HIGHEST_PROTOCOL)
print(f"Saved overall np_in.pkl to {p_pkl}")
####################################################################################################################
def display_with_pprint_pager(data, compact=True, width=120, depth=None):
"""Print data with pager.
Args:
data: 要显示的数据
compact: 是否使用紧凑模式显示
width: 每行的最大宽度
depth: 控制嵌套结构的显示深度,None表示无限制。设置为1可以使列表在同一行显示
"""
# 创建一个 pprint 的 PrettyPrinter 对象
pp = pprint.PrettyPrinter(compact=compact, width=width, depth=depth)
# 使用 pprint 将数据格式化为字符串
formatted_data = pp.pformat(data)
# 使用 pydoc.pager 分页显示格式化后的数据
pydoc.pager(formatted_data)
def parse_bie(path_to_bie, platform):
"""Parse bie and print its contents."""
fn_bie = Path(path_to_bie)
if not fn_bie.exists():
raise FileNotFoundError(f"{fn_bie} does not exists!")
from sys_flow_v2.flow_utils import get_ioinfo_from_bie
dp_in, dp_out, dp_out_shape, ioinfo = get_ioinfo_from_bie(fn_bie, platform)
rslt = {}
rslt["dp_in"] = dp_in
rslt["dp_out"] = dp_out
rslt["dp_out shape"] = dp_out_shape
rslt["ioinfo"] = ioinfo
display_with_pprint_pager(rslt, depth=None)
def parse_nef(p_nef, hw_mode):
"""CLI wrapper for unpack_nefs."""
from sys_flow_v2.compiler_v2 import unpack_nefs as unpack
fn_maps, p_out = unpack(p_nef, hw_mode)
display_with_pprint_pager({k: v[1] for k, v in fn_maps.items()})
print(f"Remember to clean up {p_out}")
def get_lib_version(lib_name):
"""Get a installed library version."""
try:
lib = __import__(lib_name)
return lib.__version__, lib.__path__
except ImportError:
return f"{lib_name} not installed", "!!!!!!!!!!!!!!!!!!"
def check_libs():
"""Check library version for our interested libs."""
print("Python version:", sys.version.split("\n")[0])
def print_lib_info(lib_name):
print(f"checking \"{lib_name}\"")
v, p = get_lib_version(lib_name)
print(f" version: {v}")
print(f" path: {p}")
for lib in ["onnx", "numpy", "pandas", "matplotlib", "sys_flow", "sys_flow_v2"]:
print_lib_info(lib)
def convert_all_npy_to_txt(p_base, dry_run=False, replace=False):
"""Find all the .npy file and convert to same name (flattened) .txt."""
p_base = Path(p_base)
if not p_base.exists():
raise FileNotFoundError(f"{p_base} does not exist!!!")
p_npy_s = list(p_base.rglob("*.npy"))
for p_npy in p_npy_s:
print(p_npy)
if not dry_run:
p_txt = p_npy.with_suffix(".txt")
if not p_txt.exists():
np.savetxt(p_txt, np.load(p_npy).ravel(), fmt="%.6f")
if replace:
print(f"delete: {p_npy}")
p_npy.unlink()
def extract_nonqat(p_txt):
"""从文本文件中提取包含 'ptq''constraint' 的行.
参数:
p_txt (str): 要读取的文本文件路径。
返回:
tuple: 包含所有包含 'ptq''constraint' 的行。
"""
with open(p_txt, 'r') as file:
return [line.strip() for line in file if 'ptq' in line or 'constraint' in line]
def collect_nonqat(p_base, p_text, platform=730):
"""Routine of overall."""
radix_st_fns = list(p_base.rglob(f"knerex_{platform}/radix_status"))
print(f"Found {len(radix_st_fns)} radix_status files.")
info = {fn.parent.parent.parent.name: extract_nonqat(fn) for fn in radix_st_fns}
with open(p_text, 'w') as file:
for model_name, lines in info.items():
file.write(f"# {model_name}\n")
for line in lines:
file.write(line + '\n')
file.write("\n\n")
print(f"saved {len(lines)} lines for {model_name} to {p_text}")
def main():
"""CLI interface for kneron tools.
Usage:
kneron_cli run <path_to_json>
kneron_cli parse_bie <platform> <path_to_bie>
kneron_cli parse_nef <platform> <path_to_nef>
kneron_cli check_libs
kneron_cli npy2txt [--dry] [--replace] <path_with_npy>
kneron_cli check_qat [--platform=<platform>] <p_base> <p_result>
kneron_cli create_random_input_for_onnx <p_onnx> [--n=<n>]
kneron_cli (-h | --help)
kneron_cli --version
Options:
-d, --dry Dry run. Just print without actually executing.
--replace Delete .npy after generate .txt files.
--platform=<platform> Platform setting [default: 730].
-h --help Show this screen.
--version Show version.
"""
arguments = docopt(main.__doc__, version="kneron_cli 0.3")
# print(arguments)
if arguments["parse_bie"]:
parse_bie(arguments["<path_to_bie>"], arguments["<platform>"])
elif arguments["parse_nef"]:
parse_nef(arguments["<path_to_nef>"], arguments["<platform>"])
elif arguments["run"]:
raise NotImplementedError("subcommand `run` is not ready yet.")
elif arguments["check_libs"]:
check_libs()
elif arguments["npy2txt"]:
dry_run = arguments["--dry"]
replace = arguments["--replace"] and not dry_run
p_base = arguments["<path_with_npy>"]
convert_all_npy_to_txt(p_base, dry_run=dry_run, replace=replace)
elif arguments["check_qat"]:
p_base = Path(arguments['<p_base>'])
p_result = Path(arguments['<p_result>'])
platform = int(arguments['--platform'])
assert platform in [730], f"only support platform 730, but got {platform}"
collect_nonqat(p_base, p_result, platform)
elif arguments["create_random_input_for_onnx"]:
p_onnx = Path(arguments["<p_onnx>"])
n = int(arguments["--n"] or 1)
create_random_npy_for_onnx(p_onnx, n=n)
if __name__ == "__main__":
main()