/** * \brief ONNXVectorIO declarations and definitions * \author Nan Zhou, nanzhou at kneron dot us * \copyright 2019 Kneron Inc. All right reserved. */ #ifndef PIANO_DYNASTY_INCLUDE_IO_ONNXVECTORIO_H_ #define PIANO_DYNASTY_INCLUDE_IO_ONNXVECTORIO_H_ #include #include #include #include #include "InferenceLogger.h" namespace dynasty { namespace io { /** * \class ONNXVectorIO * \brief a class to save vectors to files, or load vectors from files, the order follows tensors in ONNX, which is [b, c, h, w] */ template class ONNXVectorIO { public: /** * \brief save the input vector to a file, the order is [b, c, h, w] */ static void Save1dVectorToFile(std::vector const &input_vector, std::string const &file_location) { size_t pos = file_location.find_last_of("/\\"); std::string dirname = file_location.substr(0, pos); std::string filename = file_location.substr(pos + 1); int result = chdir(dirname.c_str()); if (result == -1) { std::cerr << "Failed to change current working directory to dump directory" << dirname << std::endl; } std::ofstream f_out; // Set exceptions to be thrown on failure f_out.exceptions(std::ifstream::failbit | std::ifstream::badbit); try { f_out.open(filename); } catch (std::system_error &e) { SPDLOG_LOGGER_ERROR(dynasty::log::InferenceLogger::GetLogger(), "Failed to open file {}: {}", file_location, e.code().message().c_str()); exit(1); } for (auto const &i : input_vector) { f_out << std::setprecision(8) << i << std::endl; } } /** * \brief the file is supposed to be in order [b, h, w, c], * we convert it to [b, c, h, w] */ static std::vector Load1dVectorFromFile(std::string const &file_location, std::vector const &shape) { uint32_t c = 0, h = 0, w = 0, b = 0; b = shape.size() > 0 ? shape.at(0) : 1; c = shape.size() > 1 ? shape.at(1) : 1; h = shape.size() > 2 ? shape.at(2) : 1; w = shape.size() > 3 ? shape.at(3) : 1; if (shape.size() != 4) { // not bhwc format c = 1; h = 1; w = 1; b = 1; for (uint32_t i = 0; i < shape.size(); i++) { w *= shape[i]; } } std::vector result_vector(b * c * h * w); std::string line; std::ifstream input_file(file_location); if (!input_file) { SPDLOG_LOGGER_ERROR(dynasty::log::InferenceLogger::GetLogger(), "Can not open " + file_location); exit(1); } uint32_t index = 0; while (getline(input_file, line)) { uint32_t in = index; uint32_t k = in / (c*h*w); // batch_id in %= (c*h*w); uint32_t i = in / c; uint32_t j = in % c; uint32_t addr = k * (c * h * w) + j * (w * h) + i; if(addr >= b * c * w * h){ SPDLOG_LOGGER_ERROR(dynasty::log::InferenceLogger::GetLogger(), "Input file dimension is greater than expected: " + std::to_string(b*c*w*h) + ". It has reached: " + std::to_string(addr+1) + " lines. " + file_location); exit(1); } result_vector.at(addr) = DataConverter(line); index++; } if (index != b * c * w * h) { SPDLOG_LOGGER_ERROR(dynasty::log::InferenceLogger::GetLogger(), "Input file dimension is not correct: expected: " + std::to_string(b*c*w*h) + ", but got: " + std::to_string(index) + " lines. " + file_location); exit(1); } input_file.close(); return result_vector; } /** * \brief the file is supposed to be in order [b, c, h, w], * this function will just load these numbers line by line */ static std::vector Load1dVectorFromFile(std::string const &file_location) { std::vector result_vector; std::string line; std::ifstream input_file(file_location); if (!input_file) { SPDLOG_LOGGER_ERROR(dynasty::log::InferenceLogger::GetLogger(), "Can not open " + file_location); exit(1); } while (getline(input_file, line)) { result_vector.push_back(DataConverter(line)); } input_file.close(); return result_vector; } protected: static T DataConverter(std::string const &num); ~ONNXVectorIO() = default; }; } } #endif //PIANO_DYNASTY_INCLUDE_IO_ONNXVECTORIO_H_