2026-01-28 06:16:04 +00:00

220 lines
8.0 KiB
C++

/**
* \brief main function to run Inferencer
* \author Nan Zhou, nanzhou at kneron dot us
* \copyright 2019 Kneron Inc. All right reserved.
*/
#include <algorithm>
#include <iostream>
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include "AbstractInferencer.h"
#include "cpu/InferencerPImpl.h"
#ifdef CUDA_FOUND
#include "msft-gpu/MSFTInferencer.h"
#include "msft-gpu/InferencerPImpl.h"
#else
#include "msft-gpu/dummyInferencerPImpl.h"
#endif
#include "msft/MSFTInferencer.h"
#include "cxxopts.hpp"
#include "ONNXVectorIO.h"
#include "UnixFileManager.h"
#include "JsonKeyConstants.h"
#include "JsonIO.h"
using std::string;
using std::vector;
using cxxopts::Options;
using std::cout;
using std::endl;
using std::cerr;
using std::pair;
using std::unordered_map;
using std::unordered_set;
using std::unique_ptr;
using std::shared_ptr;
using std::replace_if;
using dynasty::io::JsonIO;
using dynasty::io::UnixFileManager;
using dynasty::io::ONNXVectorIO;
using namespace dynasty::inferencer;
/*
* Parse and validate the arguments
*/
cxxopts::ParseResult parse(int argc, char *argv[]) {
try {
unordered_set<string> type_set = {"CPU", "MSFT", "MSFT-CUDA"};
cxxopts::Options options("Inferencer", "Kneron's Neural Network Inferencer");
options.positional_help("[optional args]")
.show_positional_help();
options.allow_unrecognised_options()
.add_options()
("i, input", "input config json file", cxxopts::value<string>())
("e, encrypt", "use encrypted model or not", cxxopts::value<bool>())
("t, type", "inferencer type; AVAILABLE CHOICES: CPU, MSFT", cxxopts::value<string>())
("h, help", "optional, print help")
("d, device", "optional, gpu device number", cxxopts::value<int>())
("o, output",
"optional, path_to_folder to save the outputs; CREATE THE FOLDER BEFORE CALLING THE BINARY",
cxxopts::value<string>());
auto result = options.parse(argc, argv);
if (result.count("help")) {
cout << options.help() << endl;
exit(0);
}
// parse input json
if (result.count("input") == 0) {
cout
<< "*******************************\nInput is not specified but it is required\n*******************************\n"
<< endl;
cout << options.help() << endl;
exit(1);
} else {
Json::Value config_root = JsonIO::LoadJson(result["input"].as<string>());
bool is_config_valid = true;
if (!(config_root.isMember(dynasty::common::MODEL_PATH_STR) &&
config_root.isMember(dynasty::common::MODEL_INPUT_TXTS_STR))) {
is_config_valid = false;
}
if (is_config_valid) {
for (auto const &per_input : config_root[dynasty::common::MODEL_INPUT_TXTS_STR]) {
if (!(per_input.isMember(dynasty::common::DATA_VECTOR_STR)
&& per_input.isMember(dynasty::common::OPERATION_NAME_STR))) {
is_config_valid = false;
break;
}
}
}
if (!is_config_valid) {
cout
<< "*******************************\nInput json is not valid. Please refer to examples in dynasty/conf \n*******************************\n"
<< endl;
cout << options.help() << endl;
exit(1);
}
}
cout << "input = " << result["input"].as<string>() << endl;
// parse encryption
if (result.count("encrypt") == 0) {
cout
<< "*******************************\nEncrypt is not specified but it is required or Encrypt can not be recognized\n*******************************\n"
<< endl;
cout << options.help() << endl;
exit(1);
}
cout << "encrypt = " << (result["encrypt"].as<bool>() ? "true" : "false" )<< endl;
// parse inferencer type
if (result.count("type") == 0 || type_set.count(result["type"].as<string>()) == 0) {
cout
<< "*******************************\nType is not specified but it is required or Type can not be recognized\n*******************************\n"
<< endl;
cout << options.help() << endl;
exit(1);
}
cout << "type = " << result["type"].as<string>() << endl;
// parse output directory
if (result.count("output")) {
if (!UnixFileManager::IsDirectory(result["output"].as<string>())) {
cout
<< "*******************************\nOutput " << result["output"].as<string>()
<< " is not a existing directory\n*******************************\n"
<< endl;
cout << options.help() << endl;
exit(1);
}
cout << "output = " << result["output"].as<string>() << endl;
}
return result;
} catch (const cxxopts::OptionException &e) {
cout << "error parsing options: " << e.what() << endl;
exit(1);
}
}
int main(int argc, char *argv[]) {
{
auto result = parse(argc, argv);
Json::Value config_root = JsonIO::LoadJson(result["input"].as<string>());
string model_file = config_root[dynasty::common::MODEL_PATH_STR].asString();
string model_type = result["type"].as<string>();
InferencerUniquePtr<float> inferencer = nullptr;
if (!result["encrypt"].as<bool>()) {
if (model_type == "CPU") {
inferencer = cpu::Inferencer<float>::GetBuilder()->
WithGraphOptimization(1)->
WithONNXModel(model_file)->
Build();
} else if (model_type == "MSFT") {
inferencer = msft::Inferencer<float>::GetBuilder()->
WithGraphOptimization(1)->
WithParallelLevel(1)->
WithONNXModel(model_file)->
Build();
} else if (model_type == "MSFT-CUDA") {
inferencer = msftgpu::Inferencer<float>::GetBuilder()->
WithDeviceID(0)->
WithGraphOptimization(1)->
WithParallelLevel(1)->
WithONNXModel(model_file)->
Build();
} else {
cout << "error, unsupported model type: " << model_type << endl;
exit(1);
}
} else {
if (model_type == "CPU") {
inferencer = cpu::Inferencer<float>::GetBuilder()->
WithGraphOptimization(1)->
WithBIEModel(model_file)->
Build();
} else if ( model_type == "MSFT") {
cout << model_type << " does not support BIE model" << endl;
exit(1);
} else {
cout << "error, unsupported model type: " << model_type << endl;
exit(1);
}
}
unordered_map<string, vector<float>> inference_output;
inference_output = inferencer->Inference(result["input"].as<string>());
if (result.count("output")) {
for (auto const &operation_name_vector_pair : inference_output) {
string operation_name = operation_name_vector_pair.first;
replace_if(operation_name.begin(), operation_name.end(), [](char ch) { return ch == '/'; }, '_');
ONNXVectorIO<float>::Save1dVectorToFile(operation_name_vector_pair.second,
UnixFileManager::PathJoin(result["output"].as<string>(),
operation_name));
}
}
cout << "Done!" << endl;
}
return 0;
}