73 lines
1.9 KiB
Python
73 lines
1.9 KiB
Python
#!/usr/bin/env python
|
|
# coding: utf-8
|
|
|
|
# In[2]:
|
|
|
|
|
|
import os
|
|
import numpy as np
|
|
import argparse
|
|
|
|
|
|
# In[3]:
|
|
|
|
def softmax(x):
|
|
"""Compute softmax values for each sets of scores in x."""
|
|
e_x = np.exp(x - np.max(x))
|
|
return e_x / e_x.sum()
|
|
|
|
def main_(args):
|
|
|
|
emulator_folder_path = args.input_folder
|
|
model = args.model
|
|
|
|
if model == "pn_color":
|
|
for folder in os.listdir(emulator_folder_path):
|
|
folder_path = os.path.join(emulator_folder_path, folder)
|
|
for file in os.listdir(folder_path):
|
|
if file.startswith("temp"):
|
|
file_path = os.path.join(folder_path, file)
|
|
res = []
|
|
file_ = open(file_path, "r")
|
|
for line in file_.readlines():
|
|
res.append(float(line.strip('\n')))
|
|
file_.close()
|
|
if len(res) != 10:
|
|
raise ValueError("illegal output nums")
|
|
else:
|
|
res = np.asarray(res)
|
|
res = softmax(res)
|
|
res_file_path = os.path.join(folder_path, "softmax_result.txt")
|
|
res_file = open(res_file_path, "w+")
|
|
for i in range(10):
|
|
res_file.write(str(res[i]) + "\n")
|
|
res_file.close()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
argparser = argparse.ArgumentParser(
|
|
description="post process for different models"
|
|
)
|
|
|
|
argparser.add_argument(
|
|
'-i',
|
|
'--input_folder',
|
|
help="input folder, i.e. emulator result folder, default as /data1/emulator",
|
|
default="/data1/emulator"
|
|
)
|
|
|
|
argparser.add_argument(
|
|
'-m',
|
|
'--model',
|
|
help="model type, default as pn_color",
|
|
default="pn_color"
|
|
)
|
|
|
|
|
|
args = argparser.parse_args()
|
|
|
|
main_(args)
|
|
|
|
|
|
|