import argparse import glob import csv import numpy as np import matplotlib.pyplot as plt from tools import helper import onnx_vs_onnx as onnx_tester def compare_results(results_a, results_b): """compare onnx model inference results calculate basic statistical values results: results from inference multiple times returns: list of basic statistical values """ # input results data can be of nonuniform shape # get flatten data to compare ra_flat = helper.flatten_with_depth(results_a, 0) rb_flat = helper.flatten_with_depth(results_b, 0) shape_a = [item[1] for item in ra_flat] shape_b = [item[1] for item in rb_flat] assert shape_a == shape_b, "two results data shape doesn't match" ra_raw = [item[0] for item in ra_flat] rb_raw = [item[0] for item in rb_flat] # the statistical values max_rel_diff = ( 0 # defined to be max( { abs(diff)/max(abs(ra), abs(rb) ) } ) ) max_abs_diff = 0 # defined to be max( { abs(ra-rb) } ) mean_rel_diff = 0 mean_abs_diff = 0 std_rel_diff = 0 std_abs_diff = 0 acc_with_diff_precision = [] rel_diff = [] abs_diff_percentiles = [] # rel_diff percentiles rel_diff_percentiles = [] # abs_diff precentiles raw_diff = [ra_raw[i] - rb_raw[i] for i in range(len(ra_raw))] abs_diff = [abs(num) for num in raw_diff] for i in range(len(ra_raw)): divider = max([abs(ra_raw[i]), abs(rb_raw[i])]) val = abs_diff[i] / divider if divider != 0 else 0 rel_diff.append(val) max_rel_diff = max(rel_diff) max_abs_diff = max(abs_diff) mean_rel_diff = np.average(rel_diff) mean_abs_diff = np.average(abs_diff) std_rel_diff = np.std(rel_diff) std_abs_diff = np.std(abs_diff) # calculate accuracy with different precison for digit in range(8): correct = 0 for i in range(len(ra_raw)): if format(ra_raw[i], "." + str(digit) + "f") == format( rb_raw[i], "." + str(digit) + "f" ): correct += 1 acc_with_diff_precision.append( [digit, float(format(correct / len(ra_raw), ".3f"))] ) # analyze rel_diff distribution rel_diff.sort() abs_diff.sort() for i in range(20): rel_diff_percentiles.append( ["{}%".format(i * 5), rel_diff[int((i / 20) * len(rel_diff))]] ) abs_diff_percentiles.append( ["{}%".format(i * 5), abs_diff[int((i / 20) * len(abs_diff))]] ) results = [ ["max_rel_diff", max_rel_diff], ["max_abs_diff", max_abs_diff], ["mean_rel_diff", mean_rel_diff], ["mean_abs_diff", mean_abs_diff], ["std_rel_diff", std_rel_diff], ["std_abs_diff", std_abs_diff], ["acc_with_diff_precision", acc_with_diff_precision], ["rel_diff_percentiles", rel_diff_percentiles], ["abs_diff_percentiles", abs_diff_percentiles], ] return results if __name__ == "__main__": parser = argparse.ArgumentParser( description="test model optimization results" ) parser.add_argument( "dir", type=str, help="the directory that stores onnx models" ) parser.add_argument( "ending1", type=str, help="model file name ending(eg, .onnx)" ) parser.add_argument( "ending2", type=str, help="opt model file name ending(eg. _opt.onnx)" ) parser.add_argument("out_file", type=str, help="output csv file name") parser.add_argument("-p", "--plot", default="N", help="get plots (Y/N)") parser.add_argument( "-i", "--iter_times", default=10, type=int, help="inference times" ) args = parser.parse_args() old_models_paths = glob.glob(args.dir + "*" + args.ending1) new_models_paths = glob.glob(args.dir + "*" + args.ending2) stats_table = [ [ "Model", "max_rel_diff", "max_abs_diff", "mean_rel_diff", "mean_abs_diff", "std_rel_diff", "std_abs_diff", "acc_with_diff_precision", "rel_diff_percentiles", "abs_diff_percentiles", ] ] for new_model_path in new_models_paths: old_model_path = new_model_path[: -len(args.ending2)] + args.ending1 if old_model_path not in old_models_paths: continue # run inference results_a, results_b = onnx_tester.onnx_model_results( old_model_path, new_model_path, total_times=args.iter_times ) # compare inference results comparision = compare_results(results_a, results_b) new_line = [old_model_path.split("/")[-1]] for item in comparision: new_line.append(item[1]) stats_table.append(new_line) # try to read existing file old_stats_table = [] try: old_file = open(args.out_file, "r") reader = csv.reader(old_file) old_header = reader.__next__() for row in reader: old_stats_table.append(row) old_file.close() except Exception: pass # compare and merge possible old stat data file with new stat data file header = stats_table[0] stats_table = stats_table[1:] new_model_names = set([item[0] for item in stats_table]) for row in old_stats_table: if row[0] not in new_model_names: stats_table.append(row) stats_table.insert(0, header) # write a new stat data file, overwrite old file new_file = open(args.out_file, "w", newline="") writer = csv.writer(new_file) for row in stats_table: writer.writerow(row) new_file.close() # make some plots if args.plot == "Y": if len(stats_table) < 2: exit(0) sample_table = ( stats_table[1:] if len(stats_table) < 6 else stats_table[1:6] ) max_rel_diffs = [round(float(item[1]), 2) for item in stats_table[1:]] plt.hist(max_rel_diffs, bins=15) plt.title("Max Relavtive Difference Histogram") plt.xlabel("Max Relative Difference") plt.ylabel("Counts") plt.savefig("max_rel_diff_hist.png") plt.close() max_abs_diffs = [round(float(item[2]), 2) for item in stats_table[1:]] plt.hist(max_abs_diffs, bins=15) plt.title("Max Absolute Difference Histogram") plt.xlabel("Max Absolute Difference") plt.ylabel("Counts") plt.savefig("max_abs_diff_hist.png") plt.close() for line in sample_table: model_name = line[0] percentiles = line[-2] x = [ round(i * (1 / len(percentiles)), 2) for i in range(len(percentiles)) ] y = [ele[1] for ele in percentiles] plt.plot(x, y, label=model_name) plt.title("Rel_diff Percentiles of Raw and Optimized Models") plt.xlabel("percentage") plt.ylabel("relative difference") plt.legend() plt.savefig("rel_diff_percentiles.png") plt.close() for line in sample_table: model_name = line[0] percentiles = line[-1] x = [ round(i * (1 / len(percentiles)), 2) for i in range(len(percentiles)) ] y = [ele[1] for ele in percentiles] plt.plot(x, y, label=model_name) plt.title("Abs_diff Percentiles of Raw and Optimized Models") plt.xlabel("percentage") plt.ylabel("absolute difference") plt.legend() plt.savefig("abs_diff_percentiles.png") plt.close() for line in sample_table: model_name = line[0] accuracies = line[-3] x = [acc[0] for acc in accuracies] y = [acc[1] for acc in accuracies] plt.plot(x, y, label=model_name) plt.title("Accuracies with Different Precisions") plt.xlabel("Decimals") plt.ylabel("Precision") plt.legend() plt.savefig("precisions.png") plt.close()