137 lines
4.6 KiB
C++
137 lines
4.6 KiB
C++
/**
|
|
* \brief ONNXVectorIO declarations and definitions
|
|
* \author Nan Zhou, nanzhou at kneron dot us
|
|
* \copyright 2019 Kneron Inc. All right reserved.
|
|
*/
|
|
|
|
#ifndef PIANO_DYNASTY_INCLUDE_IO_ONNXVECTORIO_H_
|
|
#define PIANO_DYNASTY_INCLUDE_IO_ONNXVECTORIO_H_
|
|
|
|
#include <fstream>
|
|
#include <iomanip>
|
|
#include <iostream>
|
|
#include <unistd.h>
|
|
|
|
#include "InferenceLogger.h"
|
|
|
|
namespace dynasty {
|
|
namespace io {
|
|
/**
|
|
* \class ONNXVectorIO
|
|
* \brief a class to save vectors to files, or load vectors from files, the order follows tensors in ONNX, which is [b, c, h, w]
|
|
*/
|
|
template<typename T>
|
|
class ONNXVectorIO {
|
|
public:
|
|
/**
|
|
* \brief save the input vector to a file, the order is [b, c, h, w]
|
|
*/
|
|
static void Save1dVectorToFile(std::vector<T> const &input_vector, std::string const &file_location) {
|
|
size_t pos = file_location.find_last_of("/\\");
|
|
std::string dirname = file_location.substr(0, pos);
|
|
std::string filename = file_location.substr(pos + 1);
|
|
|
|
int result = chdir(dirname.c_str());
|
|
if (result == -1) {
|
|
std::cerr << "Failed to change current working directory to dump directory" << dirname << std::endl;
|
|
}
|
|
|
|
std::ofstream f_out;
|
|
// Set exceptions to be thrown on failure
|
|
f_out.exceptions(std::ifstream::failbit | std::ifstream::badbit);
|
|
try {
|
|
f_out.open(filename);
|
|
} catch (std::system_error &e) {
|
|
SPDLOG_LOGGER_ERROR(dynasty::log::InferenceLogger::GetLogger(), "Failed to open file {}: {}",
|
|
file_location, e.code().message().c_str());
|
|
exit(1);
|
|
}
|
|
for (auto const &i : input_vector) {
|
|
f_out << std::setprecision(8) << i << std::endl;
|
|
}
|
|
}
|
|
|
|
/**
|
|
* \brief the file is supposed to be in order [b, h, w, c],
|
|
* we convert it to [b, c, h, w]
|
|
*/
|
|
static std::vector<T> Load1dVectorFromFile(std::string const &file_location, std::vector<int32_t > const &shape) {
|
|
uint32_t c = 0, h = 0, w = 0, b = 0;
|
|
b = shape.size() > 0 ? shape.at(0) : 1;
|
|
c = shape.size() > 1 ? shape.at(1) : 1;
|
|
h = shape.size() > 2 ? shape.at(2) : 1;
|
|
w = shape.size() > 3 ? shape.at(3) : 1;
|
|
if (shape.size() != 4) { // not bhwc format
|
|
c = 1;
|
|
h = 1;
|
|
w = 1;
|
|
b = 1;
|
|
for (uint32_t i = 0; i < shape.size(); i++) {
|
|
w *= shape[i];
|
|
}
|
|
}
|
|
std::vector<T> result_vector(b * c * h * w);
|
|
std::string line;
|
|
std::ifstream input_file(file_location);
|
|
if (!input_file) {
|
|
SPDLOG_LOGGER_ERROR(dynasty::log::InferenceLogger::GetLogger(), "Can not open " + file_location);
|
|
exit(1);
|
|
}
|
|
uint32_t index = 0;
|
|
while (getline(input_file, line)) {
|
|
uint32_t in = index;
|
|
uint32_t k = in / (c*h*w); // batch_id
|
|
in %= (c*h*w);
|
|
uint32_t i = in / c;
|
|
uint32_t j = in % c;
|
|
uint32_t addr = k * (c * h * w) +
|
|
j * (w * h) +
|
|
i;
|
|
if(addr >= b * c * w * h){
|
|
SPDLOG_LOGGER_ERROR(dynasty::log::InferenceLogger::GetLogger(), "Input file dimension is greater than expected: "
|
|
+ std::to_string(b*c*w*h) + ". It has reached: " + std::to_string(addr+1) + " lines. " + file_location);
|
|
exit(1);
|
|
|
|
}
|
|
|
|
result_vector.at(addr) = DataConverter(line);
|
|
index++;
|
|
}
|
|
|
|
if (index != b * c * w * h) {
|
|
SPDLOG_LOGGER_ERROR(dynasty::log::InferenceLogger::GetLogger(), "Input file dimension is not correct: expected: "
|
|
+ std::to_string(b*c*w*h) + ", but got: " + std::to_string(index) + " lines. " + file_location);
|
|
exit(1);
|
|
}
|
|
input_file.close();
|
|
return result_vector;
|
|
}
|
|
|
|
/**
|
|
* \brief the file is supposed to be in order [b, c, h, w],
|
|
* this function will just load these numbers line by line
|
|
*/
|
|
static std::vector<T> Load1dVectorFromFile(std::string const &file_location) {
|
|
std::vector<T> result_vector;
|
|
std::string line;
|
|
std::ifstream input_file(file_location);
|
|
if (!input_file) {
|
|
SPDLOG_LOGGER_ERROR(dynasty::log::InferenceLogger::GetLogger(), "Can not open " + file_location);
|
|
exit(1);
|
|
}
|
|
while (getline(input_file, line)) {
|
|
result_vector.push_back(DataConverter(line));
|
|
}
|
|
input_file.close();
|
|
return result_vector;
|
|
}
|
|
|
|
protected:
|
|
static T DataConverter(std::string const &num);
|
|
~ONNXVectorIO() = default;
|
|
};
|
|
}
|
|
}
|
|
|
|
#endif //PIANO_DYNASTY_INCLUDE_IO_ONNXVECTORIO_H_
|