2424#include  < utility> 
2525#include  < vector> 
2626
27- #include  " paddle/fluid//platform/device/gpu/gpu_types.h" 
2827#include  " paddle/fluid/framework/scope.h" 
29- #include  " paddle/fluid/framework/version.h" 
28+ #include  " paddle/fluid/framework/var_type_traits.h" 
29+ #include  " paddle/fluid/framework/variable_helper.h" 
3030#include  " paddle/fluid/inference/analysis/helper.h" 
31- #include  " paddle/fluid/inference/analysis/passes/memory_optimize_pass.h" 
3231#include  " paddle/fluid/inference/api/helper.h" 
3332#include  " paddle/fluid/inference/api/paddle_inference_api.h" 
3433#include  " paddle/fluid/inference/api/paddle_inference_pass.h" 
@@ -97,6 +96,7 @@ bool ONNXRuntimePredictor::Init() {
9796  } else  {
9897    place_ = paddle::platform::CPUPlace ();
9998  }
99+   scope_.reset (new  paddle::framework::Scope ());
100100
101101  char  *onnx_proto = nullptr ;
102102  int  out_size;
@@ -147,6 +147,8 @@ bool ONNXRuntimePredictor::Init() {
147147  Ort::Allocator allocator (session_, memory_info);
148148
149149  size_t  n_inputs = session_.GetInputCount ();
150+   framework::proto::VarType::Type proto_type =
151+       framework::proto::VarType::LOD_TENSOR;
150152  for  (size_t  i = 0 ; i < n_inputs; ++i) {
151153    auto  input_name = session_.GetInputName (i, allocator);
152154    auto  type_info = session_.GetInputTypeInfo (i);
@@ -155,6 +157,10 @@ bool ONNXRuntimePredictor::Init() {
155157    ONNXTensorElementDataType data_type =
156158        type_info.GetTensorTypeAndShapeInfo ().GetElementType ();
157159    input_desc_.emplace_back (ONNXDesc{input_name, shape, data_type});
160+ 
161+     auto  *ptr = scope_->Var (input_name);
162+     framework::InitializeVariable (ptr, proto_type);
163+ 
158164    allocator.Free (input_name);
159165  }
160166
@@ -249,13 +255,13 @@ bool ONNXRuntimePredictor::FindONNXDesc(const std::string &name,
249255
250256std::unique_ptr<ZeroCopyTensor> ONNXRuntimePredictor::GetInputTensor (
251257    const  std::string &name) {
252-   PADDLE_ENFORCE_EQ ( FindONNXDesc (name,  true ),
253-                     true , 
254-                     platform::errors::PreconditionNotMet ( 
255-                         " The in variable named %s is not found in the  " 
256-                         " ONNXPredictor. " , 
257-                         name)); 
258-   std::unique_ptr<ZeroCopyTensor>  res ( new  ZeroCopyTensor (nullptr , this ));
258+   PADDLE_ENFORCE_NOT_NULL (scope_-> FindVar (name),
259+                            platform::errors::PreconditionNotMet ( 
260+                                " The in variable named %s is not found in the  " 
261+                                " ONNXPredictor. " , 
262+                               name)); 
263+   std::unique_ptr<ZeroCopyTensor>  res ( 
264+        new  ZeroCopyTensor (static_cast < void  *>(scope_. get ()) , this ));
259265  res->input_or_output_  = true ;
260266  res->SetName (name);
261267  if  (platform::is_cpu_place (place_)) {
@@ -264,16 +270,6 @@ std::unique_ptr<ZeroCopyTensor> ONNXRuntimePredictor::GetInputTensor(
264270    auto  gpu_place = place_;
265271    res->SetPlace (PaddlePlace::kGPU , gpu_place.GetDeviceId ());
266272  }
267-   res->SetOrtMark (true );
268-   res->SetOrtBinding (binding_);
269-   auto  iter = input_buffers_.find (name);
270-   if  (iter == input_buffers_.end ()) {
271-     std::vector<int8_t > i_vector;
272-     input_buffers_[name] = std::make_shared<std::vector<int8_t >>(i_vector);
273-     res->SetOrtBuffer (input_buffers_[name]);
274-   } else  {
275-     res->SetOrtBuffer (iter->second );
276-   }
277273  return  res;
278274}
279275
@@ -306,6 +302,24 @@ std::unique_ptr<ZeroCopyTensor> ONNXRuntimePredictor::GetOutputTensor(
306302  return  res;
307303}
308304
305+ Ort::Value ONNXRuntimePredictor::GetOrtValue (const  ONNXDesc &desc,
306+                                              const  char  *device_name) {
307+   Ort::MemoryInfo memory_info (
308+       device_name, OrtDeviceAllocator, place_.GetDeviceId (), OrtMemTypeDefault);
309+   auto  *var = scope_->FindVar (desc.name );
310+   auto  *tensor = var->GetMutable <framework::LoDTensor>();
311+   size_t  size =
312+       tensor->numel () *
313+       framework::SizeOfType (framework::TransToProtoVarType (tensor->dtype ()));
314+   std::vector<int64_t > shape = phi::vectorize<int64_t >(tensor->dims ());
315+   return  Ort::Value::CreateTensor (memory_info,
316+                                   static_cast <void  *>(tensor->data ()),
317+                                   size,
318+                                   shape.data (),
319+                                   shape.size (),
320+                                   desc.dtype );
321+ }
322+ 
309323bool  ONNXRuntimePredictor::Run (const  std::vector<PaddleTensor> &inputs,
310324                               std::vector<PaddleTensor> *output_data,
311325                               int  batch_size) {
@@ -315,7 +329,13 @@ bool ONNXRuntimePredictor::Run(const std::vector<PaddleTensor> &inputs,
315329
316330bool  ONNXRuntimePredictor::ZeroCopyRun () {
317331  try  {
318-     const  char  *device_name = place_ == PlaceType::kCPU  ? " Cpu" " Cuda" 
332+     const  char  *device_name = platform::is_cpu_place (place_) ? " Cpu" " Cuda" 
333+     std::vector<Ort::Value> inputs;
334+     inputs.reserve (input_desc_.size ());
335+     for  (auto  desc : input_desc_) {
336+       inputs.push_back (GetOrtValue (desc, device_name));
337+       binding_->BindInput (desc.name .c_str (), inputs.back ());
338+     }
319339    for  (auto  output : output_desc_) {
320340      Ort::MemoryInfo out_memory_info (device_name,
321341                                      OrtDeviceAllocator,
@@ -333,8 +353,10 @@ bool ONNXRuntimePredictor::ZeroCopyRun() {
333353}
334354
335355std::unique_ptr<PaddlePredictor> ONNXRuntimePredictor::Clone (void  *stream) {
336-   LOG (ERROR) << " Not support Clone(), Please create new Predictor" 
337-   return  nullptr ;
356+   std::lock_guard<std::mutex> lk (clone_mutex_);
357+   auto  *x = new  ONNXRuntimePredictor (config_);
358+   x->Init ();
359+   return  std::unique_ptr<PaddlePredictor>(x);
338360}
339361
340362uint64_t  ONNXRuntimePredictor::TryShrinkMemory () {
0 commit comments