/** * \brief ONNXVectorIO declarations and definitions * \author Nan Zhou, nanzhou at kneron dot us * \copyright 2019 Kneron Inc. All right reserved. */ #ifndef PIANO_DYNASTY_INCLUDE_IO_ONNXVECTORIO_H_ #define PIANO_DYNASTY_INCLUDE_IO_ONNXVECTORIO_H_ #include #include #include #include #include #include #include "InferenceLogger.h" #include "fmt/format.h" #include "fmt/ostream.h" #include "npy.hpp" namespace dynasty { namespace io { /** * \class ONNXVectorIO * \brief a class to save vectors to files, or load vectors from files, the order follows tensors in ONNX, which is [b, * c, h, w] */ template class ONNXVectorIO { public: /** * \brief save the input vector to a file, the order is [b, c, h, w] */ static void Save1dVectorToFile(std::vector const &input_vector, std::string const &file_location) { auto save_txt = [](std::vector const &input_vector, const std::string &file_location) { std::ofstream f_out(file_location); auto out = fmt::memory_buffer(); for (auto const &i : input_vector) { if (std::is_floating_point_v == true) fmt::format_to(std::back_inserter(out), "{:.8f}\n", i); else if (std::is_integral_v == true) fmt::format_to(std::back_inserter(out), "{}\n", i); else throw std::runtime_error("unsupported data type" + std::string(typeid(T).name())); } std::string s(out.begin(), out.end()); fmt::print(f_out, "{}", s); f_out.flush(); f_out.close(); }; auto save_npy = [](std::vector const &input_vector, const std::string &file_location) { const npy::npy_data_ptr data_ptr{input_vector.data(), {1, input_vector.size()}, false}; write_npy(file_location, data_ptr); }; std::filesystem::path path = file_location; std::string ext = path.extension(); if (ext == ".txt") { save_txt(input_vector, file_location); } else if (ext == ".npy") { save_npy(input_vector, file_location); } else { throw std::runtime_error("Not supported extention(only txt/npy supported): " + ext); } } /** * \brief the file is supposed to be in order [b, h, w, c], * we convert it to [b, c, h, w] */ static std::vector Load1dVectorFromFile(std::string const &file_location, std::vector const &shape) { std::vector raw_vector = Load1dVectorFromFile(file_location); size_t shape_size = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); if (raw_vector.size() != shape_size) { throw std::runtime_error("size in file: " + std::to_string(raw_vector.size()) + "not match with size calced in shape: " + std::to_string(shape_size)); } if (shape.size() != 4) return raw_vector; std::vector result_vector(raw_vector.size()); uint32_t b, c, h, w; b = shape.at(0); c = shape.at(1); h = shape.at(2); w = shape.at(3); for (size_t index = 0; index < raw_vector.size(); index++) { size_t in = index; uint32_t k = in / (c * h * w); // batch_id in %= (c * h * w); uint32_t i = in / c; uint32_t j = in % c; uint32_t addr = k * (c * h * w) + j * (w * h) + i; result_vector.at(addr) = raw_vector[index]; } return result_vector; } template static std::vector transform_type(const std::string &file_location) { std::vector d = npy::read_npy(file_location).data; std::vector ret; std::transform(d.begin(), d.end(), std::back_inserter(ret), [](auto d_) { return T(d_); }); return ret; } static std::vector load_npy(const std::string &file_location) { std::string header_s; { std::ifstream stream(file_location, std::ifstream::binary); if (!stream) { throw std::runtime_error("io error: failed to open a file."); } header_s = npy::read_header(stream); } std::vector ret; // parse header npy::header_t header = npy::parse_header(header_s); if (header.dtype.tie() == npy::dtype_map.at(std::type_index(typeid(float))).tie()) { return transform_type(file_location); } if (header.dtype.tie() == npy::dtype_map.at(std::type_index(typeid(double))).tie()) { return transform_type(file_location); } if (header.dtype.tie() == npy::dtype_map.at(std::type_index(typeid(char))).tie()) { return transform_type(file_location); } if (header.dtype.tie() == npy::dtype_map.at(std::type_index(typeid(short))).tie()) { return transform_type(file_location); } if (header.dtype.tie() == npy::dtype_map.at(std::type_index(typeid(int))).tie()) { return transform_type(file_location); } if (header.dtype.tie() == npy::dtype_map.at(std::type_index(typeid(long))).tie()) { return transform_type(file_location); } if (header.dtype.tie() == npy::dtype_map.at(std::type_index(typeid(unsigned char))).tie()) { return transform_type(file_location); } if (header.dtype.tie() == npy::dtype_map.at(std::type_index(typeid(unsigned short))).tie()) { return transform_type(file_location); } if (header.dtype.tie() == npy::dtype_map.at(std::type_index(typeid(unsigned int))).tie()) { return transform_type(file_location); } if (header.dtype.tie() == npy::dtype_map.at(std::type_index(typeid(unsigned long))).tie()) { return transform_type(file_location); } /* if(header.dtype.tie() == npy::dtype_map.at(std::type_index(typeid(bool))).tie()) { return transform_type(stream); } */ throw std::runtime_error("not supported type in load_npy"); } /** * \brief the file is supposed to be in order [b, c, h, w], * this function will just load these numbers line by line */ static std::vector Load1dVectorFromFile(std::string const &file_location) { auto load_txt = [](const std::string &file_location) { std::vector result_vector; std::string line; std::ifstream input_file(file_location); if (!input_file) { SPDLOG_LOGGER_ERROR(dynasty::log::InferenceLogger::GetLogger(), "Can not open " + file_location); exit(1); } while (getline(input_file, line)) { result_vector.push_back(DataConverter(line)); } input_file.close(); return result_vector; }; std::filesystem::path path = file_location; std::string ext = path.extension(); std::vector res; if (ext == ".txt" || ext == ".golden") { res = load_txt(file_location); } else if (ext == ".npy") { res = load_npy(file_location); } else { throw std::runtime_error("Not supported extention(only txt/npy supported): " + ext); } return res; } protected: static T DataConverter(std::string const &num); ~ONNXVectorIO() = default; }; } // namespace io } // namespace dynasty #endif // PIANO_DYNASTY_INCLUDE_IO_ONNXVECTORIO_H_