104 lines
3.3 KiB
C++
104 lines
3.3 KiB
C++
/**
|
|
* \brief An implementation of {@see AbstractInferencer} using Microsoft ONNX Runtime to do inference
|
|
* \author Nan Zhou, nanzhou at kneron dot us
|
|
* \copyright Kneron Inc. All right reserved.
|
|
*/
|
|
|
|
#ifndef PIANO_DYNASTY_FLOATING_POINT_INCLUDE_MSFT_MSFTINFERENCER_H_
|
|
#define PIANO_DYNASTY_FLOATING_POINT_INCLUDE_MSFT_MSFTINFERENCER_H_
|
|
|
|
#include <memory>
|
|
#include <string>
|
|
#include <limits.h>
|
|
|
|
#include "AbstractInferencer.h"
|
|
namespace Ort {
|
|
class Env;
|
|
}
|
|
|
|
namespace Ort {
|
|
class Session;
|
|
class Value;
|
|
class SessionOptions;
|
|
}
|
|
|
|
|
|
|
|
namespace dynasty {
|
|
namespace inferencer {
|
|
namespace msft {
|
|
|
|
struct InputOutputNodeInfo ;
|
|
template <typename T>
|
|
class Inferencer : public AbstractInferencer<T> {
|
|
public:
|
|
class Builder : public AbstractInferencer<T>::Builder {
|
|
friend class Inferencer;
|
|
protected:
|
|
uint32_t parallel_level_ = INT_MAX;
|
|
uint32_t graph_optimize_level_ = INT_MAX;
|
|
bool dump_enabled_ = false;
|
|
std::shared_ptr<Ort::Env> enviroment_ = nullptr;
|
|
std::shared_ptr<Ort::Session> session_ = nullptr;
|
|
InputOutputNodeInfo* inputOutputNodeInfo_ = nullptr;
|
|
Builder();
|
|
Builder *CreateSession(std::string const &model_file, bool isEncryted=false);
|
|
void SanityCheck(); // may throw
|
|
virtual void ConfigSessionOption(Ort::SessionOptions& option);
|
|
|
|
public:
|
|
Builder *WithGraphOptimization(uint32_t level);
|
|
Builder *WithParallelLevel(uint32_t parallel_level);
|
|
Builder *WithONNXModel(std::string const &model_file);
|
|
Builder *WithEncryptedModel(std::string const &model_file);
|
|
Builder *WithDumpEnable(bool dump_enabled) { dump_enabled_ = dump_enabled; return this;}
|
|
~Builder();
|
|
InferencerUniquePtr<T> Build() override;
|
|
|
|
/**
|
|
* @return {{operation name : bchw}}
|
|
*/
|
|
std::vector<std::pair<std::string, std::vector<int32_t >>> GetInputDimensions();
|
|
std::vector<std::pair<std::string, std::vector<int32_t >>> GetOutputDimensions();
|
|
InputOutputNodeInfo* GetInputOutputInfo() const { return inputOutputNodeInfo_; }
|
|
};
|
|
|
|
protected:
|
|
std::shared_ptr<Ort::Env> enviroment_ = nullptr;
|
|
std::shared_ptr<Ort::Session> session_ = nullptr;
|
|
|
|
std::vector<std::string> input_node_names_;
|
|
std::vector<std::vector<int64_t>> input_node_dims_;
|
|
|
|
std::vector<std::string> output_node_names_;
|
|
std::vector<std::vector<int64_t>> output_node_dims_;
|
|
bool dump_enabled_ = false;
|
|
|
|
protected:
|
|
void CheckInput(std::unordered_map<std::string, std::vector<T>> const &preprocess_input);
|
|
|
|
explicit Inferencer(std::shared_ptr<Ort::Env> enviroment, std::shared_ptr<Ort::Session> session);
|
|
|
|
std::unordered_map<std::string, std::vector<T>> Inference(
|
|
std::unordered_map<std::string, std::string> const &preprocess_input, bool only_output_layers) override;
|
|
|
|
std::unordered_map<std::string, std::vector<T>> Inference(
|
|
std::unordered_map<std::string, std::vector<T>> const &preprocess_input,
|
|
bool only_output_layers) override;
|
|
|
|
public:
|
|
~Inferencer() override;
|
|
|
|
static Builder* GetBuilder();
|
|
|
|
void SetDumpEnable(bool dump_enabled = false) { dump_enabled_ = dump_enabled; }
|
|
void dump(std::vector<Ort::Value>& result);
|
|
};
|
|
|
|
|
|
}
|
|
}
|
|
}
|
|
|
|
#endif //PIANO_DYNASTY_FLOATING_POINT_INCLUDE_MSFT_MSFTINFERENCER_H_
|