#!/usr/bin/env python3 import sys import os import shutil import subprocess import multiprocessing import argparse import re import logging import time from utils import log from utils import util def find_model_path(case_dir, platform): regex = re.compile('(.*kdp{}.*scaled\.onnx$)'.format(platform)) for root, dirs, files in os.walk(case_dir): for file in files: if regex.match(file): model_path = os.path.join(root, file) return model_path return None def find_compiler_config_path(case_dir, platform): regex = re.compile('(.*config_{}\.json$)'.format(platform)) for root, dirs, files in os.walk(case_dir): for file in files: if regex.match(file): model_path = os.path.join(root, file) return model_path return None def arg_parse(): desc = 'Requirement:\n - Run "gen_config.py" to generate config.json before using this script.\n\n' desc += 'Environment variables:\n - COMPILER_BIN_DIR: specify compiler bin dir\n\n' desc += 'Sub-modules:\n - image_cut_search: 720|530|730|630|540\n' parser = argparse.ArgumentParser(description=desc, formatter_class=argparse.RawTextHelpFormatter) parser.add_argument("platform", help="HW platform <520|720|530|730|540>") parser.add_argument("src_model", help="top directory includes models") parser.add_argument("output_dir", help="log output directory.") parser.add_argument("-d", "--debug", action="store_true", default=False, help="debug mode. Show more debug message") return parser.parse_args() def run_opt_compile(script, case_name, platform, model, config, log_dir, debug): command = '{command} {platform} {model} {config} {output_dir} {debug}'.format(command=script, platform=platform, model=model, config=config, output_dir=log_dir, debug=debug) cmd_args = command.split() start_tm = time.time() o = subprocess.run(cmd_args, stdout=subprocess.PIPE, stderr=subprocess.PIPE) end_tm = time.time() log.logger.info('================={}=================='.format(case_name)) log.logger.info('run model : {} ...'.format(case_name)) log.logger.info('opt_wrapper : {}'.format(command)) if o.returncode != 0: result[case_name] = 'FAILED [compilation failed]' else: result[case_name] = 'SUCCESS' tm = (end_tm - start_tm) / 60 # minutes result_tm[case_name] = "{:.2f} minutes".format(tm) if o.stdout: log.logger.info('\n\n' + o.stdout.decode()) if o.stderr: log.logger.error('\n\n' + o.stderr.decode()) if __name__ == "__main__": # get opt wrapper directory opt_dir = os.path.dirname(os.path.abspath(__file__)) # get arguments args = arg_parse() platform = args.platform src_model = args.src_model output_dir = args.output_dir debug = '-d' if args.debug else '' log_level = logging.DEBUG if args.debug else logging.INFO opt_compile = '{}/opt_compile.py'.format(opt_dir) if output_dir and not util.is_file_name_valid(output_dir): print('Please give a valid -o/OUTPUT_DIR. "{}" is invalid'.format(output_dir)) sys.exit(-1) # set compiler binary path to environment compiler_bin = util.get_compiler_bin(opt_dir) os.environ['COMPILER_BIN'] = compiler_bin # create logger log_path = output_dir log = log.Log(log_path, False) log.logger.setLevel(log_level) # run opt_compile: iterate all test cases under top case directory result = multiprocessing.Manager().dict() result_tm = multiprocessing.Manager().dict() cases = util.get_case_name(src_model) pool = multiprocessing.Pool() for case_name in cases: # step 1: check if test case exist case_path = os.path.join(src_model, case_name) if not util.check_file_exist(case_path): log.logger.error('case directory for [{}] not exist'.format(case_name)) continue # step 2: check if model (onnx) exist model_path = find_model_path(case_path, platform) if not util.check_file_exist(model_path): result[case_name] = 'FAILED [no onnx model]' log.logger.error('model for [{}] not exist'.format(case_name)) continue # step 3: check if compiler config exist config_path = find_compiler_config_path(case_path, platform) if not util.check_file_exist(config_path): result[case_name] = 'FAILED [no compiler config]' log.logger.error('compiler config for [{}] not exist'.format(case_name)) continue # step 4: run opt_compiler log_dir = '{}/{}'.format(output_dir, case_name) pool.apply_async(run_opt_compile, args=(opt_compile, case_name, platform, model_path, config_path, log_dir, debug)) pool.close() pool.join() # show compile result log.logger.info('=== Summary ===') success_cnt = 0 failure_cnt = 0 for model, status in result.items(): if (status.find('SUCCESS') != -1): success_cnt += 1 else: failure_cnt += 1 log.logger.info('%-60s: %-30s %-12s' % (model, status, result_tm.get(model, 'N/A'))) log.logger.info('==================================================') log.logger.info('=> [{}] cases generated, [{}] cases failed'.format(success_cnt, failure_cnt)) # show configs log.logger.info('ARGS:') log.logger.info('platform={}'.format(platform)) log.logger.info('compiler bin = {}'.format(compiler_bin)) log.logger.info('src_model={}'.format(src_model)) log.logger.info('Generate to [{}]'.format(log_path))