23 lines
786 B
Python
23 lines
786 B
Python
import os
|
|
|
|
ENVVAR_KEY_NPU_ROUNDING_MODE = "NPU_ROUNDING_MODE"
|
|
ENVVAR_VAL_NPU_ROUNDING_MODE_INF = "NPU_ROUNDING_MODE_INF"
|
|
ENVVAR_VAL_NPU_ROUNDING_MODE_ROUND2EVEN = "NPU_ROUNDING_MODE_ROUND2EVEN"
|
|
|
|
RUN_VAL_GO = 0x1
|
|
RUN_VAL_ROUND2INF = 0x0
|
|
RUN_VAL_ROUND2EVEN_CONV = 0x20000000
|
|
RUN_VAL_ROUND2EVEN_PCONV = 0x10000000
|
|
RUN_VAL_ROUND2EVEN_PFUNC = 0x8000000
|
|
RUN_VAL_ROUND2EVEN = RUN_VAL_ROUND2EVEN_CONV | RUN_VAL_ROUND2EVEN_PCONV | RUN_VAL_ROUND2EVEN_PFUNC
|
|
|
|
|
|
def get_round_val():
|
|
rounding_mode = os.getenv(ENVVAR_KEY_NPU_ROUNDING_MODE, ENVVAR_VAL_NPU_ROUNDING_MODE_ROUND2EVEN)
|
|
return hex(RUN_VAL_ROUND2INF if(rounding_mode != None and rounding_mode == ENVVAR_VAL_NPU_ROUNDING_MODE_INF) else RUN_VAL_ROUND2EVEN)
|
|
|
|
|
|
def get_run_val():
|
|
return hex(RUN_VAL_GO | int(get_round_val(), 16))
|
|
|