2026-01-28 06:16:04 +00:00

135 lines
4.5 KiB
C++

//
// Created by Zhou Xiang on 10/4/21.
//
#pragma once
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "AbstractInferencer.h"
extern "C" {
#include "kplus/kp_struct.h"
}
namespace dynasty {
namespace inferencer {
namespace kplus {
using std::shared_ptr;
extern const std::string MODEL_OUTPUT_NODENAME_PREFIX;
template <typename T>
class InferencerImpl : public AbstractInferencer<T> {
public:
class Builder : public AbstractInferencer<T>::Builder {
protected:
kp_model_nef_descriptor_t nef_;
uint32_t model_id_;
std::shared_ptr<kp_device_group_s> device_group_;
kp_product_id_t product_;
public:
Builder() = default;
~Builder() override = default;
Builder* WithNef(kp_model_nef_descriptor_t const& nef);
Builder* WithModelId(uint32_t model_id);
Builder* WithDeviceGroup(const shared_ptr<kp_device_group_s>& deviceGroup);
Builder* WithProduct(const kp_product_id_t& product);
InferencerUniquePtr<T> Build() override;
std::vector<std::pair<std::string, std::vector<int32_t >>> GetInputDimensions(); // Only valid after Build()
std::vector<std::pair<std::string, std::vector<int32_t >>> GetOutputDimensions(); // Only valid after Build()
private:
kp_single_model_descriptor_t GetModel();
};
private:
const kp_single_model_descriptor_t model_;
const kp_product_id_t product_;
const shared_ptr<kp_device_group_s> device_group_;
uint32_t raw_buf_size_;
uint8_t* raw_buf_;
protected:
/**
* @brief A routine for all inferencers is to load the model. The constructor will construct a graph based on the
* model
* @param model_file: a path to the model in ONNX format;
*/
InferencerImpl (const kp_single_model_descriptor_t& inf, const shared_ptr<kp_device_group_s>& deviceGroup,
const kp_product_id_t& product);
virtual ~InferencerImpl () { delete raw_buf_; }
public:
/**
* \param preprocess_input: [{operation_node_name, 1d_vector}]
* \brief interface need to be implemented, pack output data path names and their float vectors then return
* \return name_value_pair: {operation node name: corresponding float vector}
*/
std::unordered_map<std::string, std::vector<T>> Inference(
std::unordered_map<std::string, std::vector<T>> const& preprocess_input, bool only_output_layers) override;
/**
* \param preprocess_input: [{operation_node_name, 1d_vector}]
* \brief interface need to be implemented, pack output data path names and their Tensors then return
* \return name_value_pair: {operation node name: corresponding float vector}
*/
std::unordered_map<std::string, std::vector<T>> Inference(
std::unordered_map<std::string, std::shared_ptr<dynasty::common::Tensor>> const& preprocess_input,
bool only_output_layers) override;
void InferenceSend(
std::unordered_map<std::string, std::vector<T>> const& preprocess_input, bool only_output_layers);
void InferenceSend(
std::unordered_map<std::string, std::shared_ptr<dynasty::common::Tensor>> const& preprocess_input,
bool only_output_layers) ;
std::unordered_map<std::string, std::vector<T>> InferenceRecv();
std::unordered_map<std::string, std::vector<T>> Inference(
std::unordered_map<std::string, std::string> const &preprocess_input, bool only_output_layers);
private:
/**
* transformFloat2RGBA
* This will transform the preprocess output (which is normalized float as chw) into
* KL520/KL720: HWC
* for C channel from 3 channel RGB to 4 channel RGBA
*/
uint8_t* transformFloat2RGBA(const T* normalized, const kp_product_id_t prod_id);
/**
* getFmOutputNode
* This takes dongle output node and :
* KL520: turn from HWC to CHW
*/
std::vector<T> getFmOutputNode(kp_inf_float_node_output_t* node, kp_product_id_t prod_id);
static uint32_t inf_send_index_;
static uint32_t inf_recv_index_;
// Helper func for Inference
std::unordered_map<std::string, std::vector<T>> Inference(const T* preprocess_input, bool only_output_layers);
// Helper func for InferenceSend
void InferenceSend(const T* preprocess_input, bool only_output_layers);
};
template<typename T>
uint32_t InferencerImpl<T>::inf_send_index_ = 0;
template<typename T>
uint32_t InferencerImpl<T>::inf_recv_index_ = 0;
} // namespace kplus
} // namespace inferencer
} // namespace dynasty