70 lines
2.7 KiB
Python

from base.preprocess import preprocess_hw
from base.postprocess import postprocess_hw
from .lite_hrnet_postprocess import postprocess_
from base.inference import inference_
class LiteHrnetRunner:
def __init__(self, model_path, **kwargs):
"""
:param model_path:
string, path of the pytorch model file.
:param lib_path:
a folder that include all dependencies scripts of building block classes. (optional)
:param model_def_path:
a script that instantiates `target_model` based on the imported building blocks.
if `lib_path` is not defined, one should difine all building block classes in this script
before instantiating `target_model`
"""
from kneron_utils.model import load_model
self.model, self._inference_type = load_model(model_path, **kwargs)
self.init_config = locals()
self.init_config.update(kwargs)
self.init_config['_type'] = self._inference_type
def _exec_py_script(self, model_def_path):
"""
a script that import `target_model` into global() so that the runner can access
"""
with open(model_def_path, 'r') as f:
lines = f.readlines()
script_content = "".join(lines)
if 'target_model' in script_content:
self._target_model_instantiate(model_def_path)
else:
# declare everything from the script into global variable, dangerous!
# this part will be deprecated in near future
exec(script_content, globals())
'''
Lite hrnet runner
:param img unused img path parameter
:param pre_results should contain inputs (dim (1, 256, 192, 3)) and preprocessing information for final coords scaling
:return: 17 x 2 kpts as flattened list
'''
def run(self,img,pre_results,**kwargs):
if len(pre_results)<2 or len(pre_results[0])==0 or len(pre_results[1])==0:
return []
bboxes = []
infer_config = {
'model': self.model,
'_type':self._inference_type
}
results=[]
inputs_list,pre_infos = pre_results
for inputs,pre_info in zip(inputs_list,pre_infos):
pre_config = {
}
pre_config.update(kwargs)
pre_config.update(self.init_config)
infer_config.update(self.init_config)
outputs = inference_(inputs, **infer_config)
post_config = {
}
post_config.update(self.init_config)
post_config.update(pre_info)
post_config.update(kwargs)
preds = postprocess_hw(outputs, postprocess_, **post_config)
results.append(preds)
return results