222 lines
7.0 KiB
Python
222 lines
7.0 KiB
Python
import onnx
|
|
import argparse
|
|
import glob
|
|
import csv
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
|
|
from tools import helper
|
|
import onnx_vs_onnx as onnx_tester
|
|
|
|
def compare_results(results_a, results_b):
|
|
""" compare onnx model inference results
|
|
calculate basic statistical values
|
|
results: results from inference multiple times
|
|
returns: list of basic statistical values
|
|
"""
|
|
# input results data can be of nonuniform shape
|
|
# get flatten data to compare
|
|
ra_flat = helper.flatten_with_depth(results_a, 0)
|
|
rb_flat = helper.flatten_with_depth(results_b, 0)
|
|
shape_a = [item[1] for item in ra_flat]
|
|
shape_b = [item[1] for item in rb_flat]
|
|
assert shape_a == shape_b, 'two results data shape doesn\'t match'
|
|
ra_raw = [item[0] for item in ra_flat]
|
|
rb_raw = [item[0] for item in rb_flat]
|
|
|
|
# the statistical values
|
|
max_rel_diff = 0 # defined to be max( { abs(diff)/max(abs(ra), abs(rb) ) } )
|
|
max_abs_diff = 0 # defined to be max( { abs(ra-rb) } )
|
|
mean_rel_diff = 0
|
|
mean_abs_diff = 0
|
|
std_rel_diff = 0
|
|
std_abs_diff = 0
|
|
acc_with_diff_precision = []
|
|
rel_diff = []
|
|
abs_diff_percentiles = [] # rel_diff percentiles
|
|
rel_diff_percentiles = [] # abs_diff precentiles
|
|
|
|
raw_diff = [ra_raw[i]-rb_raw[i] for i in range(len(ra_raw))]
|
|
abs_diff = [abs(num) for num in raw_diff]
|
|
for i in range(len(ra_raw)):
|
|
divider = max([abs(ra_raw[i]), abs(rb_raw[i])])
|
|
val = abs_diff[i]/divider if divider != 0 else 0
|
|
rel_diff.append(val)
|
|
|
|
max_rel_diff = max(rel_diff)
|
|
max_abs_diff = max(abs_diff)
|
|
mean_rel_diff = np.average(rel_diff)
|
|
mean_abs_diff = np.average(abs_diff)
|
|
std_rel_diff = np.std(rel_diff)
|
|
std_abs_diff = np.std(abs_diff)
|
|
|
|
# calculate accuracy with different precison
|
|
for digit in range(8):
|
|
correct = 0
|
|
for i in range(len(ra_raw)):
|
|
if format(ra_raw[i], '.'+str(digit)+'f')\
|
|
== format(rb_raw[i], '.'+str(digit)+'f'):
|
|
correct += 1
|
|
acc_with_diff_precision.append([digit, float(format(correct/len(ra_raw), '.3f'))])
|
|
|
|
# analyze rel_diff distribution
|
|
rel_diff.sort()
|
|
abs_diff.sort()
|
|
for i in range(20):
|
|
rel_diff_percentiles.append(['{}%'.format(i*5), rel_diff[int((i/20)*len(rel_diff))]])
|
|
abs_diff_percentiles.append(['{}%'.format(i*5), abs_diff[int((i/20)*len(abs_diff))]])
|
|
|
|
results = [
|
|
['max_rel_diff', max_rel_diff],
|
|
['max_abs_diff', max_abs_diff],
|
|
['mean_rel_diff', mean_rel_diff],
|
|
['mean_abs_diff', mean_abs_diff],
|
|
['std_rel_diff', std_rel_diff],
|
|
['std_abs_diff', std_abs_diff],
|
|
['acc_with_diff_precision', acc_with_diff_precision],
|
|
['rel_diff_percentiles', rel_diff_percentiles],
|
|
['abs_diff_percentiles', abs_diff_percentiles]
|
|
]
|
|
|
|
return results
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser(description='test model optimization results')
|
|
|
|
parser.add_argument('dir', type=str, help='the directory that stores onnx models')
|
|
parser.add_argument('ending1', type=str, help='model file name ending(eg, .onnx)')
|
|
parser.add_argument('ending2', type=str, help='opt model file name ending(eg. _opt.onnx)')
|
|
parser.add_argument('out_file', type=str, help='output csv file name')
|
|
parser.add_argument('-p', '--plot', default='N', help='get plots (Y/N)')
|
|
parser.add_argument('-i', '--iter_times', default=10, type=int, help='inference times')
|
|
|
|
args = parser.parse_args()
|
|
|
|
old_models_paths = glob.glob(args.dir+'*'+args.ending1)
|
|
new_models_paths = glob.glob(args.dir+'*'+args.ending2)
|
|
|
|
stats_table = [[
|
|
'Model',
|
|
'max_rel_diff',
|
|
'max_abs_diff',
|
|
'mean_rel_diff',
|
|
'mean_abs_diff',
|
|
'std_rel_diff',
|
|
'std_abs_diff',
|
|
'acc_with_diff_precision',
|
|
'rel_diff_percentiles',
|
|
'abs_diff_percentiles'
|
|
]]
|
|
|
|
for new_model_path in new_models_paths:
|
|
old_model_path = new_model_path[:-len(args.ending2)] + args.ending1
|
|
if old_model_path not in old_models_paths:
|
|
continue
|
|
|
|
# run inference
|
|
results_a, results_b = onnx_tester.onnx_model_results(old_model_path, new_model_path, total_times=args.iter_times)
|
|
|
|
# compare inference results
|
|
comparision = compare_results(results_a, results_b)
|
|
|
|
new_line = [old_model_path.split('/')[-1]]
|
|
for item in comparision:
|
|
new_line.append(item[1])
|
|
|
|
stats_table.append(new_line)
|
|
|
|
# try to read existing file
|
|
old_stats_table = []
|
|
try:
|
|
old_file = open(args.out_file, 'r')
|
|
reader = csv.reader(old_file)
|
|
old_header = reader.__next__()
|
|
for row in reader:
|
|
old_stats_table.append(row)
|
|
old_file.close()
|
|
except:
|
|
pass
|
|
|
|
# compare and merge possible old stat data file with new stat data file
|
|
header = stats_table[0]
|
|
stats_table = stats_table[1:]
|
|
new_model_names = set([item[0] for item in stats_table])
|
|
for row in old_stats_table:
|
|
if row[0] not in new_model_names:
|
|
stats_table.append(row)
|
|
stats_table.insert(0, header)
|
|
|
|
# write a new stat data file, overwrite old file
|
|
new_file = open(args.out_file, 'w', newline='')
|
|
writer = csv.writer(new_file)
|
|
for row in stats_table:
|
|
writer.writerow(row)
|
|
new_file.close()
|
|
|
|
# make some plots
|
|
if args.plot == 'Y':
|
|
if len(stats_table) < 2:
|
|
exit(0)
|
|
|
|
sample_table = stats_table[1:] if len(stats_table) < 6 else stats_table[1:6]
|
|
|
|
max_rel_diffs = [round(float(item[1]), 2) for item in stats_table[1:]]
|
|
plt.hist(max_rel_diffs, bins=15)
|
|
plt.title('Max Relavtive Difference Histogram')
|
|
plt.xlabel('Max Relative Difference')
|
|
plt.ylabel('Counts')
|
|
plt.savefig('max_rel_diff_hist.png')
|
|
plt.close()
|
|
|
|
max_abs_diffs = [round(float(item[2]), 2) for item in stats_table[1:]]
|
|
plt.hist(max_abs_diffs, bins=15)
|
|
plt.title('Max Absolute Difference Histogram')
|
|
plt.xlabel('Max Absolute Difference')
|
|
plt.ylabel('Counts')
|
|
plt.savefig('max_abs_diff_hist.png')
|
|
plt.close()
|
|
|
|
for line in sample_table:
|
|
model_name = line[0]
|
|
percentiles = line[-2]
|
|
x = [round(i*(1/len(percentiles)), 2) for i in range(len(percentiles))]
|
|
y = [ele[1] for ele in percentiles]
|
|
plt.plot(x, y, label=model_name)
|
|
plt.title('Rel_diff Percentiles of Raw and Optimized Models')
|
|
plt.xlabel('percentage')
|
|
plt.ylabel('relative difference')
|
|
plt.legend()
|
|
plt.savefig('rel_diff_percentiles.png')
|
|
plt.close()
|
|
|
|
for line in sample_table:
|
|
model_name = line[0]
|
|
percentiles = line[-1]
|
|
x = [round(i*(1/len(percentiles)), 2) for i in range(len(percentiles))]
|
|
y = [ele[1] for ele in percentiles]
|
|
plt.plot(x, y, label=model_name)
|
|
plt.title('Abs_diff Percentiles of Raw and Optimized Models')
|
|
plt.xlabel('percentage')
|
|
plt.ylabel('absolute difference')
|
|
plt.legend()
|
|
plt.savefig('abs_diff_percentiles.png')
|
|
plt.close()
|
|
|
|
for line in sample_table:
|
|
model_name = line[0]
|
|
accuracies = line[-3]
|
|
x = [acc[0] for acc in accuracies]
|
|
y = [acc[1] for acc in accuracies]
|
|
plt.plot(x, y, label=model_name)
|
|
plt.title('Accuracies with Different Precisions')
|
|
plt.xlabel('Decimals')
|
|
plt.ylabel('Precision')
|
|
plt.legend()
|
|
plt.savefig('precisions.png')
|
|
plt.close()
|
|
|
|
|
|
|
|
|
|
|