#ifndef PIANO_KNEREXINFERENCER_H #define PIANO_KNEREXINFERENCER_H #include #include #include #include #include #include "AbstractInferencer.h" class Graph; namespace dynasty { namespace inferencer { template class KnerexInferencer : public AbstractInferencer { public: class Builder : public AbstractInferencer::Builder { protected: std::shared_ptr graph_; public: Builder() = default; virtual ~Builder() = default; virtual Builder *WithONNXModel(std::string const &onnx_file) = 0; virtual Builder *WithBIEModel(std::string const &bie_file) = 0; virtual Builder *WithGraphOptimization(uint32_t level) = 0; virtual Builder *WithSettingInfo(std::unordered_map inf_config) = 0; virtual Builder *WithGraph(Graph& graph) = 0; virtual Builder *WithGraph(std::shared_ptr const graph) = 0; virtual std::shared_ptr GetGraph(); virtual void SetGraph(Graph const &graph); virtual InferencerUniquePtr Build() = 0; }; protected: std::shared_ptr graph_; explicit KnerexInferencer(std::shared_ptr piano_graph); /** * @brief Check * 1. whether input provides enough entries to cover graph->input * 2. whether entries in input have correct dimension */ void CheckInputCorrectness(std::unordered_map> const &preprocess_input); virtual std::unordered_map> GraphBasedInference(std::unordered_map> const &input, bool only_output_layers) = 0; public: virtual ~KnerexInferencer() = default; public: /** * \param input: [{operation_node_name, 1d_vector}] * \brief interface need to be implemented, pack output data path names and their float vectors then return * \return name_value_pair: {operation node name: corresponding float vector} */ std::unordered_map> Inference(std::unordered_map> const &input, bool only_output_layers) override; /** -------- Added Interface for Knerex -------- */ /** ---- Manipulate Graph and Internal Buffer ---- */ std::shared_ptr GetGraph(); void SetGraph(Graph const& new_graph); virtual void InferenceNodeBufferSetup() = 0; virtual void InferenceNode(std::string node_name, std::unordered_map>>& data) = 0; }; } } #endif // PIANO_KNEREXINFERENCER_H