220 lines
8.0 KiB
C++
220 lines
8.0 KiB
C++
/**
|
|
* \brief main function to run Inferencer
|
|
* \author Nan Zhou, nanzhou at kneron dot us
|
|
* \copyright 2019 Kneron Inc. All right reserved.
|
|
*/
|
|
|
|
#include <algorithm>
|
|
#include <iostream>
|
|
#include <memory>
|
|
#include <string>
|
|
#include <unordered_map>
|
|
#include <unordered_set>
|
|
|
|
|
|
#include "AbstractInferencer.h"
|
|
#include "cpu/InferencerPImpl.h"
|
|
|
|
|
|
#ifdef CUDA_FOUND
|
|
#include "msft-gpu/MSFTInferencer.h"
|
|
#include "msft-gpu/InferencerPImpl.h"
|
|
#else
|
|
#include "msft-gpu/dummyInferencerPImpl.h"
|
|
#endif
|
|
|
|
#include "msft/MSFTInferencer.h"
|
|
|
|
#include "cxxopts.hpp"
|
|
#include "ONNXVectorIO.h"
|
|
#include "UnixFileManager.h"
|
|
#include "JsonKeyConstants.h"
|
|
#include "JsonIO.h"
|
|
|
|
using std::string;
|
|
using std::vector;
|
|
using cxxopts::Options;
|
|
using std::cout;
|
|
using std::endl;
|
|
using std::cerr;
|
|
using std::pair;
|
|
using std::unordered_map;
|
|
using std::unordered_set;
|
|
using std::unique_ptr;
|
|
using std::shared_ptr;
|
|
using std::replace_if;
|
|
using dynasty::io::JsonIO;
|
|
using dynasty::io::UnixFileManager;
|
|
using dynasty::io::ONNXVectorIO;
|
|
using namespace dynasty::inferencer;
|
|
|
|
/*
|
|
* Parse and validate the arguments
|
|
*/
|
|
cxxopts::ParseResult parse(int argc, char *argv[]) {
|
|
try {
|
|
unordered_set<string> type_set = {"CPU", "MSFT", "MSFT-CUDA"};
|
|
|
|
cxxopts::Options options("Inferencer", "Kneron's Neural Network Inferencer");
|
|
options.positional_help("[optional args]")
|
|
.show_positional_help();
|
|
|
|
options.allow_unrecognised_options()
|
|
.add_options()
|
|
("i, input", "input config json file", cxxopts::value<string>())
|
|
("e, encrypt", "use encrypted model or not", cxxopts::value<bool>())
|
|
("t, type", "inferencer type; AVAILABLE CHOICES: CPU, MSFT", cxxopts::value<string>())
|
|
|
|
("h, help", "optional, print help")
|
|
("d, device", "optional, gpu device number", cxxopts::value<int>())
|
|
("o, output",
|
|
"optional, path_to_folder to save the outputs; CREATE THE FOLDER BEFORE CALLING THE BINARY",
|
|
cxxopts::value<string>());
|
|
auto result = options.parse(argc, argv);
|
|
|
|
if (result.count("help")) {
|
|
cout << options.help() << endl;
|
|
exit(0);
|
|
}
|
|
|
|
// parse input json
|
|
if (result.count("input") == 0) {
|
|
cout
|
|
<< "*******************************\nInput is not specified but it is required\n*******************************\n"
|
|
<< endl;
|
|
cout << options.help() << endl;
|
|
exit(1);
|
|
} else {
|
|
Json::Value config_root = JsonIO::LoadJson(result["input"].as<string>());
|
|
|
|
bool is_config_valid = true;
|
|
if (!(config_root.isMember(dynasty::common::MODEL_PATH_STR) &&
|
|
config_root.isMember(dynasty::common::MODEL_INPUT_TXTS_STR))) {
|
|
is_config_valid = false;
|
|
}
|
|
if (is_config_valid) {
|
|
for (auto const &per_input : config_root[dynasty::common::MODEL_INPUT_TXTS_STR]) {
|
|
if (!(per_input.isMember(dynasty::common::DATA_VECTOR_STR)
|
|
&& per_input.isMember(dynasty::common::OPERATION_NAME_STR))) {
|
|
is_config_valid = false;
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
if (!is_config_valid) {
|
|
cout
|
|
<< "*******************************\nInput json is not valid. Please refer to examples in dynasty/conf \n*******************************\n"
|
|
<< endl;
|
|
cout << options.help() << endl;
|
|
exit(1);
|
|
}
|
|
}
|
|
cout << "input = " << result["input"].as<string>() << endl;
|
|
|
|
// parse encryption
|
|
if (result.count("encrypt") == 0) {
|
|
cout
|
|
<< "*******************************\nEncrypt is not specified but it is required or Encrypt can not be recognized\n*******************************\n"
|
|
<< endl;
|
|
cout << options.help() << endl;
|
|
exit(1);
|
|
}
|
|
cout << "encrypt = " << (result["encrypt"].as<bool>() ? "true" : "false" )<< endl;
|
|
|
|
// parse inferencer type
|
|
if (result.count("type") == 0 || type_set.count(result["type"].as<string>()) == 0) {
|
|
cout
|
|
<< "*******************************\nType is not specified but it is required or Type can not be recognized\n*******************************\n"
|
|
<< endl;
|
|
cout << options.help() << endl;
|
|
exit(1);
|
|
}
|
|
cout << "type = " << result["type"].as<string>() << endl;
|
|
|
|
|
|
// parse output directory
|
|
if (result.count("output")) {
|
|
if (!UnixFileManager::IsDirectory(result["output"].as<string>())) {
|
|
cout
|
|
<< "*******************************\nOutput " << result["output"].as<string>()
|
|
<< " is not a existing directory\n*******************************\n"
|
|
<< endl;
|
|
cout << options.help() << endl;
|
|
exit(1);
|
|
}
|
|
cout << "output = " << result["output"].as<string>() << endl;
|
|
}
|
|
return result;
|
|
} catch (const cxxopts::OptionException &e) {
|
|
cout << "error parsing options: " << e.what() << endl;
|
|
exit(1);
|
|
}
|
|
}
|
|
|
|
int main(int argc, char *argv[]) {
|
|
{
|
|
auto result = parse(argc, argv);
|
|
|
|
Json::Value config_root = JsonIO::LoadJson(result["input"].as<string>());
|
|
|
|
string model_file = config_root[dynasty::common::MODEL_PATH_STR].asString();
|
|
string model_type = result["type"].as<string>();
|
|
|
|
InferencerUniquePtr<float> inferencer = nullptr;
|
|
|
|
if (!result["encrypt"].as<bool>()) {
|
|
if (model_type == "CPU") {
|
|
inferencer = cpu::Inferencer<float>::GetBuilder()->
|
|
WithGraphOptimization(1)->
|
|
WithONNXModel(model_file)->
|
|
Build();
|
|
} else if (model_type == "MSFT") {
|
|
inferencer = msft::Inferencer<float>::GetBuilder()->
|
|
WithGraphOptimization(1)->
|
|
WithParallelLevel(1)->
|
|
WithONNXModel(model_file)->
|
|
Build();
|
|
} else if (model_type == "MSFT-CUDA") {
|
|
inferencer = msftgpu::Inferencer<float>::GetBuilder()->
|
|
WithDeviceID(0)->
|
|
WithGraphOptimization(1)->
|
|
WithParallelLevel(1)->
|
|
WithONNXModel(model_file)->
|
|
Build();
|
|
} else {
|
|
cout << "error, unsupported model type: " << model_type << endl;
|
|
exit(1);
|
|
}
|
|
} else {
|
|
if (model_type == "CPU") {
|
|
inferencer = cpu::Inferencer<float>::GetBuilder()->
|
|
WithGraphOptimization(1)->
|
|
WithBIEModel(model_file)->
|
|
Build();
|
|
} else if ( model_type == "MSFT") {
|
|
cout << model_type << " does not support BIE model" << endl;
|
|
exit(1);
|
|
} else {
|
|
cout << "error, unsupported model type: " << model_type << endl;
|
|
exit(1);
|
|
}
|
|
}
|
|
|
|
unordered_map<string, vector<float>> inference_output;
|
|
inference_output = inferencer->Inference(result["input"].as<string>());
|
|
|
|
if (result.count("output")) {
|
|
for (auto const &operation_name_vector_pair : inference_output) {
|
|
string operation_name = operation_name_vector_pair.first;
|
|
replace_if(operation_name.begin(), operation_name.end(), [](char ch) { return ch == '/'; }, '_');
|
|
ONNXVectorIO<float>::Save1dVectorToFile(operation_name_vector_pair.second,
|
|
UnixFileManager::PathJoin(result["output"].as<string>(),
|
|
operation_name));
|
|
}
|
|
}
|
|
cout << "Done!" << endl;
|
|
}
|
|
|
|
return 0;
|
|
}
|