70 lines
2.7 KiB
Python
70 lines
2.7 KiB
Python
from base.preprocess import preprocess_hw
|
|
from base.postprocess import postprocess_hw
|
|
from .lite_hrnet_postprocess import postprocess_
|
|
from base.inference import inference_
|
|
class LiteHrnetRunner:
|
|
def __init__(self, model_path, **kwargs):
|
|
"""
|
|
:param model_path:
|
|
string, path of the pytorch model file.
|
|
:param lib_path:
|
|
a folder that include all dependencies scripts of building block classes. (optional)
|
|
:param model_def_path:
|
|
a script that instantiates `target_model` based on the imported building blocks.
|
|
if `lib_path` is not defined, one should difine all building block classes in this script
|
|
before instantiating `target_model`
|
|
"""
|
|
from kneron_utils.model import load_model
|
|
self.model, self._inference_type = load_model(model_path, **kwargs)
|
|
self.init_config = locals()
|
|
self.init_config.update(kwargs)
|
|
self.init_config['_type'] = self._inference_type
|
|
|
|
|
|
def _exec_py_script(self, model_def_path):
|
|
"""
|
|
a script that import `target_model` into global() so that the runner can access
|
|
"""
|
|
with open(model_def_path, 'r') as f:
|
|
lines = f.readlines()
|
|
script_content = "".join(lines)
|
|
if 'target_model' in script_content:
|
|
self._target_model_instantiate(model_def_path)
|
|
else:
|
|
# declare everything from the script into global variable, dangerous!
|
|
# this part will be deprecated in near future
|
|
exec(script_content, globals())
|
|
'''
|
|
Lite hrnet runner
|
|
:param img unused img path parameter
|
|
:param pre_results should contain inputs (dim (1, 256, 192, 3)) and preprocessing information for final coords scaling
|
|
:return: 17 x 2 kpts as flattened list
|
|
'''
|
|
def run(self,img,pre_results,**kwargs):
|
|
if len(pre_results)<2 or len(pre_results[0])==0 or len(pre_results[1])==0:
|
|
return []
|
|
bboxes = []
|
|
infer_config = {
|
|
'model': self.model,
|
|
'_type':self._inference_type
|
|
}
|
|
results=[]
|
|
inputs_list,pre_infos = pre_results
|
|
for inputs,pre_info in zip(inputs_list,pre_infos):
|
|
pre_config = {
|
|
}
|
|
pre_config.update(kwargs)
|
|
pre_config.update(self.init_config)
|
|
|
|
infer_config.update(self.init_config)
|
|
outputs = inference_(inputs, **infer_config)
|
|
post_config = {
|
|
}
|
|
post_config.update(self.init_config)
|
|
post_config.update(pre_info)
|
|
post_config.update(kwargs)
|
|
preds = postprocess_hw(outputs, postprocess_, **post_config)
|
|
|
|
results.append(preds)
|
|
return results
|