2026-01-28 06:16:04 +00:00

73 lines
1.9 KiB
Python

#!/usr/bin/env python
# coding: utf-8
# In[2]:
import os
import numpy as np
import argparse
# In[3]:
def softmax(x):
"""Compute softmax values for each sets of scores in x."""
e_x = np.exp(x - np.max(x))
return e_x / e_x.sum()
def main_(args):
emulator_folder_path = args.input_folder
model = args.model
if model == "pn_color":
for folder in os.listdir(emulator_folder_path):
folder_path = os.path.join(emulator_folder_path, folder)
for file in os.listdir(folder_path):
if file.startswith("temp"):
file_path = os.path.join(folder_path, file)
res = []
file_ = open(file_path, "r")
for line in file_.readlines():
res.append(float(line.strip('\n')))
file_.close()
if len(res) != 10:
raise ValueError("illegal output nums")
else:
res = np.asarray(res)
res = softmax(res)
res_file_path = os.path.join(folder_path, "softmax_result.txt")
res_file = open(res_file_path, "w+")
for i in range(10):
res_file.write(str(res[i]) + "\n")
res_file.close()
if __name__ == "__main__":
argparser = argparse.ArgumentParser(
description="post process for different models"
)
argparser.add_argument(
'-i',
'--input_folder',
help="input folder, i.e. emulator result folder, default as /data1/emulator",
default="/data1/emulator"
)
argparser.add_argument(
'-m',
'--model',
help="model type, default as pn_color",
default="pn_color"
)
args = argparser.parse_args()
main_(args)