210 lines
7.7 KiB
C++
210 lines
7.7 KiB
C++
/**
|
|
* \brief ONNXVectorIO declarations and definitions
|
|
* \author Nan Zhou, nanzhou at kneron dot us
|
|
* \copyright 2019 Kneron Inc. All right reserved.
|
|
*/
|
|
|
|
#ifndef PIANO_DYNASTY_INCLUDE_IO_ONNXVECTORIO_H_
|
|
#define PIANO_DYNASTY_INCLUDE_IO_ONNXVECTORIO_H_
|
|
|
|
#include <unistd.h>
|
|
|
|
#include <filesystem>
|
|
#include <fstream>
|
|
#include <iomanip>
|
|
#include <iostream>
|
|
#include <numeric>
|
|
|
|
#include "InferenceLogger.h"
|
|
#include "fmt/format.h"
|
|
#include "fmt/ostream.h"
|
|
#include "npy.hpp"
|
|
|
|
namespace dynasty {
|
|
namespace io {
|
|
/**
|
|
* \class ONNXVectorIO
|
|
* \brief a class to save vectors to files, or load vectors from files, the order follows tensors in ONNX, which is [b,
|
|
* c, h, w]
|
|
*/
|
|
template <typename T>
|
|
class ONNXVectorIO {
|
|
public:
|
|
/**
|
|
* \brief save the input vector to a file, the order is [b, c, h, w]
|
|
*/
|
|
static void Save1dVectorToFile(std::vector<T> const &input_vector, std::string const &file_location) {
|
|
auto save_txt = [](std::vector<T> const &input_vector, const std::string &file_location) {
|
|
std::ofstream f_out(file_location);
|
|
auto out = fmt::memory_buffer();
|
|
for (auto const &i : input_vector) {
|
|
if (std::is_floating_point_v<T> == true)
|
|
fmt::format_to(std::back_inserter(out), "{:.8f}\n", i);
|
|
else if (std::is_integral_v<T> == true)
|
|
fmt::format_to(std::back_inserter(out), "{}\n", i);
|
|
else
|
|
throw std::runtime_error("unsupported data type" + std::string(typeid(T).name()));
|
|
}
|
|
std::string s(out.begin(), out.end());
|
|
fmt::print(f_out, "{}", s);
|
|
f_out.flush();
|
|
f_out.close();
|
|
};
|
|
|
|
auto save_npy = [](std::vector<T> const &input_vector, const std::string &file_location) {
|
|
const npy::npy_data_ptr<T> data_ptr{input_vector.data(), {1, input_vector.size()}, false};
|
|
write_npy(file_location, data_ptr);
|
|
};
|
|
|
|
std::filesystem::path path = file_location;
|
|
std::string ext = path.extension();
|
|
|
|
if (ext == ".txt") {
|
|
save_txt(input_vector, file_location);
|
|
} else if (ext == ".npy") {
|
|
save_npy(input_vector, file_location);
|
|
} else {
|
|
throw std::runtime_error("Not supported extention(only txt/npy supported): " + ext);
|
|
}
|
|
}
|
|
|
|
/**
|
|
* \brief the file is supposed to be in order [b, h, w, c],
|
|
* we convert it to [b, c, h, w]
|
|
*/
|
|
static std::vector<T> Load1dVectorFromFile(std::string const &file_location, std::vector<int32_t> const &shape) {
|
|
std::vector<T> raw_vector = Load1dVectorFromFile(file_location);
|
|
size_t shape_size = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int32_t>());
|
|
if (raw_vector.size() != shape_size) {
|
|
throw std::runtime_error("size in file: " + std::to_string(raw_vector.size()) +
|
|
"not match with size calced in shape: " + std::to_string(shape_size));
|
|
}
|
|
|
|
if (shape.size() != 4) return raw_vector;
|
|
std::vector<T> result_vector(raw_vector.size());
|
|
|
|
uint32_t b, c, h, w;
|
|
b = shape.at(0);
|
|
c = shape.at(1);
|
|
h = shape.at(2);
|
|
w = shape.at(3);
|
|
|
|
for (size_t index = 0; index < raw_vector.size(); index++) {
|
|
size_t in = index;
|
|
uint32_t k = in / (c * h * w); // batch_id
|
|
in %= (c * h * w);
|
|
uint32_t i = in / c;
|
|
uint32_t j = in % c;
|
|
uint32_t addr = k * (c * h * w) + j * (w * h) + i;
|
|
|
|
result_vector.at(addr) = raw_vector[index];
|
|
}
|
|
|
|
return result_vector;
|
|
}
|
|
|
|
template <typename ScalarFm>
|
|
static std::vector<T> transform_type(const std::string &file_location) {
|
|
std::vector<ScalarFm> d = npy::read_npy<ScalarFm>(file_location).data;
|
|
std::vector<T> ret;
|
|
std::transform(d.begin(), d.end(), std::back_inserter(ret), [](auto d_) { return T(d_); });
|
|
return ret;
|
|
}
|
|
|
|
static std::vector<T> load_npy(const std::string &file_location) {
|
|
std::string header_s;
|
|
{
|
|
std::ifstream stream(file_location, std::ifstream::binary);
|
|
if (!stream) {
|
|
throw std::runtime_error("io error: failed to open a file.");
|
|
}
|
|
header_s = npy::read_header(stream);
|
|
}
|
|
std::vector<T> ret;
|
|
|
|
// parse header
|
|
npy::header_t header = npy::parse_header(header_s);
|
|
|
|
if (header.dtype.tie() == npy::dtype_map.at(std::type_index(typeid(float))).tie()) {
|
|
return transform_type<float>(file_location);
|
|
}
|
|
|
|
if (header.dtype.tie() == npy::dtype_map.at(std::type_index(typeid(double))).tie()) {
|
|
return transform_type<double>(file_location);
|
|
}
|
|
|
|
if (header.dtype.tie() == npy::dtype_map.at(std::type_index(typeid(char))).tie()) {
|
|
return transform_type<char>(file_location);
|
|
}
|
|
if (header.dtype.tie() == npy::dtype_map.at(std::type_index(typeid(short))).tie()) {
|
|
return transform_type<short>(file_location);
|
|
}
|
|
if (header.dtype.tie() == npy::dtype_map.at(std::type_index(typeid(int))).tie()) {
|
|
return transform_type<int>(file_location);
|
|
}
|
|
if (header.dtype.tie() == npy::dtype_map.at(std::type_index(typeid(long))).tie()) {
|
|
return transform_type<long>(file_location);
|
|
}
|
|
if (header.dtype.tie() == npy::dtype_map.at(std::type_index(typeid(unsigned char))).tie()) {
|
|
return transform_type<unsigned char>(file_location);
|
|
}
|
|
if (header.dtype.tie() == npy::dtype_map.at(std::type_index(typeid(unsigned short))).tie()) {
|
|
return transform_type<unsigned short>(file_location);
|
|
}
|
|
if (header.dtype.tie() == npy::dtype_map.at(std::type_index(typeid(unsigned int))).tie()) {
|
|
return transform_type<unsigned int>(file_location);
|
|
}
|
|
if (header.dtype.tie() == npy::dtype_map.at(std::type_index(typeid(unsigned long))).tie()) {
|
|
return transform_type<unsigned long>(file_location);
|
|
}
|
|
/*
|
|
if(header.dtype.tie() == npy::dtype_map.at(std::type_index(typeid(bool))).tie()) {
|
|
return transform_type<bool>(stream);
|
|
}
|
|
*/
|
|
|
|
throw std::runtime_error("not supported type in load_npy");
|
|
}
|
|
|
|
/**
|
|
* \brief the file is supposed to be in order [b, c, h, w],
|
|
* this function will just load these numbers line by line
|
|
*/
|
|
static std::vector<T> Load1dVectorFromFile(std::string const &file_location) {
|
|
auto load_txt = [](const std::string &file_location) {
|
|
std::vector<T> result_vector;
|
|
std::string line;
|
|
std::ifstream input_file(file_location);
|
|
if (!input_file) {
|
|
SPDLOG_LOGGER_ERROR(dynasty::log::InferenceLogger::GetLogger(), "Can not open " + file_location);
|
|
exit(1);
|
|
}
|
|
while (getline(input_file, line)) {
|
|
result_vector.push_back(DataConverter(line));
|
|
}
|
|
input_file.close();
|
|
return result_vector;
|
|
};
|
|
|
|
std::filesystem::path path = file_location;
|
|
std::string ext = path.extension();
|
|
std::vector<T> res;
|
|
if (ext == ".txt" || ext == ".golden") {
|
|
res = load_txt(file_location);
|
|
} else if (ext == ".npy") {
|
|
res = load_npy(file_location);
|
|
} else {
|
|
throw std::runtime_error("Not supported extention(only txt/npy supported): " + ext);
|
|
}
|
|
return res;
|
|
}
|
|
|
|
protected:
|
|
static T DataConverter(std::string const &num);
|
|
~ONNXVectorIO() = default;
|
|
};
|
|
} // namespace io
|
|
} // namespace dynasty
|
|
|
|
#endif // PIANO_DYNASTY_INCLUDE_IO_ONNXVECTORIO_H_
|