2026-01-28 06:16:04 +00:00

83 lines
3.0 KiB
C++

#ifndef PIANO_KNEREXINFERENCER_H
#define PIANO_KNEREXINFERENCER_H
#include <map>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "AbstractInferencer.h"
class Graph;
namespace dynasty {
namespace inferencer {
template <typename T>
class KnerexInferencer : public AbstractInferencer<T> {
public:
class Builder : public AbstractInferencer<T>::Builder {
protected:
std::shared_ptr<Graph> graph_;
public:
Builder() = default;
virtual ~Builder() = default;
virtual Builder *WithONNXModel(std::string const &onnx_file) = 0;
virtual Builder *WithBIEModel(std::string const &bie_file) = 0;
virtual Builder *WithGraphOptimization(uint32_t level) = 0;
virtual Builder *WithSettingInfo(std::unordered_map<std::string, std::string> inf_config) = 0;
virtual Builder *WithGraph(Graph& graph) = 0;
virtual Builder *WithGraph(std::shared_ptr<Graph> const graph) = 0;
virtual std::shared_ptr<Graph> GetGraph();
virtual void SetGraph(Graph const &graph);
virtual InferencerUniquePtr<T> Build() = 0;
};
protected:
std::shared_ptr<Graph> graph_;
explicit KnerexInferencer(std::shared_ptr<Graph> piano_graph);
/**
* @brief Check
* 1. whether input provides enough entries to cover graph->input
* 2. whether entries in input have correct dimension
*/
void CheckInputCorrectness(std::unordered_map<std::string, std::vector<T>> const &preprocess_input);
virtual std::unordered_map<std::string, std::vector<T>>
GraphBasedInference(std::unordered_map<std::string, std::vector<T>> const &input, bool only_output_layers) = 0;
public:
virtual ~KnerexInferencer() = default;
public:
/**
* \param input: [{operation_node_name, 1d_vector}]
* \brief interface need to be implemented, pack output data path names and their float vectors then return
* \return name_value_pair: {operation node name: corresponding float vector}
*/
std::unordered_map<std::string, std::vector<T>>
Inference(std::unordered_map<std::string, std::vector<T>> const &input, bool only_output_layers) override;
/** -------- Added Interface for Knerex -------- */
/** ---- Manipulate Graph and Internal Buffer ---- */
std::shared_ptr<Graph> GetGraph();
void SetGraph(Graph const& new_graph);
virtual void InferenceNodeBufferSetup() = 0;
virtual void InferenceNode(std::string node_name, std::unordered_map<std::string, std::vector<std::vector<T>>>& data) = 0;
virtual void InferenceNode(std::string node_name, std::unordered_map<std::string, std::vector<std::vector<int>>>& data) = 0;
virtual std::unordered_map<std::string, std::vector<T>> GetFloatTensor() = 0;
virtual std::unordered_map<std::string, std::vector<T>> GetFloatScaledTensor() = 0 ;
virtual std::unordered_map<std::string, std::vector<int>> GetFixTensor() = 0;
virtual std::unordered_map<std::string, std::vector<T>> GetFixScaledTensor() = 0;
};
}
}
#endif // PIANO_KNEREXINFERENCER_H