Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions paddle/fluid/inference/api/details/zero_copy_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -720,10 +720,6 @@ void Tensor::SetOrtBinding(const std::shared_ptr<Ort::IoBinding> binding) {
binding_ = binding;
}

void Tensor::SetOrtBuffer(const std::shared_ptr<std::vector<int8_t>> buffer) {
buffer_ = buffer;
}

template <typename T>
void Tensor::ORTCopyToCpu(T *data) const {
auto binding = binding_.lock();
Expand Down
102 changes: 54 additions & 48 deletions paddle/fluid/inference/api/onnxruntime_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,7 @@ bool CheckConvertToONNX(const AnalysisConfig &config) {
}
}

bool ONNXRuntimePredictor::Init() {
VLOG(3) << "ONNXRuntime Predictor::init()";

bool ONNXRuntimePredictor::InitBinding() {
// Now ONNXRuntime only support CPU
const char *device_name = config_.use_gpu() ? "Cuda" : "Cpu";
if (config_.use_gpu()) {
Expand All @@ -98,6 +96,53 @@ bool ONNXRuntimePredictor::Init() {
}
scope_.reset(new paddle::framework::Scope());

binding_ = std::make_shared<Ort::IoBinding>(*session_);
Ort::MemoryInfo memory_info(
device_name, OrtDeviceAllocator, place_.GetDeviceId(), OrtMemTypeDefault);
Ort::Allocator allocator(*session_, memory_info);

size_t n_inputs = session_->GetInputCount();
framework::proto::VarType::Type proto_type =
framework::proto::VarType::LOD_TENSOR;
for (size_t i = 0; i < n_inputs; ++i) {
auto input_name = session_->GetInputName(i, allocator);
auto type_info = session_->GetInputTypeInfo(i);
std::vector<int64_t> shape =
type_info.GetTensorTypeAndShapeInfo().GetShape();
ONNXTensorElementDataType data_type =
type_info.GetTensorTypeAndShapeInfo().GetElementType();
input_desc_.emplace_back(ONNXDesc{input_name, shape, data_type});

auto *ptr = scope_->Var(input_name);
framework::InitializeVariable(ptr, proto_type);

allocator.Free(input_name);
}

size_t n_outputs = session_->GetOutputCount();
for (size_t i = 0; i < n_outputs; ++i) {
auto output_name = session_->GetOutputName(i, allocator);
auto type_info = session_->GetOutputTypeInfo(i);
std::vector<int64_t> shape =
type_info.GetTensorTypeAndShapeInfo().GetShape();
ONNXTensorElementDataType data_type =
type_info.GetTensorTypeAndShapeInfo().GetElementType();
output_desc_.emplace_back(ONNXDesc{output_name, shape, data_type});

Ort::MemoryInfo out_memory_info(device_name,
OrtDeviceAllocator,
place_.GetDeviceId(),
OrtMemTypeDefault);
binding_->BindOutput(output_name, out_memory_info);

allocator.Free(output_name);
}
return true;
}

bool ONNXRuntimePredictor::Init() {
VLOG(3) << "ONNXRuntime Predictor::init()";

char *onnx_proto = nullptr;
int out_size;
if (config_.model_from_memory()) {
Expand Down Expand Up @@ -139,49 +184,10 @@ bool ONNXRuntimePredictor::Init() {
"will be "
"generated.";
}
session_ = {env_, onnx_proto, static_cast<size_t>(out_size), session_options};
binding_ = std::make_shared<Ort::IoBinding>(session_);

Ort::MemoryInfo memory_info(
device_name, OrtDeviceAllocator, place_.GetDeviceId(), OrtMemTypeDefault);
Ort::Allocator allocator(session_, memory_info);

size_t n_inputs = session_.GetInputCount();
framework::proto::VarType::Type proto_type =
framework::proto::VarType::LOD_TENSOR;
for (size_t i = 0; i < n_inputs; ++i) {
auto input_name = session_.GetInputName(i, allocator);
auto type_info = session_.GetInputTypeInfo(i);
std::vector<int64_t> shape =
type_info.GetTensorTypeAndShapeInfo().GetShape();
ONNXTensorElementDataType data_type =
type_info.GetTensorTypeAndShapeInfo().GetElementType();
input_desc_.emplace_back(ONNXDesc{input_name, shape, data_type});

auto *ptr = scope_->Var(input_name);
framework::InitializeVariable(ptr, proto_type);
session_ = std::make_shared<Ort::Session>(
*env_, onnx_proto, static_cast<size_t>(out_size), session_options);
InitBinding();

allocator.Free(input_name);
}

size_t n_outputs = session_.GetOutputCount();
for (size_t i = 0; i < n_outputs; ++i) {
auto output_name = session_.GetOutputName(i, allocator);
auto type_info = session_.GetOutputTypeInfo(i);
std::vector<int64_t> shape =
type_info.GetTensorTypeAndShapeInfo().GetShape();
ONNXTensorElementDataType data_type =
type_info.GetTensorTypeAndShapeInfo().GetElementType();
output_desc_.emplace_back(ONNXDesc{output_name, shape, data_type});

Ort::MemoryInfo out_memory_info(device_name,
OrtDeviceAllocator,
place_.GetDeviceId(),
OrtMemTypeDefault);
binding_->BindOutput(output_name, out_memory_info);

allocator.Free(output_name);
}
delete onnx_proto;
onnx_proto = nullptr;
return true;
Expand Down Expand Up @@ -343,7 +349,7 @@ bool ONNXRuntimePredictor::ZeroCopyRun() {
OrtMemTypeDefault);
binding_->BindOutput(output.name.c_str(), out_memory_info);
}
session_.Run({}, *(binding_.get()));
session_->Run({}, *(binding_.get()));
} catch (const std::exception &e) {
LOG(ERROR) << e.what();
return false;
Expand All @@ -354,8 +360,8 @@ bool ONNXRuntimePredictor::ZeroCopyRun() {

std::unique_ptr<PaddlePredictor> ONNXRuntimePredictor::Clone(void *stream) {
std::lock_guard<std::mutex> lk(clone_mutex_);
auto *x = new ONNXRuntimePredictor(config_);
x->Init();
auto *x = new ONNXRuntimePredictor(config_, env_, session_);
x->InitBinding();
return std::unique_ptr<PaddlePredictor>(x);
}

Expand Down
29 changes: 25 additions & 4 deletions paddle/fluid/inference/api/onnxruntime_predictor.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,36 @@ class ONNXRuntimePredictor : public PaddlePredictor {
/// \param[in] AnalysisConfig config
///
explicit ONNXRuntimePredictor(const AnalysisConfig &config)
: env_(ORT_LOGGING_LEVEL_WARNING, "onnx"), config_(config) {
: env_(std::make_shared<Ort::Env>(ORT_LOGGING_LEVEL_WARNING,
"paddle-ort")),
session_(nullptr),
binding_(nullptr),
config_(config) {
predictor_id_ = inference::GetUniqueId();
}
///
/// \brief Clone a ONNXRuntime Predictor object
///
/// \param[in] AnalysisConfig config
///
explicit ONNXRuntimePredictor(const AnalysisConfig &config,
std::shared_ptr<Ort::Env> env,
std::shared_ptr<Ort::Session> session)
: env_(env), session_(session), binding_(nullptr), config_(config) {
predictor_id_ = inference::GetUniqueId();
}
///
/// \brief Destroy the ONNXRuntime Predictor object
///
~ONNXRuntimePredictor();

///
/// \brief Initialize ORT Binding
///
/// \return Whether the init function executed successfully
///
bool InitBinding();

///
/// \brief Initialize predictor
///
Expand Down Expand Up @@ -203,16 +225,15 @@ class ONNXRuntimePredictor : public PaddlePredictor {

private:
// ONNXRuntime
Ort::Env env_;
Ort::Session session_{nullptr};
std::shared_ptr<Ort::Env> env_;
std::shared_ptr<Ort::Session> session_{nullptr};
std::shared_ptr<Ort::IoBinding> binding_;

AnalysisConfig config_;
std::mutex clone_mutex_;
platform::Place place_;
std::vector<ONNXDesc> input_desc_;
std::vector<ONNXDesc> output_desc_;
std::map<std::string, std::shared_ptr<std::vector<int8_t>>> input_buffers_;
int predictor_id_;

// Some more detailed tests, they are made the friends of the predictor, so that
Expand Down
3 changes: 0 additions & 3 deletions paddle/fluid/inference/api/paddle_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -191,16 +191,13 @@ class PD_INFER_DECL Tensor {
#ifdef PADDLE_WITH_ONNXRUNTIME
bool is_ort_tensor_{false};
std::vector<int64_t> shape_;
std::weak_ptr<std::vector<int8_t>> buffer_;
std::weak_ptr<Ort::IoBinding> binding_;
int idx_{-1};

void SetOrtMark(bool is_ort_tensor);

void SetOrtBinding(const std::shared_ptr<Ort::IoBinding> binding);

void SetOrtBuffer(const std::shared_ptr<std::vector<int8_t>> buffer);

template <typename T>
void ORTCopyFromCpu(const T* data);

Expand Down