2026-03-11 16:13:59 +08:00

249 lines
7.9 KiB
Python

import argparse
import glob
import csv
import numpy as np
import matplotlib.pyplot as plt
from tools import helper
import onnx_vs_onnx as onnx_tester
def compare_results(results_a, results_b):
"""compare onnx model inference results
calculate basic statistical values
results: results from inference multiple times
returns: list of basic statistical values
"""
# input results data can be of nonuniform shape
# get flatten data to compare
ra_flat = helper.flatten_with_depth(results_a, 0)
rb_flat = helper.flatten_with_depth(results_b, 0)
shape_a = [item[1] for item in ra_flat]
shape_b = [item[1] for item in rb_flat]
assert shape_a == shape_b, "two results data shape doesn't match"
ra_raw = [item[0] for item in ra_flat]
rb_raw = [item[0] for item in rb_flat]
# the statistical values
max_rel_diff = (
0 # defined to be max( { abs(diff)/max(abs(ra), abs(rb) ) } )
)
max_abs_diff = 0 # defined to be max( { abs(ra-rb) } )
mean_rel_diff = 0
mean_abs_diff = 0
std_rel_diff = 0
std_abs_diff = 0
acc_with_diff_precision = []
rel_diff = []
abs_diff_percentiles = [] # rel_diff percentiles
rel_diff_percentiles = [] # abs_diff precentiles
raw_diff = [ra_raw[i] - rb_raw[i] for i in range(len(ra_raw))]
abs_diff = [abs(num) for num in raw_diff]
for i in range(len(ra_raw)):
divider = max([abs(ra_raw[i]), abs(rb_raw[i])])
val = abs_diff[i] / divider if divider != 0 else 0
rel_diff.append(val)
max_rel_diff = max(rel_diff)
max_abs_diff = max(abs_diff)
mean_rel_diff = np.average(rel_diff)
mean_abs_diff = np.average(abs_diff)
std_rel_diff = np.std(rel_diff)
std_abs_diff = np.std(abs_diff)
# calculate accuracy with different precison
for digit in range(8):
correct = 0
for i in range(len(ra_raw)):
if format(ra_raw[i], "." + str(digit) + "f") == format(
rb_raw[i], "." + str(digit) + "f"
):
correct += 1
acc_with_diff_precision.append(
[digit, float(format(correct / len(ra_raw), ".3f"))]
)
# analyze rel_diff distribution
rel_diff.sort()
abs_diff.sort()
for i in range(20):
rel_diff_percentiles.append(
["{}%".format(i * 5), rel_diff[int((i / 20) * len(rel_diff))]]
)
abs_diff_percentiles.append(
["{}%".format(i * 5), abs_diff[int((i / 20) * len(abs_diff))]]
)
results = [
["max_rel_diff", max_rel_diff],
["max_abs_diff", max_abs_diff],
["mean_rel_diff", mean_rel_diff],
["mean_abs_diff", mean_abs_diff],
["std_rel_diff", std_rel_diff],
["std_abs_diff", std_abs_diff],
["acc_with_diff_precision", acc_with_diff_precision],
["rel_diff_percentiles", rel_diff_percentiles],
["abs_diff_percentiles", abs_diff_percentiles],
]
return results
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="test model optimization results"
)
parser.add_argument(
"dir", type=str, help="the directory that stores onnx models"
)
parser.add_argument(
"ending1", type=str, help="model file name ending(eg, .onnx)"
)
parser.add_argument(
"ending2", type=str, help="opt model file name ending(eg. _opt.onnx)"
)
parser.add_argument("out_file", type=str, help="output csv file name")
parser.add_argument("-p", "--plot", default="N", help="get plots (Y/N)")
parser.add_argument(
"-i", "--iter_times", default=10, type=int, help="inference times"
)
args = parser.parse_args()
old_models_paths = glob.glob(args.dir + "*" + args.ending1)
new_models_paths = glob.glob(args.dir + "*" + args.ending2)
stats_table = [
[
"Model",
"max_rel_diff",
"max_abs_diff",
"mean_rel_diff",
"mean_abs_diff",
"std_rel_diff",
"std_abs_diff",
"acc_with_diff_precision",
"rel_diff_percentiles",
"abs_diff_percentiles",
]
]
for new_model_path in new_models_paths:
old_model_path = new_model_path[: -len(args.ending2)] + args.ending1
if old_model_path not in old_models_paths:
continue
# run inference
results_a, results_b = onnx_tester.onnx_model_results(
old_model_path, new_model_path, total_times=args.iter_times
)
# compare inference results
comparision = compare_results(results_a, results_b)
new_line = [old_model_path.split("/")[-1]]
for item in comparision:
new_line.append(item[1])
stats_table.append(new_line)
# try to read existing file
old_stats_table = []
try:
old_file = open(args.out_file, "r")
reader = csv.reader(old_file)
old_header = reader.__next__()
for row in reader:
old_stats_table.append(row)
old_file.close()
except Exception:
pass
# compare and merge possible old stat data file with new stat data file
header = stats_table[0]
stats_table = stats_table[1:]
new_model_names = set([item[0] for item in stats_table])
for row in old_stats_table:
if row[0] not in new_model_names:
stats_table.append(row)
stats_table.insert(0, header)
# write a new stat data file, overwrite old file
new_file = open(args.out_file, "w", newline="")
writer = csv.writer(new_file)
for row in stats_table:
writer.writerow(row)
new_file.close()
# make some plots
if args.plot == "Y":
if len(stats_table) < 2:
exit(0)
sample_table = (
stats_table[1:] if len(stats_table) < 6 else stats_table[1:6]
)
max_rel_diffs = [round(float(item[1]), 2) for item in stats_table[1:]]
plt.hist(max_rel_diffs, bins=15)
plt.title("Max Relavtive Difference Histogram")
plt.xlabel("Max Relative Difference")
plt.ylabel("Counts")
plt.savefig("max_rel_diff_hist.png")
plt.close()
max_abs_diffs = [round(float(item[2]), 2) for item in stats_table[1:]]
plt.hist(max_abs_diffs, bins=15)
plt.title("Max Absolute Difference Histogram")
plt.xlabel("Max Absolute Difference")
plt.ylabel("Counts")
plt.savefig("max_abs_diff_hist.png")
plt.close()
for line in sample_table:
model_name = line[0]
percentiles = line[-2]
x = [
round(i * (1 / len(percentiles)), 2)
for i in range(len(percentiles))
]
y = [ele[1] for ele in percentiles]
plt.plot(x, y, label=model_name)
plt.title("Rel_diff Percentiles of Raw and Optimized Models")
plt.xlabel("percentage")
plt.ylabel("relative difference")
plt.legend()
plt.savefig("rel_diff_percentiles.png")
plt.close()
for line in sample_table:
model_name = line[0]
percentiles = line[-1]
x = [
round(i * (1 / len(percentiles)), 2)
for i in range(len(percentiles))
]
y = [ele[1] for ele in percentiles]
plt.plot(x, y, label=model_name)
plt.title("Abs_diff Percentiles of Raw and Optimized Models")
plt.xlabel("percentage")
plt.ylabel("absolute difference")
plt.legend()
plt.savefig("abs_diff_percentiles.png")
plt.close()
for line in sample_table:
model_name = line[0]
accuracies = line[-3]
x = [acc[0] for acc in accuracies]
y = [acc[1] for acc in accuracies]
plt.plot(x, y, label=model_name)
plt.title("Accuracies with Different Precisions")
plt.xlabel("Decimals")
plt.ylabel("Precision")
plt.legend()
plt.savefig("precisions.png")
plt.close()