2026-01-28 06:16:04 +00:00

47 lines
1.3 KiB
C++

#ifndef PIANO_DYNASTY_FLOATING_POINT_INCLUDE_MSFT_GPU_INFERENCERPIMPL_H_
#define PIANO_DYNASTY_FLOATING_POINT_INCLUDE_MSFT_GPU_INFERENCERPIMPL_H_
#include <limits.h>
#include "msft/MSFTInferencer.h"
namespace dynasty {
namespace inferencer {
namespace msftgpu {
template <typename T>
class Inferencer : public msft::Inferencer<T> {
public:
class Builder : public msft::Inferencer<T>::Builder {
friend class Inferencer;
private:
Builder() = default;
uint32_t device_id_ = INT_MAX;
void ConfigSessionOption(Ort::SessionOptions& option) override;
public:
Builder *WithGraphOptimization(uint32_t level);
Builder *WithParallelLevel(uint32_t parallel_level);
Builder *WithDeviceID(uint32_t device_id);
Builder *WithONNXModel(std::string const &model_file);
Builder *WithEncryptedModel(std::string const &model_file);
~Builder() override = default;
InferencerUniquePtr<T> Build() override;
};
protected:
explicit Inferencer(std::shared_ptr<Ort::Env> enviroment, std::shared_ptr<Ort::Session> session) :
msft::Inferencer<T>(enviroment, session) {}
public:
~Inferencer() override;
static Inferencer<T>::Builder* GetBuilder();
};
}
}
}
#endif //PIANO_DYNASTY_FLOATING_POINT_INCLUDE_MSFT_GPU_INFERENCERPIMPL_H_