import os ENVVAR_KEY_NPU_ROUNDING_MODE = "NPU_ROUNDING_MODE" ENVVAR_VAL_NPU_ROUNDING_MODE_INF = "NPU_ROUNDING_MODE_INF" ENVVAR_VAL_NPU_ROUNDING_MODE_ROUND2EVEN = "NPU_ROUNDING_MODE_ROUND2EVEN" RUN_VAL_GO = 0x1 RUN_VAL_ROUND2INF = 0x0 RUN_VAL_ROUND2EVEN_CONV = 0x20000000 RUN_VAL_ROUND2EVEN_PCONV = 0x10000000 RUN_VAL_ROUND2EVEN_PFUNC = 0x8000000 RUN_VAL_ROUND2EVEN = RUN_VAL_ROUND2EVEN_CONV | RUN_VAL_ROUND2EVEN_PCONV | RUN_VAL_ROUND2EVEN_PFUNC def get_round_val(): rounding_mode = os.getenv(ENVVAR_KEY_NPU_ROUNDING_MODE, ENVVAR_VAL_NPU_ROUNDING_MODE_ROUND2EVEN) return hex(RUN_VAL_ROUND2INF if(rounding_mode != None and rounding_mode == ENVVAR_VAL_NPU_ROUNDING_MODE_INF) else RUN_VAL_ROUND2EVEN) def get_run_val(): return hex(RUN_VAL_GO | int(get_round_val(), 16))