Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmake/external/paddle2onnx.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ endif()
include(ExternalProject)

set(PADDLE2ONNX_PROJECT "extern_paddle2onnx")
set(PADDLE2ONNX_VERSION "0.9.9")
set(PADDLE2ONNX_VERSION "1.0.0rc")
set(PADDLE2ONNX_PREFIX_DIR ${THIRD_PARTY_PATH}/paddle2onnx)
set(PADDLE2ONNX_SOURCE_DIR
${THIRD_PARTY_PATH}/paddle2onnx/src/${PADDLE2ONNX_PROJECT})
Expand Down
120 changes: 0 additions & 120 deletions paddle/fluid/inference/api/details/zero_copy_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -179,13 +179,6 @@ PlaceType Tensor::place() const { return place_; }

template <typename T>
void Tensor::CopyFromCpu(const T *data) {
#ifdef PADDLE_WITH_ONNXRUNTIME
if (is_ort_tensor_) {
ORTCopyFromCpu<T>(data);
return;
}
#endif

EAGER_GET_TENSOR(paddle::framework::LoDTensor);
PADDLE_ENFORCE_GE(tensor->numel(),
0,
Expand Down Expand Up @@ -731,112 +724,6 @@ void Tensor::SetOrtBuffer(const std::shared_ptr<std::vector<int8_t>> buffer) {
buffer_ = buffer;
}

Ort::Value GetOrtVaule(const Ort::MemoryInfo &memory_info,
float *data,
size_t size,
const int64_t *shape,
size_t shape_len) {
return Ort::Value::CreateTensor<float>(
memory_info, data, size, shape, shape_len);
}

Ort::Value GetOrtVaule(const Ort::MemoryInfo &memory_info,
int64_t *data,
size_t size,
const int64_t *shape,
size_t shape_len) {
return Ort::Value::CreateTensor<int64_t>(
memory_info, data, size, shape, shape_len);
}

Ort::Value GetOrtVaule(const Ort::MemoryInfo &memory_info,
int32_t *data,
size_t size,
const int64_t *shape,
size_t shape_len) {
return Ort::Value::CreateTensor<int32_t>(
memory_info, data, size, shape, shape_len);
}

Ort::Value GetOrtVaule(const Ort::MemoryInfo &memory_info,
uint8_t *data,
size_t size,
const int64_t *shape,
size_t shape_len) {
return Ort::Value::CreateTensor<uint8_t>(
memory_info, data, size, shape, shape_len);
}

Ort::Value GetOrtVaule(const Ort::MemoryInfo &memory_info,
int8_t *data,
size_t size,
const int64_t *shape,
size_t shape_len) {
return Ort::Value::CreateTensor<int8_t>(
memory_info, data, size, shape, shape_len);
}

Ort::Value GetOrtVaule(const Ort::MemoryInfo &memory_info,
float16 *data,
size_t size,
const int64_t *shape,
size_t shape_len) {
return Ort::Value::CreateTensor(memory_info,
static_cast<void *>(data),
size * sizeof(float16),
shape,
shape_len,
ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16);
}

template <typename T>
void Tensor::ORTCopyFromCpu(const T *data) {
auto binding = binding_.lock();
PADDLE_ENFORCE_NOT_NULL(binding,
paddle::platform::errors::PreconditionNotMet(
"input tensor [%s] no binding ptr", name_));
const char *device_name = place_ == PlaceType::kCPU ? "Cpu" : "Cuda";
Ort::MemoryInfo memory_info(
device_name, OrtDeviceAllocator, device_, OrtMemTypeDefault);
size_t size = std::accumulate(
begin(shape_), end(shape_), 1UL, std::multiplies<size_t>());
auto buffer = buffer_.lock();
size_t buffer_size = size * sizeof(T);
if (buffer_size > buffer->size()) {
buffer->resize(buffer_size);
}
std::memcpy(static_cast<void *>(buffer->data()), data, buffer_size);

auto onnx_dtype = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
if (std::is_same<T, float>::value) {
onnx_dtype = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
} else if (std::is_same<T, double>::value) {
onnx_dtype = ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE;
} else if (std::is_same<T, int64_t>::value) {
onnx_dtype = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
} else if (std::is_same<T, int32_t>::value) {
onnx_dtype = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32;
} else if (std::is_same<T, uint8_t>::value) {
onnx_dtype = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
} else if (std::is_same<T, int8_t>::value) {
onnx_dtype = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8;
} else if (std::is_same<T, float16>::value) {
onnx_dtype = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16;
} else {
PADDLE_THROW(paddle::platform::errors::InvalidArgument(
"Found undefined data type for onnxruntime, only supports "
"float16/float32/float64/int8/uint8/int32/int64."));
}

auto ort_value = Ort::Value::CreateTensor(memory_info,
buffer->data(),
buffer_size,
shape_.data(),
shape_.size(),
onnx_dtype);
binding->BindInput(name_.c_str(), ort_value);
}

template <typename T>
void Tensor::ORTCopyToCpu(T *data) const {
auto binding = binding_.lock();
Expand All @@ -857,13 +744,6 @@ void Tensor::ORTCopyToCpu(T *data) const {
}
}

template void Tensor::ORTCopyFromCpu<float>(const float *data);
template void Tensor::ORTCopyFromCpu<int64_t>(const int64_t *data);
template void Tensor::ORTCopyFromCpu<int32_t>(const int32_t *data);
template void Tensor::ORTCopyFromCpu<uint8_t>(const uint8_t *data);
template void Tensor::ORTCopyFromCpu<int8_t>(const int8_t *data);
template void Tensor::ORTCopyFromCpu<float16>(const float16 *data);

template void Tensor::ORTCopyToCpu<float>(float *data) const;
template void Tensor::ORTCopyToCpu<int32_t>(int32_t *data) const;
template void Tensor::ORTCopyToCpu<uint8_t>(uint8_t *data) const;
Expand Down
68 changes: 45 additions & 23 deletions paddle/fluid/inference/api/onnxruntime_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,10 @@
#include <utility>
#include <vector>

#include "paddle/fluid//platform/device/gpu/gpu_types.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/version.h"
#include "paddle/fluid/framework/var_type_traits.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/inference/analysis/helper.h"
#include "paddle/fluid/inference/analysis/passes/memory_optimize_pass.h"
#include "paddle/fluid/inference/api/helper.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/inference/api/paddle_inference_pass.h"
Expand Down Expand Up @@ -97,6 +96,7 @@ bool ONNXRuntimePredictor::Init() {
} else {
place_ = paddle::platform::CPUPlace();
}
scope_.reset(new paddle::framework::Scope());

char *onnx_proto = nullptr;
int out_size;
Expand Down Expand Up @@ -147,6 +147,8 @@ bool ONNXRuntimePredictor::Init() {
Ort::Allocator allocator(session_, memory_info);

size_t n_inputs = session_.GetInputCount();
framework::proto::VarType::Type proto_type =
framework::proto::VarType::LOD_TENSOR;
for (size_t i = 0; i < n_inputs; ++i) {
auto input_name = session_.GetInputName(i, allocator);
auto type_info = session_.GetInputTypeInfo(i);
Expand All @@ -155,6 +157,10 @@ bool ONNXRuntimePredictor::Init() {
ONNXTensorElementDataType data_type =
type_info.GetTensorTypeAndShapeInfo().GetElementType();
input_desc_.emplace_back(ONNXDesc{input_name, shape, data_type});

auto *ptr = scope_->Var(input_name);
framework::InitializeVariable(ptr, proto_type);

allocator.Free(input_name);
}

Expand Down Expand Up @@ -249,13 +255,13 @@ bool ONNXRuntimePredictor::FindONNXDesc(const std::string &name,

std::unique_ptr<ZeroCopyTensor> ONNXRuntimePredictor::GetInputTensor(
const std::string &name) {
PADDLE_ENFORCE_EQ(FindONNXDesc(name, true),
true,
platform::errors::PreconditionNotMet(
"The in variable named %s is not found in the "
"ONNXPredictor.",
name));
std::unique_ptr<ZeroCopyTensor> res(new ZeroCopyTensor(nullptr, this));
PADDLE_ENFORCE_NOT_NULL(scope_->FindVar(name),
platform::errors::PreconditionNotMet(
"The in variable named %s is not found in the "
"ONNXPredictor.",
name));
std::unique_ptr<ZeroCopyTensor> res(
new ZeroCopyTensor(static_cast<void *>(scope_.get()), this));
res->input_or_output_ = true;
res->SetName(name);
if (platform::is_cpu_place(place_)) {
Expand All @@ -264,16 +270,6 @@ std::unique_ptr<ZeroCopyTensor> ONNXRuntimePredictor::GetInputTensor(
auto gpu_place = place_;
res->SetPlace(PaddlePlace::kGPU, gpu_place.GetDeviceId());
}
res->SetOrtMark(true);
res->SetOrtBinding(binding_);
auto iter = input_buffers_.find(name);
if (iter == input_buffers_.end()) {
std::vector<int8_t> i_vector;
input_buffers_[name] = std::make_shared<std::vector<int8_t>>(i_vector);
res->SetOrtBuffer(input_buffers_[name]);
} else {
res->SetOrtBuffer(iter->second);
}
return res;
}

Expand Down Expand Up @@ -306,6 +302,24 @@ std::unique_ptr<ZeroCopyTensor> ONNXRuntimePredictor::GetOutputTensor(
return res;
}

Ort::Value ONNXRuntimePredictor::GetOrtValue(const ONNXDesc &desc,
const char *device_name) {
Ort::MemoryInfo memory_info(
device_name, OrtDeviceAllocator, place_.GetDeviceId(), OrtMemTypeDefault);
auto *var = scope_->FindVar(desc.name);
auto *tensor = var->GetMutable<framework::LoDTensor>();
size_t size =
tensor->numel() *
framework::SizeOfType(framework::TransToProtoVarType(tensor->dtype()));
std::vector<int64_t> shape = phi::vectorize<int64_t>(tensor->dims());
return Ort::Value::CreateTensor(memory_info,
static_cast<void *>(tensor->data()),
size,
shape.data(),
shape.size(),
desc.dtype);
}

bool ONNXRuntimePredictor::Run(const std::vector<PaddleTensor> &inputs,
std::vector<PaddleTensor> *output_data,
int batch_size) {
Expand All @@ -315,7 +329,13 @@ bool ONNXRuntimePredictor::Run(const std::vector<PaddleTensor> &inputs,

bool ONNXRuntimePredictor::ZeroCopyRun() {
try {
const char *device_name = place_ == PlaceType::kCPU ? "Cpu" : "Cuda";
const char *device_name = platform::is_cpu_place(place_) ? "Cpu" : "Cuda";
std::vector<Ort::Value> inputs;
inputs.reserve(input_desc_.size());
for (auto desc : input_desc_) {
inputs.push_back(GetOrtValue(desc, device_name));
binding_->BindInput(desc.name.c_str(), inputs.back());
}
for (auto output : output_desc_) {
Ort::MemoryInfo out_memory_info(device_name,
OrtDeviceAllocator,
Expand All @@ -333,8 +353,10 @@ bool ONNXRuntimePredictor::ZeroCopyRun() {
}

std::unique_ptr<PaddlePredictor> ONNXRuntimePredictor::Clone(void *stream) {
LOG(ERROR) << "Not support Clone(), Please create new Predictor";
return nullptr;
std::lock_guard<std::mutex> lk(clone_mutex_);
auto *x = new ONNXRuntimePredictor(config_);
x->Init();
return std::unique_ptr<PaddlePredictor>(x);
}

uint64_t ONNXRuntimePredictor::TryShrinkMemory() {
Expand Down
20 changes: 15 additions & 5 deletions paddle/fluid/inference/api/onnxruntime_predictor.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@

#include "onnxruntime_c_api.h" // NOLINT
#include "onnxruntime_cxx_api.h" // NOLINT
#include "paddle/fluid/framework/naive_executor.h"
#include "paddle/fluid/framework/op_compatible_info.h"
#include "paddle/fluid/inference/analysis/analyzer.h"
#include "paddle/fluid/inference/api/api_impl.h"
#include "paddle/fluid/inference/api/details/reset_tensor_array.h"
Expand Down Expand Up @@ -94,7 +92,7 @@ class ONNXRuntimePredictor : public PaddlePredictor {
/// \param[in] AnalysisConfig config
///
explicit ONNXRuntimePredictor(const AnalysisConfig &config)
: config_(config), env_(ORT_LOGGING_LEVEL_WARNING, "onnx") {
: env_(ORT_LOGGING_LEVEL_WARNING, "onnx"), config_(config) {
predictor_id_ = inference::GetUniqueId();
}
///
Expand Down Expand Up @@ -176,6 +174,8 @@ class ONNXRuntimePredictor : public PaddlePredictor {
///
std::unique_ptr<PaddlePredictor> Clone(void *stream = nullptr) override;

std::shared_ptr<framework::Scope> scope_;

protected:
const void *GetDeviceContexts() const override;

Expand All @@ -191,14 +191,24 @@ class ONNXRuntimePredictor : public PaddlePredictor {
///
bool FindONNXDesc(const std::string &name, bool is_input);

private:
AnalysisConfig config_;
/// \brief get the Ort Value(input Tensor).
///
/// \param[in] desc ONNXDesce(name、shape、dtype)
///
/// \param[in] device_name "cpu" or "gpu" of device
///
/// \return get a Ort::Value
///
Ort::Value GetOrtValue(const ONNXDesc &desc, const char *device_name);

private:
// ONNXRuntime
Ort::Env env_;
Ort::Session session_{nullptr};
std::shared_ptr<Ort::IoBinding> binding_;

AnalysisConfig config_;
std::mutex clone_mutex_;
platform::Place place_;
std::vector<ONNXDesc> input_desc_;
std::vector<ONNXDesc> output_desc_;
Expand Down