2026-01-28 06:16:04 +00:00

210 lines
7.7 KiB
C++

/**
* \brief ONNXVectorIO declarations and definitions
* \author Nan Zhou, nanzhou at kneron dot us
* \copyright 2019 Kneron Inc. All right reserved.
*/
#ifndef PIANO_DYNASTY_INCLUDE_IO_ONNXVECTORIO_H_
#define PIANO_DYNASTY_INCLUDE_IO_ONNXVECTORIO_H_
#include <unistd.h>
#include <filesystem>
#include <fstream>
#include <iomanip>
#include <iostream>
#include <numeric>
#include "InferenceLogger.h"
#include "fmt/format.h"
#include "fmt/ostream.h"
#include "npy.hpp"
namespace dynasty {
namespace io {
/**
* \class ONNXVectorIO
* \brief a class to save vectors to files, or load vectors from files, the order follows tensors in ONNX, which is [b,
* c, h, w]
*/
template <typename T>
class ONNXVectorIO {
public:
/**
* \brief save the input vector to a file, the order is [b, c, h, w]
*/
static void Save1dVectorToFile(std::vector<T> const &input_vector, std::string const &file_location) {
auto save_txt = [](std::vector<T> const &input_vector, const std::string &file_location) {
std::ofstream f_out(file_location);
auto out = fmt::memory_buffer();
for (auto const &i : input_vector) {
if (std::is_floating_point_v<T> == true)
fmt::format_to(std::back_inserter(out), "{:.8f}\n", i);
else if (std::is_integral_v<T> == true)
fmt::format_to(std::back_inserter(out), "{}\n", i);
else
throw std::runtime_error("unsupported data type" + std::string(typeid(T).name()));
}
std::string s(out.begin(), out.end());
fmt::print(f_out, "{}", s);
f_out.flush();
f_out.close();
};
auto save_npy = [](std::vector<T> const &input_vector, const std::string &file_location) {
const npy::npy_data_ptr<T> data_ptr{input_vector.data(), {1, input_vector.size()}, false};
write_npy(file_location, data_ptr);
};
std::filesystem::path path = file_location;
std::string ext = path.extension();
if (ext == ".txt") {
save_txt(input_vector, file_location);
} else if (ext == ".npy") {
save_npy(input_vector, file_location);
} else {
throw std::runtime_error("Not supported extention(only txt/npy supported): " + ext);
}
}
/**
* \brief the file is supposed to be in order [b, h, w, c],
* we convert it to [b, c, h, w]
*/
static std::vector<T> Load1dVectorFromFile(std::string const &file_location, std::vector<int32_t> const &shape) {
std::vector<T> raw_vector = Load1dVectorFromFile(file_location);
size_t shape_size = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int32_t>());
if (raw_vector.size() != shape_size) {
throw std::runtime_error("size in file: " + std::to_string(raw_vector.size()) +
"not match with size calced in shape: " + std::to_string(shape_size));
}
if (shape.size() != 4) return raw_vector;
std::vector<T> result_vector(raw_vector.size());
uint32_t b, c, h, w;
b = shape.at(0);
c = shape.at(1);
h = shape.at(2);
w = shape.at(3);
for (size_t index = 0; index < raw_vector.size(); index++) {
size_t in = index;
uint32_t k = in / (c * h * w); // batch_id
in %= (c * h * w);
uint32_t i = in / c;
uint32_t j = in % c;
uint32_t addr = k * (c * h * w) + j * (w * h) + i;
result_vector.at(addr) = raw_vector[index];
}
return result_vector;
}
template <typename ScalarFm>
static std::vector<T> transform_type(const std::string &file_location) {
std::vector<ScalarFm> d = npy::read_npy<ScalarFm>(file_location).data;
std::vector<T> ret;
std::transform(d.begin(), d.end(), std::back_inserter(ret), [](auto d_) { return T(d_); });
return ret;
}
static std::vector<T> load_npy(const std::string &file_location) {
std::string header_s;
{
std::ifstream stream(file_location, std::ifstream::binary);
if (!stream) {
throw std::runtime_error("io error: failed to open a file.");
}
header_s = npy::read_header(stream);
}
std::vector<T> ret;
// parse header
npy::header_t header = npy::parse_header(header_s);
if (header.dtype.tie() == npy::dtype_map.at(std::type_index(typeid(float))).tie()) {
return transform_type<float>(file_location);
}
if (header.dtype.tie() == npy::dtype_map.at(std::type_index(typeid(double))).tie()) {
return transform_type<double>(file_location);
}
if (header.dtype.tie() == npy::dtype_map.at(std::type_index(typeid(char))).tie()) {
return transform_type<char>(file_location);
}
if (header.dtype.tie() == npy::dtype_map.at(std::type_index(typeid(short))).tie()) {
return transform_type<short>(file_location);
}
if (header.dtype.tie() == npy::dtype_map.at(std::type_index(typeid(int))).tie()) {
return transform_type<int>(file_location);
}
if (header.dtype.tie() == npy::dtype_map.at(std::type_index(typeid(long))).tie()) {
return transform_type<long>(file_location);
}
if (header.dtype.tie() == npy::dtype_map.at(std::type_index(typeid(unsigned char))).tie()) {
return transform_type<unsigned char>(file_location);
}
if (header.dtype.tie() == npy::dtype_map.at(std::type_index(typeid(unsigned short))).tie()) {
return transform_type<unsigned short>(file_location);
}
if (header.dtype.tie() == npy::dtype_map.at(std::type_index(typeid(unsigned int))).tie()) {
return transform_type<unsigned int>(file_location);
}
if (header.dtype.tie() == npy::dtype_map.at(std::type_index(typeid(unsigned long))).tie()) {
return transform_type<unsigned long>(file_location);
}
/*
if(header.dtype.tie() == npy::dtype_map.at(std::type_index(typeid(bool))).tie()) {
return transform_type<bool>(stream);
}
*/
throw std::runtime_error("not supported type in load_npy");
}
/**
* \brief the file is supposed to be in order [b, c, h, w],
* this function will just load these numbers line by line
*/
static std::vector<T> Load1dVectorFromFile(std::string const &file_location) {
auto load_txt = [](const std::string &file_location) {
std::vector<T> result_vector;
std::string line;
std::ifstream input_file(file_location);
if (!input_file) {
SPDLOG_LOGGER_ERROR(dynasty::log::InferenceLogger::GetLogger(), "Can not open " + file_location);
exit(1);
}
while (getline(input_file, line)) {
result_vector.push_back(DataConverter(line));
}
input_file.close();
return result_vector;
};
std::filesystem::path path = file_location;
std::string ext = path.extension();
std::vector<T> res;
if (ext == ".txt" || ext == ".golden") {
res = load_txt(file_location);
} else if (ext == ".npy") {
res = load_npy(file_location);
} else {
throw std::runtime_error("Not supported extention(only txt/npy supported): " + ext);
}
return res;
}
protected:
static T DataConverter(std::string const &num);
~ONNXVectorIO() = default;
};
} // namespace io
} // namespace dynasty
#endif // PIANO_DYNASTY_INCLUDE_IO_ONNXVECTORIO_H_