83 lines
3.0 KiB
C++
83 lines
3.0 KiB
C++
#ifndef PIANO_KNEREXINFERENCER_H
|
|
#define PIANO_KNEREXINFERENCER_H
|
|
|
|
#include <map>
|
|
#include <memory>
|
|
#include <string>
|
|
#include <unordered_map>
|
|
#include <vector>
|
|
|
|
#include "AbstractInferencer.h"
|
|
|
|
class Graph;
|
|
namespace dynasty {
|
|
namespace inferencer {
|
|
|
|
template <typename T>
|
|
class KnerexInferencer : public AbstractInferencer<T> {
|
|
public:
|
|
class Builder : public AbstractInferencer<T>::Builder {
|
|
protected:
|
|
std::shared_ptr<Graph> graph_;
|
|
|
|
public:
|
|
Builder() = default;
|
|
virtual ~Builder() = default;
|
|
virtual Builder *WithONNXModel(std::string const &onnx_file) = 0;
|
|
virtual Builder *WithBIEModel(std::string const &bie_file) = 0;
|
|
virtual Builder *WithGraphOptimization(uint32_t level) = 0;
|
|
virtual Builder *WithSettingInfo(std::unordered_map<std::string, std::string> inf_config) = 0;
|
|
virtual Builder *WithGraph(Graph& graph) = 0;
|
|
virtual Builder *WithGraph(std::shared_ptr<Graph> const graph) = 0;
|
|
|
|
virtual std::shared_ptr<Graph> GetGraph();
|
|
virtual void SetGraph(Graph const &graph);
|
|
|
|
virtual InferencerUniquePtr<T> Build() = 0;
|
|
};
|
|
|
|
protected:
|
|
std::shared_ptr<Graph> graph_;
|
|
|
|
explicit KnerexInferencer(std::shared_ptr<Graph> piano_graph);
|
|
/**
|
|
* @brief Check
|
|
* 1. whether input provides enough entries to cover graph->input
|
|
* 2. whether entries in input have correct dimension
|
|
*/
|
|
void CheckInputCorrectness(std::unordered_map<std::string, std::vector<T>> const &preprocess_input);
|
|
|
|
virtual std::unordered_map<std::string, std::vector<T>>
|
|
GraphBasedInference(std::unordered_map<std::string, std::vector<T>> const &input, bool only_output_layers) = 0;
|
|
|
|
public:
|
|
virtual ~KnerexInferencer() = default;
|
|
|
|
public:
|
|
/**
|
|
* \param input: [{operation_node_name, 1d_vector}]
|
|
* \brief interface need to be implemented, pack output data path names and their float vectors then return
|
|
* \return name_value_pair: {operation node name: corresponding float vector}
|
|
*/
|
|
std::unordered_map<std::string, std::vector<T>>
|
|
Inference(std::unordered_map<std::string, std::vector<T>> const &input, bool only_output_layers) override;
|
|
|
|
/** -------- Added Interface for Knerex -------- */
|
|
/** ---- Manipulate Graph and Internal Buffer ---- */
|
|
std::shared_ptr<Graph> GetGraph();
|
|
void SetGraph(Graph const& new_graph);
|
|
|
|
virtual void InferenceNodeBufferSetup() = 0;
|
|
virtual void InferenceNode(std::string node_name, std::unordered_map<std::string, std::vector<std::vector<T>>>& data) = 0;
|
|
virtual void InferenceNode(std::string node_name, std::unordered_map<std::string, std::vector<std::vector<int>>>& data) = 0;
|
|
virtual std::unordered_map<std::string, std::vector<T>> GetFloatTensor() = 0;
|
|
virtual std::unordered_map<std::string, std::vector<T>> GetFloatScaledTensor() = 0 ;
|
|
virtual std::unordered_map<std::string, std::vector<int>> GetFixTensor() = 0;
|
|
virtual std::unordered_map<std::string, std::vector<T>> GetFixScaledTensor() = 0;
|
|
|
|
};
|
|
|
|
}
|
|
}
|
|
#endif // PIANO_KNEREXINFERENCER_H
|