2026-01-28 06:16:04 +00:00

137 lines
4.6 KiB
C++

/**
* \brief ONNXVectorIO declarations and definitions
* \author Nan Zhou, nanzhou at kneron dot us
* \copyright 2019 Kneron Inc. All right reserved.
*/
#ifndef PIANO_DYNASTY_INCLUDE_IO_ONNXVECTORIO_H_
#define PIANO_DYNASTY_INCLUDE_IO_ONNXVECTORIO_H_
#include <fstream>
#include <iomanip>
#include <iostream>
#include <unistd.h>
#include "InferenceLogger.h"
namespace dynasty {
namespace io {
/**
* \class ONNXVectorIO
* \brief a class to save vectors to files, or load vectors from files, the order follows tensors in ONNX, which is [b, c, h, w]
*/
template<typename T>
class ONNXVectorIO {
public:
/**
* \brief save the input vector to a file, the order is [b, c, h, w]
*/
static void Save1dVectorToFile(std::vector<T> const &input_vector, std::string const &file_location) {
size_t pos = file_location.find_last_of("/\\");
std::string dirname = file_location.substr(0, pos);
std::string filename = file_location.substr(pos + 1);
int result = chdir(dirname.c_str());
if (result == -1) {
std::cerr << "Failed to change current working directory to dump directory" << dirname << std::endl;
}
std::ofstream f_out;
// Set exceptions to be thrown on failure
f_out.exceptions(std::ifstream::failbit | std::ifstream::badbit);
try {
f_out.open(filename);
} catch (std::system_error &e) {
SPDLOG_LOGGER_ERROR(dynasty::log::InferenceLogger::GetLogger(), "Failed to open file {}: {}",
file_location, e.code().message().c_str());
exit(1);
}
for (auto const &i : input_vector) {
f_out << std::setprecision(8) << i << std::endl;
}
}
/**
* \brief the file is supposed to be in order [b, h, w, c],
* we convert it to [b, c, h, w]
*/
static std::vector<T> Load1dVectorFromFile(std::string const &file_location, std::vector<int32_t > const &shape) {
uint32_t c = 0, h = 0, w = 0, b = 0;
b = shape.size() > 0 ? shape.at(0) : 1;
c = shape.size() > 1 ? shape.at(1) : 1;
h = shape.size() > 2 ? shape.at(2) : 1;
w = shape.size() > 3 ? shape.at(3) : 1;
if (shape.size() != 4) { // not bhwc format
c = 1;
h = 1;
w = 1;
b = 1;
for (uint32_t i = 0; i < shape.size(); i++) {
w *= shape[i];
}
}
std::vector<T> result_vector(b * c * h * w);
std::string line;
std::ifstream input_file(file_location);
if (!input_file) {
SPDLOG_LOGGER_ERROR(dynasty::log::InferenceLogger::GetLogger(), "Can not open " + file_location);
exit(1);
}
uint32_t index = 0;
while (getline(input_file, line)) {
uint32_t in = index;
uint32_t k = in / (c*h*w); // batch_id
in %= (c*h*w);
uint32_t i = in / c;
uint32_t j = in % c;
uint32_t addr = k * (c * h * w) +
j * (w * h) +
i;
if(addr >= b * c * w * h){
SPDLOG_LOGGER_ERROR(dynasty::log::InferenceLogger::GetLogger(), "Input file dimension is greater than expected: "
+ std::to_string(b*c*w*h) + ". It has reached: " + std::to_string(addr+1) + " lines. " + file_location);
exit(1);
}
result_vector.at(addr) = DataConverter(line);
index++;
}
if (index != b * c * w * h) {
SPDLOG_LOGGER_ERROR(dynasty::log::InferenceLogger::GetLogger(), "Input file dimension is not correct: expected: "
+ std::to_string(b*c*w*h) + ", but got: " + std::to_string(index) + " lines. " + file_location);
exit(1);
}
input_file.close();
return result_vector;
}
/**
* \brief the file is supposed to be in order [b, c, h, w],
* this function will just load these numbers line by line
*/
static std::vector<T> Load1dVectorFromFile(std::string const &file_location) {
std::vector<T> result_vector;
std::string line;
std::ifstream input_file(file_location);
if (!input_file) {
SPDLOG_LOGGER_ERROR(dynasty::log::InferenceLogger::GetLogger(), "Can not open " + file_location);
exit(1);
}
while (getline(input_file, line)) {
result_vector.push_back(DataConverter(line));
}
input_file.close();
return result_vector;
}
protected:
static T DataConverter(std::string const &num);
~ONNXVectorIO() = default;
};
}
}
#endif //PIANO_DYNASTY_INCLUDE_IO_ONNXVECTORIO_H_