2026-01-28 06:16:04 +00:00

104 lines
3.3 KiB
C++

/**
* \brief An implementation of {@see AbstractInferencer} using Microsoft ONNX Runtime to do inference
* \author Nan Zhou, nanzhou at kneron dot us
* \copyright Kneron Inc. All right reserved.
*/
#ifndef PIANO_DYNASTY_FLOATING_POINT_INCLUDE_MSFT_MSFTINFERENCER_H_
#define PIANO_DYNASTY_FLOATING_POINT_INCLUDE_MSFT_MSFTINFERENCER_H_
#include <memory>
#include <string>
#include <limits.h>
#include "AbstractInferencer.h"
namespace Ort {
class Env;
}
namespace Ort {
class Session;
class Value;
class SessionOptions;
}
namespace dynasty {
namespace inferencer {
namespace msft {
struct InputOutputNodeInfo ;
template <typename T>
class Inferencer : public AbstractInferencer<T> {
public:
class Builder : public AbstractInferencer<T>::Builder {
friend class Inferencer;
protected:
uint32_t parallel_level_ = INT_MAX;
uint32_t graph_optimize_level_ = INT_MAX;
bool dump_enabled_ = false;
std::shared_ptr<Ort::Env> enviroment_ = nullptr;
std::shared_ptr<Ort::Session> session_ = nullptr;
InputOutputNodeInfo* inputOutputNodeInfo_ = nullptr;
Builder();
Builder *CreateSession(std::string const &model_file, bool isEncryted=false);
void SanityCheck(); // may throw
virtual void ConfigSessionOption(Ort::SessionOptions& option);
public:
Builder *WithGraphOptimization(uint32_t level);
Builder *WithParallelLevel(uint32_t parallel_level);
Builder *WithONNXModel(std::string const &model_file);
Builder *WithEncryptedModel(std::string const &model_file);
Builder *WithDumpEnable(bool dump_enabled) { dump_enabled_ = dump_enabled; return this;}
~Builder();
InferencerUniquePtr<T> Build() override;
/**
* @return {{operation name : bchw}}
*/
std::vector<std::pair<std::string, std::vector<int32_t >>> GetInputDimensions();
std::vector<std::pair<std::string, std::vector<int32_t >>> GetOutputDimensions();
InputOutputNodeInfo* GetInputOutputInfo() const { return inputOutputNodeInfo_; }
};
protected:
std::shared_ptr<Ort::Env> enviroment_ = nullptr;
std::shared_ptr<Ort::Session> session_ = nullptr;
std::vector<const char*> input_node_names_;
std::vector<std::vector<int64_t>> input_node_dims_;
std::vector<const char*> output_node_names_;
std::vector<std::vector<int64_t>> output_node_dims_;
bool dump_enabled_ = false;
protected:
void CheckInput(std::unordered_map<std::string, std::vector<T>> const &preprocess_input);
explicit Inferencer(std::shared_ptr<Ort::Env> enviroment, std::shared_ptr<Ort::Session> session);
std::unordered_map<std::string, std::vector<T>> Inference(
std::unordered_map<std::string, std::string> const &preprocess_input, bool only_output_layers) override;
std::unordered_map<std::string, std::vector<T>> Inference(
std::unordered_map<std::string, std::vector<T>> const &preprocess_input,
bool only_output_layers) override;
public:
~Inferencer() override;
static Builder* GetBuilder();
void SetDumpEnable(bool dump_enabled = false) { dump_enabled_ = dump_enabled; }
void dump(std::vector<Ort::Value>& result);
};
}
}
}
#endif //PIANO_DYNASTY_FLOATING_POINT_INCLUDE_MSFT_MSFTINFERENCER_H_