85 lines
3.3 KiB
Python
85 lines
3.3 KiB
Python
"""
|
|
Class to run internal runners through the simulator.
|
|
"""
|
|
import importlib
|
|
import json
|
|
|
|
import numpy as np
|
|
from PIL import Image
|
|
|
|
import python_flow.common.exceptions as exceptions
|
|
|
|
ALG_JSON = "alg/runners/kneron_globalconstant/base/alg_runner_map.json"
|
|
SYS_JSON = "sys/runners/sys_runner_map.json"
|
|
|
|
class Runner():
|
|
"""Class to map team runners and call them when needed."""
|
|
def __init__(self, config, **kwargs):
|
|
self.runner_map = {
|
|
"algorithm": {},
|
|
"system": {},
|
|
}
|
|
self.config = config
|
|
self.runner = {
|
|
"algorithm": {},
|
|
"system": {},
|
|
}
|
|
|
|
for key in config.keys():
|
|
self.init_runner(key)
|
|
|
|
def init_runner(self, key):
|
|
"""Initialize each team runner needed."""
|
|
model, num = key
|
|
model_config = self.config[key].config
|
|
self.import_runner(model)
|
|
if model_config["emu"]["runner_mode"] == "algorithm":
|
|
self.runner["algorithm"][(model, num)] = self.runner_map["algorithm"][model](
|
|
USER_CONFIG=model_config, **model_config["pre"]["algorithm"])
|
|
elif model_config["emu"]["runner_mode"] == "system":
|
|
self.runner["system"][(model, num)] = self.runner_map["system"][model](
|
|
USER_CONFIG=model_config, **model_config["pre"]["system"])
|
|
|
|
def call_internal(self, key, function_name, *args, **kwargs):
|
|
"""Calls an internal function using the specified runner."""
|
|
mode = self.config[key].config["emu"]["runner_mode"]
|
|
func = getattr(self.runner[mode][key], function_name)
|
|
func(*args, **kwargs)
|
|
|
|
def run(self, key, *args, **kwargs):
|
|
"""Calls the team runner as specified by the post_mode in the input JSON."""
|
|
runner_dump = self.config[key].config["flow"]["runner_dump"]
|
|
mode = self.config[key].config["emu"]["runner_mode"]
|
|
results = self.runner[mode][key].run(*args, **kwargs)
|
|
if runner_dump is not None and key[0] in runner_dump:
|
|
dump_image_result(
|
|
np.asarray(results), self.config[key].config["flow"]["out_folder"], key[0])
|
|
return results
|
|
|
|
def import_runner(self, model):
|
|
"""Try to import the runner given the model."""
|
|
if model not in self.runner_map["algorithm"]:
|
|
with open(ALG_JSON) as alg_json:
|
|
data = json.load(alg_json)
|
|
self.try_import("algorithm", model, data[model])
|
|
|
|
def try_import(self, team, alias, import_string):
|
|
"""Try to import the specified runner and model combination."""
|
|
try:
|
|
mod, attr = import_string.rsplit(".", 1)
|
|
module = importlib.import_module(mod)
|
|
obj = getattr(module, attr)
|
|
self.runner_map[team][alias] = obj
|
|
except (AttributeError, ModuleNotFoundError) as error:
|
|
print(f"Failed to import {import_string}: {error}")
|
|
|
|
def dump_image_result(results, out_folder, key):
|
|
"""Dump the results from key runner as an image in out_folder."""
|
|
if len(results.shape) == 4:
|
|
results = results.reshape(results.shape[1:])
|
|
elif len(results.shape) != 3:
|
|
raise exceptions.UnsupportedConfigError(
|
|
f"Runner result is not the right dimension. Image cannot be dumped.")
|
|
image = Image.fromarray(results.astype("uint8"), "RGB")
|
|
image.save(str(out_folder / (key + ".png")))
|