/** * \brief An implementation of {@see AbstractInferencer} using Microsoft ONNX Runtime to do inference * \author Nan Zhou, nanzhou at kneron dot us * \copyright Kneron Inc. All right reserved. */ #ifndef PIANO_DYNASTY_FLOATING_POINT_INCLUDE_MSFT_MSFTINFERENCER_H_ #define PIANO_DYNASTY_FLOATING_POINT_INCLUDE_MSFT_MSFTINFERENCER_H_ #include #include #include #include "AbstractInferencer.h" namespace Ort { class Env; } namespace Ort { class Session; class Value; class SessionOptions; } namespace dynasty { namespace inferencer { namespace msft { struct InputOutputNodeInfo ; template class Inferencer : public AbstractInferencer { public: class Builder : public AbstractInferencer::Builder { friend class Inferencer; protected: uint32_t parallel_level_ = INT_MAX; uint32_t graph_optimize_level_ = INT_MAX; bool dump_enabled_ = false; std::shared_ptr enviroment_ = nullptr; std::shared_ptr session_ = nullptr; InputOutputNodeInfo* inputOutputNodeInfo_ = nullptr; Builder(); Builder *CreateSession(std::string const &model_file, bool isEncryted=false); void SanityCheck(); // may throw virtual void ConfigSessionOption(Ort::SessionOptions& option); public: Builder *WithGraphOptimization(uint32_t level); Builder *WithParallelLevel(uint32_t parallel_level); Builder *WithONNXModel(std::string const &model_file); Builder *WithEncryptedModel(std::string const &model_file); Builder *WithDumpEnable(bool dump_enabled) { dump_enabled_ = dump_enabled; return this;} ~Builder(); InferencerUniquePtr Build() override; /** * @return {{operation name : bchw}} */ std::vector>> GetInputDimensions(); std::vector>> GetOutputDimensions(); InputOutputNodeInfo* GetInputOutputInfo() const { return inputOutputNodeInfo_; } }; protected: std::shared_ptr enviroment_ = nullptr; std::shared_ptr session_ = nullptr; std::vector input_node_names_; std::vector> input_node_dims_; std::vector output_node_names_; std::vector> output_node_dims_; bool dump_enabled_ = false; protected: void CheckInput(std::unordered_map> const &preprocess_input); explicit Inferencer(std::shared_ptr enviroment, std::shared_ptr session); std::unordered_map> Inference( std::unordered_map const &preprocess_input, bool only_output_layers) override; std::unordered_map> Inference( std::unordered_map> const &preprocess_input, bool only_output_layers) override; public: ~Inferencer() override; static Builder* GetBuilder(); void SetDumpEnable(bool dump_enabled = false) { dump_enabled_ = dump_enabled; } void dump(std::vector& result); }; } } } #endif //PIANO_DYNASTY_FLOATING_POINT_INCLUDE_MSFT_MSFTINFERENCER_H_