1919#include  " paddle/fluid/operators/controlflow/conditional_block_op_helper.h" 
2020#include  " paddle/fluid/operators/controlflow/recurrent_op_helper.h" 
2121#include  " paddle/fluid/operators/controlflow/while_op_helper.h" 
22+ #include  " paddle/pten/core/kernel_factory.h" 
2223
2324PADDLE_DEFINE_EXPORTED_bool (
2425    new_executor_sequential_run, false ,
2526    " Enable sequential execution for standalone executor, used for debug" 
27+ DECLARE_bool (run_pten_kernel);
28+ 
2629namespace  paddle  {
2730namespace  framework  {
2831namespace  interpreter  {
@@ -338,6 +341,8 @@ void build_op_func_list(const platform::Place& place,
338341      //  op is not a operatorwithkernel, so direcly run OperatorBase::Run()
339342      deal_operator_base (place, var_scope, ops[i], &op_func_node, local_scope);
340343    } else  {
344+       auto  op_with_kernel =
345+           static_cast <const  framework::OperatorWithKernel*>(op);
341346      //  construct RuntimeContext and analysis KernelType
342347      RuntimeContext runtime_context ({}, {});
343348      runtime_context.inputs .swap (ins_map);
@@ -350,8 +355,7 @@ void build_op_func_list(const platform::Place& place,
350355        //  TODO(Aurelius84): In case of control flow ops, they are NOT
351356        //  inheritted
352357        //  from OperatorWithKernel.
353-         static_cast <const  framework::OperatorWithKernel*>(op)->InferShape (
354-             &infer_shape_ctx);
358+         op_with_kernel->InferShape (&infer_shape_ctx);
355359      }
356360
357361      auto  kernels_iter = all_op_kernels.find (op->Type ());
@@ -367,21 +371,25 @@ void build_op_func_list(const platform::Place& place,
367371          platform::DeviceContextPool::Instance ();
368372      auto * dev_ctx = pool.Get (place);
369373      Scope scope;
370-       auto  expected_kernel_key =
371-           dynamic_cast <const  framework::OperatorWithKernel*>(op)
372-               ->GetExpectedKernelType (
373-                   ExecutionContext (*op, scope, *dev_ctx, runtime_context));
374+       auto  expected_kernel_key = op_with_kernel->GetExpectedKernelType (
375+           ExecutionContext (*op, scope, *dev_ctx, runtime_context));
374376
375377      //  change device by the device_guard()
376378      apply_device_guard (op, place, &expected_kernel_key);
377379      VLOG (3 ) << " expected_kernel_key : " 
378380
379381      //  step 3. apply data transforms and insert data transfer ops
380382      VariableValueMap& ins_map_temp = runtime_context.inputs ;
383+ 
384+       //  NOTE(zhiqiu): op_func_node->operator_base_ maybe changed in
385+       //  ApplyDataTransform
381386      ApplyDataTransform (expected_kernel_key, place, &ins_map_temp, var_scope,
382387                         &op_func_node, vec_func_list, use_local_scope);
388+       op_with_kernel = static_cast <const  framework::OperatorWithKernel*>(
389+           op_func_node.operator_base_ .get ());
390+ 
383391      //  step 4. Run op kernel
384-       VLOG (3 ) << op ->Type ()
392+       VLOG (3 ) << op_with_kernel ->Type ()
385393              << "  : expected_kernel_key : " 
386394
387395      if  (platform::is_gpu_place (expected_kernel_key.place_ )) {
@@ -397,7 +405,8 @@ void build_op_func_list(const platform::Place& place,
397405      }
398406      op_func_node.dev_ctx_  = dev_ctx;
399407
400-       auto  exec_ctx = ExecutionContext (*op, scope, *dev_ctx, runtime_context);
408+       auto  exec_ctx =
409+           ExecutionContext (*op_with_kernel, scope, *dev_ctx, runtime_context);
401410
402411      auto  kernel_iter = kernels.find (expected_kernel_key);
403412      PADDLE_ENFORCE_NE (
@@ -406,8 +415,27 @@ void build_op_func_list(const platform::Place& place,
406415              " Operator (%s) does not have kernel for %s." Type (),
407416              KernelTypeToString (expected_kernel_key)));
408417
409-       op_func_node.kernel_func_  = OpKernelComputeFunc (kernel_iter->second );
410-       op_func_node.kernel_func_ (exec_ctx);
418+       auto  run_pten_kernel = false ;
419+ 
420+       if  (FLAGS_run_pten_kernel &&
421+           pten::KernelFactory::Instance ().HasCompatiblePtenKernel (
422+               op_with_kernel->Type ())) {
423+         op_with_kernel->ChoosePtenKernel (exec_ctx);
424+         run_pten_kernel = op_with_kernel->PtenKernel ()->IsValid ();
425+       }
426+ 
427+       if  (run_pten_kernel) {
428+         op_with_kernel->BuildPtenKernelContext (runtime_context, dev_ctx);
429+         op_func_node.pt_kernel_  = op_with_kernel->PtenKernel ();
430+         op_func_node.pt_kernel_context_  = op_with_kernel->PtenKernelContext ();
431+ 
432+         (*op_func_node.pt_kernel_ )(op_func_node.pt_kernel_context_ );
433+         op_with_kernel->WriteBackToOutputs (&runtime_context);
434+         op_func_node.pt_kernel_context_ ->ClearData ();
435+       } else  {
436+         op_func_node.kernel_func_  = OpKernelComputeFunc (kernel_iter->second );
437+         op_func_node.kernel_func_ (exec_ctx);
438+       }
411439
412440      //  post-process grad_op.outputs if need cast complex grad into real grad.
413441      //  NOTE(Aurelius84): insert a transfer_dtype_op inplacely to cast it.
0 commit comments