47 lines
1.3 KiB
C++
47 lines
1.3 KiB
C++
#ifndef PIANO_DYNASTY_FLOATING_POINT_INCLUDE_MSFT_GPU_INFERENCERPIMPL_H_
|
|
#define PIANO_DYNASTY_FLOATING_POINT_INCLUDE_MSFT_GPU_INFERENCERPIMPL_H_
|
|
|
|
#include <limits.h>
|
|
|
|
#include "msft/MSFTInferencer.h"
|
|
|
|
|
|
namespace dynasty {
|
|
namespace inferencer {
|
|
namespace msftgpu {
|
|
|
|
template <typename T>
|
|
class Inferencer : public msft::Inferencer<T> {
|
|
public:
|
|
class Builder : public msft::Inferencer<T>::Builder {
|
|
friend class Inferencer;
|
|
private:
|
|
Builder() = default;
|
|
uint32_t device_id_ = INT_MAX;
|
|
void ConfigSessionOption(Ort::SessionOptions& option) override;
|
|
|
|
public:
|
|
Builder *WithGraphOptimization(uint32_t level);
|
|
Builder *WithParallelLevel(uint32_t parallel_level);
|
|
Builder *WithDeviceID(uint32_t device_id);
|
|
Builder *WithONNXModel(std::string const &model_file);
|
|
Builder *WithEncryptedModel(std::string const &model_file);
|
|
~Builder() override = default;
|
|
InferencerUniquePtr<T> Build() override;
|
|
};
|
|
|
|
protected:
|
|
explicit Inferencer(std::shared_ptr<Ort::Env> enviroment, std::shared_ptr<Ort::Session> session) :
|
|
msft::Inferencer<T>(enviroment, session) {}
|
|
|
|
public:
|
|
~Inferencer() override;
|
|
static Inferencer<T>::Builder* GetBuilder();
|
|
};
|
|
|
|
}
|
|
}
|
|
}
|
|
|
|
#endif //PIANO_DYNASTY_FLOATING_POINT_INCLUDE_MSFT_GPU_INFERENCERPIMPL_H_
|