144 lines
4.4 KiB
Python
144 lines
4.4 KiB
Python
import numpy as np
|
|
import os
|
|
import shutil
|
|
from html.parser import HTMLParser
|
|
import onnx
|
|
import logging
|
|
import re
|
|
|
|
if onnx.__version__ == '1.7.0':
|
|
extract_bie_ioinfo = lambda x: ([], []) # Dummy function for onnx 1.7.0
|
|
else:
|
|
from extract_bie_info import extract_bie_ioinfo
|
|
|
|
RESULT_FOLDER = os.environ.get("KTC_OUTPUT_DIR", "/data1/kneron_flow")
|
|
TMP_FOLDER = os.environ.get("KTC_WORKDIR", "/workspace/.tmp")
|
|
LIBS_V1_FOLDER = "/workspace/libs"
|
|
LIBS_V2_FOLDER = "/workspace/libs_V2"
|
|
SCRIPT_FOLDER = "/workspace/scripts"
|
|
|
|
|
|
def convert_channel_last_to_first(image):
|
|
"""Convert channel last input image to channel first.
|
|
WARNING: This function is only valid for 2D image input.
|
|
|
|
Args:
|
|
image (np.array): input image
|
|
|
|
Returns:
|
|
np.array: converted input image
|
|
"""
|
|
# Check if the input image is a numpy array. If not, create a numpy array
|
|
if not isinstance(image, np.ndarray):
|
|
image = np.array(image)
|
|
# Check if the input is 4D. If not, unsqueeze the input image.
|
|
if len(image.shape) == 3:
|
|
image = np.expand_dims(image, axis=0)
|
|
return np.moveaxis(image, -1, 1)
|
|
|
|
|
|
def get_toolchain_version():
|
|
f = None
|
|
try:
|
|
f = open("/workspace/version.txt", "r")
|
|
result = f.read()
|
|
except:
|
|
result = "Unknown"
|
|
finally:
|
|
if f:
|
|
f.close()
|
|
return result
|
|
|
|
|
|
def clean_up(output_dir, files_to_delete=None, dirs_to_delete=None):
|
|
files_to_delete = files_to_delete or []
|
|
dirs_to_delete = dirs_to_delete or []
|
|
for file in files_to_delete:
|
|
abs_path = os.path.join(output_dir, file)
|
|
if os.path.isfile(abs_path):
|
|
os.remove(abs_path)
|
|
for folder in dirs_to_delete:
|
|
abs_path = os.path.join(output_dir, folder)
|
|
if os.path.isdir(abs_path):
|
|
shutil.rmtree(abs_path)
|
|
|
|
|
|
def evalutor_result_html_parse(html_content):
|
|
"""Print the html content in the result.html file."""
|
|
|
|
# Convert HTML table to dictionary
|
|
class EvaluatorHTMLParser(HTMLParser):
|
|
def __init__(self):
|
|
HTMLParser.__init__(self)
|
|
self.result = ""
|
|
# Status: 0 for not started, 1 for info table, 2 for details table
|
|
self.status = 0
|
|
|
|
def handle_starttag(self, tag, attrs):
|
|
if self.status == 0 and tag == "tbody" and len(self.result) == 0:
|
|
self.status = 1
|
|
if self.status == 0 and tag == "thead" and len(self.result) > 0:
|
|
self.status = 2
|
|
|
|
def handle_endtag(self, tag):
|
|
if self.status == 0:
|
|
pass
|
|
elif tag == "tr":
|
|
self.result += "\n"
|
|
elif self.status == 1 and tag == "tbody":
|
|
self.status = 0
|
|
self.result += "\n"
|
|
elif self.status == 1 and tag == "th":
|
|
self.result += ": "
|
|
elif self.status == 2 and tag in ["th", "td"]:
|
|
self.result += ", "
|
|
|
|
def handle_data(self, data):
|
|
if self.status != 0:
|
|
self.result += data.strip()
|
|
|
|
parser = EvaluatorHTMLParser()
|
|
parser.feed(html_content)
|
|
parser.close()
|
|
return parser.result
|
|
|
|
|
|
def get_input_names_from_onnx(onnx_model):
|
|
"""Get input names from onnx model."""
|
|
if onnx_model is None:
|
|
logging.error("onnx_model is None.")
|
|
return []
|
|
if not isinstance(onnx_model, onnx.ModelProto):
|
|
logging.error("onnx_model should be an instance of onnx.ModelProto.")
|
|
return []
|
|
input_names = [input.name for input in onnx_model.graph.input]
|
|
return input_names
|
|
|
|
|
|
def get_input_names_from_bie(bie_path):
|
|
"""Get input names from bie file."""
|
|
if bie_path is None:
|
|
logging.error("bie_path is None.")
|
|
return []
|
|
if not os.path.isfile(bie_path):
|
|
logging.error(f"bie file {bie_path} does not exist.")
|
|
return []
|
|
input_names, _ = extract_bie_ioinfo(bie_path)
|
|
return input_names
|
|
|
|
|
|
def check_filename_validity(filename):
|
|
"""Check if the filename is valid."""
|
|
if not isinstance(filename, str):
|
|
logging.error("Filename should be a string.")
|
|
return False
|
|
if not filename:
|
|
logging.error("Filename should not be empty.")
|
|
return False
|
|
filename = os.path.basename(filename)
|
|
filename = os.path.splitext(filename)[0]
|
|
if not re.match(r'^[\w.-]+$', filename):
|
|
logging.warning(f"Invalid filename: {filename}. Only alphanumeric characters, underscores, hyphens, and dots are allowed.")
|
|
return False
|
|
return True
|