Skip to content

[PTen]Refactor scale kernel that has selected_rows input #39278

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 47 additions & 20 deletions cmake/pten.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ function(kernel_library TARGET)
set(cpu_srcs)
set(gpu_srcs)
set(xpu_srcs)
set(selected_rows_srcs)
# parse and save the deps kerenl targets
set(all_srcs)
set(kernel_deps)
Expand All @@ -106,6 +107,9 @@ function(kernel_library TARGET)
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/cpu/${TARGET}.cc)
list(APPEND cpu_srcs ${CMAKE_CURRENT_SOURCE_DIR}/cpu/${TARGET}.cc)
endif()
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/selected_rows/${TARGET}.cc)
list(APPEND selected_rows_srcs ${CMAKE_CURRENT_SOURCE_DIR}/selected_rows/${TARGET}.cc)
endif()
if (WITH_GPU OR WITH_ROCM)
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/gpu/${TARGET}.cu)
list(APPEND gpu_srcs ${CMAKE_CURRENT_SOURCE_DIR}/gpu/${TARGET}.cu)
Expand Down Expand Up @@ -144,27 +148,30 @@ function(kernel_library TARGET)
list(LENGTH cpu_srcs cpu_srcs_len)
list(LENGTH gpu_srcs gpu_srcs_len)
list(LENGTH xpu_srcs xpu_srcs_len)
list(LENGTH selected_rows_srcs selected_rows_srcs_len)

