2026-01-28 06:16:04 +00:00

125 lines
5.1 KiB
C++

/**
* \brief an abstract class for the all the graph based inferencers, extends from AbstractInferencer
* Subclasses should implement the pure virtual function Inference
* \author Nan Zhou, nanzhou at kneron dot us
* \copyright 2019 Kneron Inc. All right reserved.
*/
#ifndef PIANO_DYNASTY_INCLUDE_PIANOINFERENCER_H_
#define PIANO_DYNASTY_INCLUDE_PIANOINFERENCER_H_
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "AbstractInferencer.h"
class Graph;
namespace dynasty {
namespace inferencer {
template <typename T>
class PianoInferencer : public AbstractInferencer<T> {
public:
class Builder : public AbstractInferencer<T>::Builder {
private:
std::string bieFileVersion_;
protected:
std::shared_ptr<Graph const> graph_;
uint32_t optimization_level_ = SOME_MAGIC_INTEGER;
uint32_t device_id_ = SOME_MAGIC_INTEGER;
bool dump_enabled_=false;
static int const SOME_MAGIC_INTEGER = 666;
int GetDeviceID() { return device_id_; }
public:
Builder();
~Builder() override;
virtual Builder *WithONNXModel(std::string const &onnx_file);
virtual Builder *WithBIEModel(std::string const &bie_file);
virtual Builder *WithUnencryptedGraph(Graph const &graph);
virtual Builder *WithGraphOptimization(uint32_t level);
virtual Builder *WithDeviceID(int device_id);
virtual Builder *WithDumpEnable(bool dump_enabled);
virtual Builder *WithSettingInfo(const std::unordered_map<std::string, std::string>& inf_config) { return this; }
virtual std::shared_ptr<Graph const> GetGraph() const;
InferencerUniquePtr<T> Build() override = 0;
/**
* @return {{operation name : bchw}}
*/
std::vector<std::pair<std::string, std::vector<int32_t >>> GetInputDimensions();
std::vector<std::pair<std::string, std::vector<int32_t >>> GetOutputDimensions();
std::string GetBieFileVersion();
};
private:
std::shared_ptr<Graph const> graph_;
protected:
PianoInferencer()=default;
/**
* @brief A routine for all inferencers is to load the model. The constructor will construct a graph based on the
* model
* @param model_file: a path to the model in ONNX format;
*/
explicit PianoInferencer(std::shared_ptr<Graph const> piano_graph);
/**
* @brief Check
* 1. whether preprocess_input provides enough entries to cover graph->input
* 2. whether entries in preprocess_input have correct dimension
*/
void CheckInputCorrectness(std::unordered_map<std::string, std::vector<T>> const &preprocess_input);
void CheckInputCorrectness(std::unordered_map<std::string, std::shared_ptr<dynasty::common::Tensor>> const &preprocess_input);
virtual std::unordered_map<std::string, std::vector<T>> GraphBasedInference(
std::unordered_map<std::string, std::vector<T>> const &preprocess_input, bool only_output_layers) = 0;
virtual std::unordered_map<std::string, std::vector<T>> GraphBasedInference(
std::unordered_map<std::string, std::shared_ptr<dynasty::common::Tensor>> const &preprocess_input, bool only_output_layers);
public:
virtual ~PianoInferencer();
public:
std::shared_ptr<Graph const> GetImmutableGraph();
/**
* \param preprocess_input: [{operation_node_name, 1d_vector}]
* \brief interface need to be implemented, pack output data path names and their float vectors then return
* \return name_value_pair: {operation node name: corresponding float vector}
*/
std::unordered_map<std::string, std::vector<T>> Inference(
std::unordered_map<std::string, std::vector<T>> const &preprocess_input, bool only_output_layers) override;
/**
* \param preprocess_input: [{operation_node_name, 1d_vector}]
* \brief interface need to be implemented, pack output data path names and their Tensors then return
* \return name_value_pair: {operation node name: corresponding float vector}
*/
std::unordered_map<std::string, std::vector<T>> Inference(
std::unordered_map<std::string, std::shared_ptr<dynasty::common::Tensor>> const &preprocess_input, bool only_output_layers) override;
/**
* \param preprocess_input: [{operation_node_name, path_to_1d_vector}]
* \brief interface to inference from operation_name and txt pairs
* \return name_value_pair: {operation node name: corresponding float vector}
*/
std::unordered_map<std::string, std::vector<T>> Inference(
std::unordered_map<std::string, std::string> const &preprocess_input, bool only_output_layers) override;
//std::unordered_map<std::string, std::vector<int>> ConvertFloatToInt(
// std::unordered_map<std::string, std::vector<T>> const &float_output, bool only_output_layers) override;
std::unordered_map<std::string, std::vector<int>> ConvertFloatToInt(std::unordered_map<std::string, std::vector<T>>& float_output, bool only_output_layers) override;
};
} // namespace inferencer
} // namespace dynasty
#endif // PIANO_ABSTRACTINFERENCER_H