2026-01-28 06:16:04 +00:00

160 lines
6.0 KiB
Python

#!/usr/bin/env python3
import sys
import os
import shutil
import subprocess
import multiprocessing
import argparse
import re
import logging
import time
from utils import log
from utils import util
def find_model_path(case_dir, platform):
regex = re.compile('(.*kdp{}.*scaled\.onnx$)'.format(platform))
for root, dirs, files in os.walk(case_dir):
for file in files:
if regex.match(file):
model_path = os.path.join(root, file)
return model_path
return None
def find_compiler_config_path(case_dir, platform):
regex = re.compile('(.*config_{}\.json$)'.format(platform))
for root, dirs, files in os.walk(case_dir):
for file in files:
if regex.match(file):
model_path = os.path.join(root, file)
return model_path
return None
def arg_parse():
desc = 'Requirement:\n - Run "gen_config.py" to generate config.json before using this script.\n\n'
desc += 'Environment variables:\n - COMPILER_BIN_DIR: specify compiler bin dir\n\n'
desc += 'Sub-modules:\n - image_cut_search: 720|530|730|630|540\n'
parser = argparse.ArgumentParser(description=desc, formatter_class=argparse.RawTextHelpFormatter)
parser.add_argument("platform", help="HW platform <520|720|530|730|540>")
parser.add_argument("src_model", help="top directory includes models")
parser.add_argument("output_dir", help="log output directory.")
parser.add_argument("-d", "--debug", action="store_true", default=False, help="debug mode. Show more debug message")
return parser.parse_args()
def run_opt_compile(script, case_name, platform, model, config, log_dir, debug):
command = '{command} {platform} {model} {config} {output_dir} {debug}'.format(command=script,
platform=platform,
model=model,
config=config,
output_dir=log_dir,
debug=debug)
cmd_args = command.split()
start_tm = time.time()
o = subprocess.run(cmd_args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
end_tm = time.time()
log.logger.info('================={}=================='.format(case_name))
log.logger.info('run model : {} ...'.format(case_name))
log.logger.info('opt_wrapper : {}'.format(command))
if o.returncode != 0:
result[case_name] = 'FAILED [compilation failed]'
else:
result[case_name] = 'SUCCESS'
tm = (end_tm - start_tm) / 60 # minutes
result_tm[case_name] = "{:.2f} minutes".format(tm)
if o.stdout:
log.logger.info('\n\n' + o.stdout.decode())
if o.stderr:
log.logger.error('\n\n' + o.stderr.decode())
if __name__ == "__main__":
# get opt wrapper directory
opt_dir = os.path.dirname(os.path.abspath(__file__))
# get arguments
args = arg_parse()
platform = args.platform
src_model = args.src_model
output_dir = args.output_dir
debug = '-d' if args.debug else ''
log_level = logging.DEBUG if args.debug else logging.INFO
opt_compile = '{}/opt_compile.py'.format(opt_dir)
if output_dir and not util.is_file_name_valid(output_dir):
print('Please give a valid -o/OUTPUT_DIR. "{}" is invalid'.format(output_dir))
sys.exit(-1)
# set compiler binary path to environment
compiler_bin = util.get_compiler_bin(opt_dir)
os.environ['COMPILER_BIN'] = compiler_bin
# create logger
log_path = output_dir
log = log.Log(log_path, False)
log.logger.setLevel(log_level)
# run opt_compile: iterate all test cases under top case directory
result = multiprocessing.Manager().dict()
result_tm = multiprocessing.Manager().dict()
cases = util.get_case_name(src_model)
pool = multiprocessing.Pool()
for case_name in cases:
# step 1: check if test case exist
case_path = os.path.join(src_model, case_name)
if not util.check_file_exist(case_path):
log.logger.error('case directory for [{}] not exist'.format(case_name))
continue
# step 2: check if model (onnx) exist
model_path = find_model_path(case_path, platform)
if not util.check_file_exist(model_path):
result[case_name] = 'FAILED [no onnx model]'
log.logger.error('model for [{}] not exist'.format(case_name))
continue
# step 3: check if compiler config exist
config_path = find_compiler_config_path(case_path, platform)
if not util.check_file_exist(config_path):
result[case_name] = 'FAILED [no compiler config]'
log.logger.error('compiler config for [{}] not exist'.format(case_name))
continue
# step 4: run opt_compiler
log_dir = '{}/{}'.format(output_dir, case_name)
pool.apply_async(run_opt_compile,
args=(opt_compile, case_name, platform, model_path, config_path, log_dir, debug))
pool.close()
pool.join()
# show compile result
log.logger.info('=== Summary ===')
success_cnt = 0
failure_cnt = 0
for model, status in result.items():
if (status.find('SUCCESS') != -1):
success_cnt += 1
else:
failure_cnt += 1
log.logger.info('%-60s: %-30s %-12s' % (model, status, result_tm.get(model, 'N/A')))
log.logger.info('==================================================')
log.logger.info('=> [{}] cases generated, [{}] cases failed'.format(success_cnt, failure_cnt))
# show configs
log.logger.info('ARGS:')
log.logger.info('platform={}'.format(platform))
log.logger.info('compiler bin = {}'.format(compiler_bin))
log.logger.info('src_model={}'.format(src_model))
log.logger.info('Generate to [{}]'.format(log_path))