from base.preprocess import preprocess_hw from base.postprocess import postprocess_hw from .lite_hrnet_postprocess import postprocess_ from base.inference import inference_ class LiteHrnetRunner: def __init__(self, model_path, **kwargs): """ :param model_path: string, path of the pytorch model file. :param lib_path: a folder that include all dependencies scripts of building block classes. (optional) :param model_def_path: a script that instantiates `target_model` based on the imported building blocks. if `lib_path` is not defined, one should difine all building block classes in this script before instantiating `target_model` """ from kneron_utils.model import load_model self.model, self._inference_type = load_model(model_path, **kwargs) self.init_config = locals() self.init_config.update(kwargs) self.init_config['_type'] = self._inference_type def _exec_py_script(self, model_def_path): """ a script that import `target_model` into global() so that the runner can access """ with open(model_def_path, 'r') as f: lines = f.readlines() script_content = "".join(lines) if 'target_model' in script_content: self._target_model_instantiate(model_def_path) else: # declare everything from the script into global variable, dangerous! # this part will be deprecated in near future exec(script_content, globals()) ''' Lite hrnet runner :param img unused img path parameter :param pre_results should contain inputs (dim (1, 256, 192, 3)) and preprocessing information for final coords scaling :return: 17 x 2 kpts as flattened list ''' def run(self,img,pre_results,**kwargs): if len(pre_results)<2 or len(pre_results[0])==0 or len(pre_results[1])==0: return [] bboxes = [] infer_config = { 'model': self.model, '_type':self._inference_type } results=[] inputs_list,pre_infos = pre_results for inputs,pre_info in zip(inputs_list,pre_infos): pre_config = { } pre_config.update(kwargs) pre_config.update(self.init_config) infer_config.update(self.init_config) outputs = inference_(inputs, **infer_config) post_config = { } post_config.update(self.init_config) post_config.update(pre_info) post_config.update(kwargs) preds = postprocess_hw(outputs, postprocess_, **post_config) results.append(preds) return results