# Build Target according different src organization
if((${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0 OR
${xpu_srcs_len} GREATER 0) AND ${common_srcs_len} GREATER 0)
# If the common_srcs depends on specific device srcs, build target using this rule.
${xpu_srcs_len} GREATER 0) AND (${common_srcs_len} GREATER 0 OR
${selected_rows_srcs_len} GREATER 0))
# If the common_srcs/selected_rows_srcs depends on specific device srcs, build target using this rule.
if (WITH_GPU)
if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0)
nv_library(${TARGET}_part SRCS ${cpu_srcs} ${gpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
nv_library(${TARGET} SRCS ${common_srcs} DEPS ${TARGET}_part)
nv_library(${TARGET} SRCS ${common_srcs} ${selected_rows_srcs} DEPS ${TARGET}_part)
endif()
elseif (WITH_ROCM)
if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0)
hip_library(${TARGET}_part SRCS ${cpu_srcs} ${gpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
hip_library(${TARGET} SRCS ${common_srcs} DEPS ${TARGET}_part)
hip_library(${TARGET} SRCS ${common_srcs} ${selected_rows_srcs} DEPS ${TARGET}_part)
endif()
else()
if (${cpu_srcs_len} GREATER 0 OR ${xpu_srcs_len} GREATER 0)
cc_library(${TARGET}_part SRCS ${cpu_srcs} ${xpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
cc_library(${TARGET} SRCS ${common_srcs} DEPS ${TARGET}_part)
cc_library(${TARGET} SRCS ${common_srcs} ${selected_rows_srcs} DEPS ${TARGET}_part)
endif()
endif()
# If there are only specific device srcs, build target using this rule.
elseif (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0 OR ${xpu_srcs_len} GREATER 0)
if (WITH_GPU)
if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0)
Expand All @@ -179,25 +186,42 @@ function(kernel_library TARGET)
cc_library(${TARGET} SRCS ${cpu_srcs} ${xpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
endif()
endif()
else()
if (${common_srcs_len} EQUAL 0)
message(FATAL_ERROR "Cannot find any implementation for ${TARGET}")
# If the selected_rows_srcs depends on common_srcs, build target using this rule.
elseif (${common_srcs_len} GREATER 0 AND ${selected_rows_srcs_len} GREATER 0)
if (WITH_GPU)
nv_library(${TARGET}_part SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
nv_library(${TARGET} SRCS ${selected_rows_srcs} DEPS ${TARGET}_part)
elseif (WITH_ROCM)
hip_library(${TARGET}_part SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
hip_library(${TARGET} SRCS ${selected_rows_srcs} DEPS ${TARGET}_part)
else()
# If the kernel has a device independent public implementation,
# we will use this implementation and will not adopt the implementation
# under specific devices
if (WITH_GPU)
nv_library(${TARGET} SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
elseif (WITH_ROCM)
hip_library(${TARGET} SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
else()
cc_library(${TARGET} SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
endif()
endif()
cc_library(${TARGET}_part SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
cc_library(${TARGET} SRCS ${selected_rows_srcs} DEPS ${TARGET}_part)
endif()
# If there are only common_srcs or selected_rows_srcs, build target using below rules.
elseif (${common_srcs_len} GREATER 0)
if (WITH_GPU)
nv_library(${TARGET} SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
elseif (WITH_ROCM)
hip_library(${TARGET} SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
else()
cc_library(${TARGET} SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
endif()
elseif (${selected_rows_srcs_len} GREATER 0)
if (WITH_GPU)
nv_library(${TARGET} SRCS ${selected_rows_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
elseif (WITH_ROCM)
hip_library(${TARGET} SRCS ${selected_rows_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
else()
cc_library(${TARGET} SRCS ${selected_rows_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
endif()
else()
message(FATAL_ERROR "Cannot find any implementation for ${TARGET}")
endif()

if (${common_srcs_len} GREATER 0 OR ${cpu_srcs_len} GREATER 0 OR
${gpu_srcs_len} GREATER 0 OR ${xpu_srcs_len} GREATER 0)
${gpu_srcs_len} GREATER 0 OR ${xpu_srcs_len} GREATER 0 OR
${selected_rows_srcs_len} GREATER 0)
# append target into PTEN_KERNELS property
get_property(pten_kernels GLOBAL PROPERTY PTEN_KERNELS)
set(pten_kernels ${pten_kernels} ${TARGET})
Expand All @@ -219,6 +243,9 @@ function(kernel_library TARGET)
if (${xpu_srcs_len} GREATER 0)
kernel_declare(${xpu_srcs})
endif()
if (${selected_rows_srcs_len} GREATER 0)
kernel_declare(${selected_rows_srcs})
endif()
endfunction()

function(register_kernels)
Expand Down
28 changes: 15 additions & 13 deletions paddle/fluid/operators/scale_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,34 +43,36 @@ class ScaleKernel : public framework::OpKernel<T> {
public:
virtual void Compute(const framework::ExecutionContext& ctx) const {
auto* in_var = ctx.InputVar("X");
auto* in = framework::GetLoDTensorOrSelectedRowsValueFromVar(*in_var);

auto bias = ctx.Attr<float>("bias");
auto bias_after_scale = ctx.Attr<bool>("bias_after_scale");

auto scale = ctx.Attr<float>("scale");
auto* out_var = ctx.OutputVar("Out");

if (ctx.HasInput("ScaleTensor")) {
auto* scale_tensor = ctx.Input<framework::Tensor>("ScaleTensor");
scale = static_cast<float>(GetAttrFromTensor<T>(scale_tensor));
}

auto* out_var = ctx.OutputVar("Out");
if (in_var->IsType<pten::SelectedRows>() && in_var != out_var) {
auto& in_slr = in_var->Get<pten::SelectedRows>();
auto* out_slr = out_var->GetMutable<pten::SelectedRows>();
out_slr->set_rows(in_slr.rows());
out_slr->set_height(in_slr.height());
}
auto* in = framework::GetLoDTensorOrSelectedRowsValueFromVar(*in_var);
auto* out =
framework::GetMutableLoDTensorOrSelectedRowsValueFromVar(out_var);
out->mutable_data<T>(in->place());
auto& dev_ctx = ctx.device_context<DeviceContext>();

// call new kernel
pten::ScaleKernel<T>(
static_cast<const typename framework::ConvertToPtenContext<
DeviceContext>::TYPE&>(dev_ctx),
*in, scale, bias, bias_after_scale, out);
if (in_var->IsType<pten::SelectedRows>()) {
pten::ScaleSR<T>(
static_cast<const typename framework::ConvertToPtenContext<
DeviceContext>::TYPE&>(dev_ctx),
in_var->Get<pten::SelectedRows>(), scale, bias, bias_after_scale,
out_var->GetMutable<pten::SelectedRows>());
} else {
pten::ScaleKernel<T>(
static_cast<const typename framework::ConvertToPtenContext<
DeviceContext>::TYPE&>(dev_ctx),
*in, scale, bias, bias_after_scale, out);
}
}
};

Expand Down
6 changes: 6 additions & 0 deletions paddle/pten/core/kernel_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,19 @@ struct KernelArgsParseFunctor<Return_ (*)(Args_...)> {
std::type_index(typeid(const std::vector<DenseTensor>&))) {
args_def->AppendInput(
default_key.backend(), default_tensor_layout, default_key.dtype());
} else if (arg_type == std::type_index(typeid(const SelectedRows&))) {
args_def->AppendInput(
default_key.backend(), default_tensor_layout, default_key.dtype());
} else if (arg_type == std::type_index(typeid(DenseTensor*))) {
args_def->AppendOutput(
default_key.backend(), default_tensor_layout, default_key.dtype());
} else if (arg_type ==
std::type_index(typeid(std::vector<DenseTensor*>))) {
args_def->AppendOutput(
default_key.backend(), default_tensor_layout, default_key.dtype());
} else if (arg_type == std::type_index(typeid(SelectedRows*))) {
args_def->AppendOutput(
default_key.backend(), default_tensor_layout, default_key.dtype());
} else {
// Attribute deal with
// TODO(chenweihang): now here allow any types of attribute, maybe
Expand Down
7 changes: 3 additions & 4 deletions paddle/pten/core/kernel_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/kernel_context.h"
#include "paddle/pten/core/kernel_def.h"
#include "paddle/pten/core/selected_rows.h"
#include "paddle/pten/core/sparse_coo_tensor.h"
#include "paddle/pten/core/sparse_csr_tensor.h"

Expand Down Expand Up @@ -215,6 +216,7 @@ struct KernelImpl<Return (*)(DevCtx, Args...), kernel_fn> {
PT_SPECIALIZE_KernelCallHelper_FOR_INPUT(DenseTensor);
PT_SPECIALIZE_KernelCallHelper_FOR_OPTIONAL_INPUT(DenseTensor);
PT_SPECIALIZE_KernelCallHelper_FOR_MULTI_INPUT(DenseTensor);
PT_SPECIALIZE_KernelCallHelper_FOR_INPUT(SelectedRows);

PT_SPECIALIZE_KernelCallHelper_FOR_INPUT(SparseCooTensor);
PT_SPECIALIZE_KernelCallHelper_FOR_OPTIONAL_INPUT(SparseCooTensor);
Expand All @@ -223,8 +225,6 @@ struct KernelImpl<Return (*)(DevCtx, Args...), kernel_fn> {
PT_SPECIALIZE_KernelCallHelper_FOR_INPUT(SparseCsrTensor);
PT_SPECIALIZE_KernelCallHelper_FOR_OPTIONAL_INPUT(SparseCsrTensor);
PT_SPECIALIZE_KernelCallHelper_FOR_MULTI_INPUT(SparseCsrTensor);
// TODO(chenweihang): adapt SelectedRows
// PT_SPECIALIZE_KernelCallHelper_FOR_INPUT(SelectedRowsTensor);

/* Attribute Helpers */

Expand All @@ -244,14 +244,13 @@ struct KernelImpl<Return (*)(DevCtx, Args...), kernel_fn> {

PT_SPECIALIZE_KernelCallHelper_FOR_OUTPUT(DenseTensor);
PT_SPECIALIZE_KernelCallHelper_FOR_MULTI_OUTPUT(DenseTensor);
PT_SPECIALIZE_KernelCallHelper_FOR_OUTPUT(SelectedRows);

PT_SPECIALIZE_KernelCallHelper_FOR_OUTPUT(SparseCooTensor);
PT_SPECIALIZE_KernelCallHelper_FOR_MULTI_OUTPUT(SparseCooTensor);

PT_SPECIALIZE_KernelCallHelper_FOR_OUTPUT(SparseCsrTensor);
PT_SPECIALIZE_KernelCallHelper_FOR_MULTI_OUTPUT(SparseCsrTensor);
// TODO(chenweihang): adapt SelectedRows
// PT_SPECIALIZE_KernelCallHelper_FOR_OUTPUT(SelectedRowsTensor);

/* End case */
template <typename T>
Expand Down
2 changes: 1 addition & 1 deletion paddle/pten/kernels/cpu/scale_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ PT_REGISTER_KERNEL(scale,
pten::ScaleKernel,
float,
double,
paddle::platform::bfloat16,
pten::dtype::bfloat16,
uint8_t,
int8_t,
int16_t,
Expand Down
2 changes: 1 addition & 1 deletion paddle/pten/kernels/gpu/scale_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ PT_REGISTER_KERNEL(scale,
pten::ScaleKernel,
float,
double,
paddle::platform::float16,
pten::dtype::float16,
uint8_t,
int8_t,
int16_t,
Expand Down
9 changes: 9 additions & 0 deletions paddle/pten/kernels/scale_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License. */

#include "paddle/pten/common/scalar.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/selected_rows.h"
#include "paddle/pten/infermeta/unary.h"
#include "paddle/pten/kernels/empty_kernel.h"
namespace pten {
Expand All @@ -28,6 +29,14 @@ void ScaleKernel(const Context& dev_ctx,
bool bias_after_scale,
DenseTensor* out);

template <typename T, typename Context>
void ScaleSR(const Context& dev_ctx,
const SelectedRows& x,
const Scalar& scale,
float bias,
bool bias_after_scale,
SelectedRows* out);

template <typename T, typename Context>
DenseTensor Scale(const Context& dev_ctx,
const DenseTensor& x,
Expand Down
68 changes: 68 additions & 0 deletions paddle/pten/kernels/selected_rows/scale_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2022

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

下个pr处理


Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/pten/kernels/scale_kernel.h"

#include "paddle/pten/backends/cpu/cpu_context.h"
#include "paddle/pten/backends/gpu/gpu_context.h"
#include "paddle/pten/core/kernel_registry.h"

// See Note [ Why still include the fluid headers? ]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这行注释可以删除了

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,下个PR处理

#include "paddle/pten/common/bfloat16.h"
namespace pten {

template <typename T, typename Context>
void ScaleSR(const Context& dev_ctx,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个名字需要加Kernel后缀吗

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我想了想还是加上吧,顺便把中间层API也加上,下个PR补一下

const SelectedRows& x,
const Scalar& scale,
float bias,
bool bias_after_scale,
SelectedRows* out) {
if (x.value().data() != out->value().data()) {
out->set_rows(x.rows());
out->set_height(x.height());
}
pten::ScaleKernel<T>(
dev_ctx, x.value(), scale, bias, bias_after_scale, out->mutable_value());
}

} // namespace pten

PT_REGISTER_KERNEL(scale_sr,
CPU,
ALL_LAYOUT,
pten::ScaleSR,
float,
double,
pten::dtype::bfloat16,
uint8_t,
int8_t,
int16_t,
int,
int64_t) {}

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PT_REGISTER_KERNEL(scale_sr,
GPU,
ALL_LAYOUT,
pten::ScaleSR,
float,
double,
pten::dtype::float16,
uint8_t,
int8_t,
int16_t,
int,
int64_t) {}
#endif
2 changes: 1 addition & 1 deletion paddle/pten/tests/api/scale_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ static void ScaleCPU(DataType kernel_dtype,
break;
}
case pten::DataType::BFLOAT16: {
pten::ScaleKernel<paddle::platform::bfloat16>(
pten::ScaleKernel<pten::dtype::bfloat16>(
dev_ctx, x, pten::Scalar(scale), bias, bias_after_scale, dense_out);
break;
}
Expand Down