From 56869d99260eee6517ed6f88f74d2131789430e8 Mon Sep 17 00:00:00 2001 From: feifei-111 Date: Wed, 31 Aug 2022 16:32:03 +0800 Subject: [PATCH] [phi] Migrate lookup_table_v2 and lookup_table_v2_grad XPU kernel to phi (#45590) * test=kunlun * test=kunlun --- .../fluid/operators/lookup_table_v2_op_xpu.cc | 144 ------------------ .../phi/kernels/xpu/embedding_grad_kernel.cc | 66 ++++++++ paddle/phi/kernels/xpu/embedding_kernel.cc | 73 +++++++++ 3 files changed, 139 insertions(+), 144 deletions(-) delete mode 100644 paddle/fluid/operators/lookup_table_v2_op_xpu.cc create mode 100644 paddle/phi/kernels/xpu/embedding_grad_kernel.cc create mode 100644 paddle/phi/kernels/xpu/embedding_kernel.cc diff --git a/paddle/fluid/operators/lookup_table_v2_op_xpu.cc b/paddle/fluid/operators/lookup_table_v2_op_xpu.cc deleted file mode 100644 index 1e669df798c1f..0000000000000 --- a/paddle/fluid/operators/lookup_table_v2_op_xpu.cc +++ /dev/null @@ -1,144 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include - -#include "paddle/fluid/framework/no_need_buffer_vars_inference.h" -#include "paddle/fluid/framework/op_version_registry.h" -#include "paddle/fluid/framework/var_type_inference.h" -#include "paddle/fluid/operators/lookup_table_v2_op.h" -#include "paddle/fluid/platform/device/device_wrapper.h" -#ifdef PADDLE_WITH_XPU -namespace paddle { -namespace operators { - -template -class LookupTableV2XPUKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &context) const override { - auto *ids_t = context.Input("Ids"); // int - auto *output_t = context.Output("Out"); // float - auto *table_var = context.InputVar("W"); - PADDLE_ENFORCE_EQ( - (std::is_same::value), - true, - platform::errors::PreconditionNotMet("Unsupported place! only support " - "xpu place , please check your " - "place.")); - - PADDLE_ENFORCE_EQ(table_var->IsType(), - true, - platform::errors::PermissionDenied( - "Unsupported Variable Type , idx in " - "LookupTableV2XPUKernel should be LoDTensor.")); - - int64_t padding_idx = context.Attr("padding_idx"); - int64_t ids_numel = ids_t->numel(); - - auto *table_t = context.Input("W"); - auto &dev_ctx = context.template device_context(); - - auto *table = table_t->data(); - auto *output = output_t->mutable_data(context.GetPlace()); - - const int64_t *ids = ids_t->data(); - - PADDLE_ENFORCE_EQ( - ids_numel <= std::numeric_limits::max(), - true, - platform::errors::OutOfRange( - "Number of ids greater than int32_t::max , please check " - "number of ids in LookupTableV2XPUKernel.")); - - int ym = static_cast(ids_numel); - - size_t xm = table_t->dims()[0]; - size_t n = table_t->dims()[1]; - - int r = xpu::embedding(dev_ctx.x_context(), - table, - ids, - output, - xm, - n, - ym, - static_cast(padding_idx)); - - PADDLE_ENFORCE_XDNN_SUCCESS(r, "embedding"); - } -}; - -template -class LookupTableV2GradXPUKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &context) const override { - auto *table_var = context.InputVar("W"); - DDim table_dim; - PADDLE_ENFORCE_EQ(table_var->IsType(), - true, - platform::errors::PermissionDenied( - "Unsupported Variable Type , idx in " - "LookupTableV2GradXPUKernel should be LoDTensor.")); - table_dim = context.Input("W")->dims(); - - bool is_sparse = context.Attr("is_sparse"); - PADDLE_ENFORCE_EQ( - is_sparse, - false, - platform::errors::InvalidArgument( - "LookupTableV2GradXPUKernel dose NOT support is_sparse = True.")); - - auto ids_t = context.Input("Ids"); - auto d_output_t = context.Input(framework::GradVarName("Out")); - auto d_table_t = context.Output(framework::GradVarName("W")); - - int64_t ids_numel = ids_t->numel(); - PADDLE_ENFORCE_EQ( - ids_numel <= std::numeric_limits::max(), - true, - platform::errors::OutOfRange( - "Number of ids greater than int32_t::max , please check " - "number of ids in LookupTableV2GradXPUKernel.")); - - auto &dev_ctx = context.template device_context(); - const int64_t *ids_data = ids_t->data(); - const T *d_output_data = d_output_t->data(); - T *d_table_data = d_table_t->mutable_data(context.GetPlace()); - int xm = d_table_t->dims()[0]; - int ym = static_cast(ids_numel); - int n = d_table_t->dims()[1]; - int padding_idx = context.Attr("padding_idx"); - - int r = xpu::embedding_grad(dev_ctx.x_context(), - d_output_data, - ids_data, - d_table_data, - xm, - n, - ym, - padding_idx); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "embedding_grad"); - } -}; -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OP_XPU_KERNEL( - lookup_table_v2, - ops::LookupTableV2XPUKernel); -REGISTER_OP_XPU_KERNEL( - lookup_table_v2_grad, - ops::LookupTableV2GradXPUKernel); -#endif diff --git a/paddle/phi/kernels/xpu/embedding_grad_kernel.cc b/paddle/phi/kernels/xpu/embedding_grad_kernel.cc new file mode 100644 index 0000000000000..f1b9abfe82f5a --- /dev/null +++ b/paddle/phi/kernels/xpu/embedding_grad_kernel.cc @@ -0,0 +1,66 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/embedding_grad_kernel.h" + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void EmbeddingGradKernel(const Context& ctx, + const DenseTensor& input, + const DenseTensor& weight, + const DenseTensor& out_grad, + int64_t padding_idx, + DenseTensor* weight_grad) { + DDim table_dim; + table_dim = weight.dims(); + + auto ids_t = &input; + auto d_output_t = &out_grad; + auto d_table_t = weight_grad; + + int64_t ids_numel = ids_t->numel(); + PADDLE_ENFORCE_EQ( + ids_numel <= std::numeric_limits::max(), + true, + phi::errors::OutOfRange( + "Number of ids greater than int32_t::max , please check " + "number of ids in LookupTableV2GradXPUKernel.")); + + auto& dev_ctx = ctx; + const int64_t* ids_data = ids_t->data(); + const T* d_output_data = d_output_t->data(); + T* d_table_data = dev_ctx.template Alloc(d_table_t); + int xm = d_table_t->dims()[0]; + int ym = static_cast(ids_numel); + int n = d_table_t->dims()[1]; + + int r = xpu::embedding_grad(dev_ctx.x_context(), + d_output_data, + ids_data, + d_table_data, + xm, + n, + ym, + padding_idx); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "embedding_grad"); +} + +} // namespace phi + +PD_REGISTER_KERNEL( + embedding_grad, XPU, ALL_LAYOUT, phi::EmbeddingGradKernel, float) {} diff --git a/paddle/phi/kernels/xpu/embedding_kernel.cc b/paddle/phi/kernels/xpu/embedding_kernel.cc new file mode 100644 index 0000000000000..d0e531f8c1399 --- /dev/null +++ b/paddle/phi/kernels/xpu/embedding_kernel.cc @@ -0,0 +1,73 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/embedding_kernel.h" + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void EmbeddingKernel(const Context &ctx, + const DenseTensor &inputx, + const DenseTensor &weight, + int64_t padding_idx, + DenseTensor *out) { + auto *ids_t = &inputx; // int + auto *output_t = out; // float + PADDLE_ENFORCE_EQ( + (std::is_same::value), + true, + phi::errors::PreconditionNotMet("Unsupported place! only support " + "xpu place , please check your " + "place.")); + + int64_t ids_numel = ids_t->numel(); + + auto *table_t = &weight; + auto &dev_ctx = ctx; + + auto *table = table_t->data(); + auto *output = dev_ctx.template Alloc(output_t); + + const int64_t *ids = ids_t->data(); + + PADDLE_ENFORCE_EQ( + ids_numel <= std::numeric_limits::max(), + true, + phi::errors::OutOfRange( + "Number of ids greater than int32_t::max , please check " + "number of ids in LookupTableV2XPUKernel.")); + + int ym = static_cast(ids_numel); + + size_t xm = table_t->dims()[0]; + size_t n = table_t->dims()[1]; + + int r = xpu::embedding(dev_ctx.x_context(), + table, + ids, + output, + xm, + n, + ym, + static_cast(padding_idx)); + + PADDLE_ENFORCE_XDNN_SUCCESS(r, "embedding"); +} + +} // namespace phi + +PD_REGISTER_KERNEL(embedding, XPU, ALL_LAYOUT, phi::EmbeddingKernel, float) {}