@@ -133,8 +133,8 @@ static void RunKernelFunc(const framework::ExecutionContext& ctx,
133
133
for (auto & in_name : inputs) {
134
134
VLOG (3 ) << " Custom Operator: input name - " << in_name;
135
135
if (detail::IsDuplicableVar (in_name)) {
136
- // return const std::vector<const Tensor *>
137
- auto vec_x = ctx.MultiInput <Tensor >(in_name);
136
+ // return const std::vector<const phi::DenseTensor *>
137
+ auto vec_x = ctx.MultiInput <phi::DenseTensor >(in_name);
138
138
PADDLE_ENFORCE_NE (vec_x.empty (),
139
139
true ,
140
140
platform::errors::NotFound (
@@ -161,7 +161,7 @@ static void RunKernelFunc(const framework::ExecutionContext& ctx,
161
161
}
162
162
kernel_ctx.EmplaceBackInputs (std::move (custom_vec_in));
163
163
} else {
164
- auto * x = ctx.Input <Tensor >(in_name);
164
+ auto * x = ctx.Input <phi::DenseTensor >(in_name);
165
165
PADDLE_ENFORCE_NOT_NULL (
166
166
x,
167
167
platform::errors::NotFound (" Input tensor (%s) is nullptr." , in_name));
@@ -222,7 +222,7 @@ static void RunKernelFunc(const framework::ExecutionContext& ctx,
222
222
223
223
VLOG (3 ) << " Custom Operator: push outputs into CustomOpKernelContext." ;
224
224
// cache the target tensor pointers
225
- std::vector<Tensor *> true_out_ptrs;
225
+ std::vector<phi::DenseTensor *> true_out_ptrs;
226
226
for (size_t i = 0 ; i < outputs.size (); ++i) {
227
227
auto out_name = outputs[i];
228
228
if (detail::IsDuplicableVar (out_name)) {
@@ -231,7 +231,7 @@ static void RunKernelFunc(const framework::ExecutionContext& ctx,
231
231
" If custom operator's outputs contains `paddle::Vec("
232
232
" )` type, "
233
233
" it only can hold one output." ));
234
- auto vec_out = ctx.MultiOutput <Tensor >(out_name);
234
+ auto vec_out = ctx.MultiOutput <phi::DenseTensor >(out_name);
235
235
PADDLE_ENFORCE_NE (vec_out.empty (),
236
236
true ,
237
237
platform::errors::NotFound (
@@ -253,7 +253,7 @@ static void RunKernelFunc(const framework::ExecutionContext& ctx,
253
253
}
254
254
kernel_ctx.EmplaceBackOutputs (std::move (custom_vec_out));
255
255
} else {
256
- auto * out = ctx.Output <Tensor >(out_name);
256
+ auto * out = ctx.Output <phi::DenseTensor >(out_name);
257
257
PADDLE_ENFORCE_NOT_NULL (out,
258
258
platform::errors::NotFound (
259
259
" Output tensor (%s) is nullptr." , out_name));
@@ -431,7 +431,7 @@ class CustomOperator : public OperatorWithKernel {
431
431
*/
432
432
framework::OpKernelType GetKernelTypeForVar (
433
433
const std::string& var_name,
434
- const Tensor & tensor,
434
+ const phi::DenseTensor & tensor,
435
435
const OpKernelType& expected_kernel_type) const override {
436
436
return OpKernelType (expected_kernel_type.data_type_ ,
437
437
expected_kernel_type.place_ ,
@@ -511,7 +511,7 @@ class CustomOpMaker : public OpProtoAndCheckerMaker {
511
511
AddComment (R"DOC(
512
512
Custom Operator.
513
513
514
- According to the Tensor operation function implemented by the user
514
+ According to the phi::DenseTensor operation function implemented by the user
515
515
independently of the framework, it is encapsulated into a framework
516
516
operator to adapt to various execution scenarios such as dynamic graph,
517
517
mode static graph mode, and inference mode.
0 commit comments