Skip to content

[pten] add split kernel #39060

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
7e1e79e
add split kernel
MingMingShangTian Jan 20, 2022
3b1b2b8
Merge branch 'develop' into split_kernel_latest
MingMingShangTian Jan 20, 2022
2544790
add split kernel signature
MingMingShangTian Jan 20, 2022
8e165ec
fix split bug
MingMingShangTian Jan 24, 2022
80d7c3b
modify MakePtenScalarArrayFromVarList
MingMingShangTian Jan 24, 2022
37a89eb
modify MakePtenScalarArrayFromVarList
MingMingShangTian Jan 24, 2022
998f488
Merge branch 'split_kernel_latest' of https://github.com/MingMingShan…
MingMingShangTian Jan 24, 2022
2ac9dcd
Merge branch 'develop' into split_kernel_latest
MingMingShangTian Jan 24, 2022
16e4239
Merge branch 'develop' into split_kernel_latest
MingMingShangTian Jan 25, 2022
c962479
fix split windows register error
MingMingShangTian Jan 25, 2022
1d0566c
add test case for split kernel
MingMingShangTian Jan 25, 2022
5db8cfb
merge develop branch and fix conflict
MingMingShangTian Jan 25, 2022
ba2a896
replace raw split kernel with pten kernel
MingMingShangTian Jan 25, 2022
b81a383
fix makeScalar/ScalarArray bug
MingMingShangTian Jan 26, 2022
918d978
remove debug log
MingMingShangTian Jan 26, 2022
48e09f9
merge develop branch and fix conflict
MingMingShangTian Jan 26, 2022
290aeae
remove int64_t type in buildPtcontext
MingMingShangTian Jan 27, 2022
c9091a4
merge develop branch
MingMingShangTian Jan 28, 2022
1a75972
update by code review
MingMingShangTian Jan 28, 2022
6284eb0
fix split dev test failed
MingMingShangTian Jan 28, 2022
9405d0f
merge develp branch and fix conflict
MingMingShangTian Feb 7, 2022
7473058
change DenseTensorMeta to MetaTensor
MingMingShangTian Feb 7, 2022
ed6e1a5
change split api code from auto gen to manual
MingMingShangTian Feb 8, 2022
c9fd9ad
Merge branch 'develop' into split_kernel_latest
MingMingShangTian Feb 9, 2022
b972ba2
split cuda kernel support bfloat16 type
MingMingShangTian Feb 9, 2022
43b051f
Merge branch 'develop' into split_kernel_latest
MingMingShangTian Feb 9, 2022
c4b2208
Merge branch 'develop' into split_kernel_latest
MingMingShangTian Feb 9, 2022
d7ee412
fix conflict
MingMingShangTian Feb 9, 2022
f8e4b1b
rm raw split kernel
MingMingShangTian Feb 10, 2022
f846ac2
Merge branch 'develop' into split_kernel_latest
MingMingShangTian Feb 10, 2022
a4b41d5
merge develop branch
MingMingShangTian Feb 11, 2022
6c3b8ca
merge develop branch
MingMingShangTian Feb 11, 2022
4448418
change to pten::errors
MingMingShangTian Feb 11, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions paddle/fluid/framework/custom_kernel_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,12 @@ limitations under the License. */
#include <glog/logging.h>
#include <gtest/gtest.h>
#include "paddle/extension.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_kernel_info_helper.h"
#include "paddle/fluid/memory/allocation/allocator_facade.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/pten/api/lib/utils/allocator.h"
#include "paddle/pten/api/lib/utils/tensor_utils.h"
#include "paddle/pten/api/lib/utils/storage.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/kernel_context.h"
#include "paddle/pten/core/kernel_factory.h"
Expand Down Expand Up @@ -183,14 +186,14 @@ TEST(CustomKernel, custom_kernel_dot) {
paddle::platform::CPUPlace());
auto dense_x = std::make_shared<pten::DenseTensor>(
alloc.get(), pten::DenseTensorMeta(pten::DataType::UINT8,
paddle::framework::make_ddim({2, 3}),
pten::framework::make_ddim({2, 3}),
pten::DataLayout::NCHW));
auto* dense_x_data =
dense_x->mutable_data<uint8_t>(paddle::platform::CPUPlace());

auto dense_y = std::make_shared<pten::DenseTensor>(
alloc.get(), pten::DenseTensorMeta(pten::DataType::UINT8,
paddle::framework::make_ddim({2, 3}),
pten::framework::make_ddim({2, 3}),
pten::DataLayout::NCHW));
auto* dense_y_data =
dense_y->mutable_data<uint8_t>(paddle::platform::CPUPlace());
Expand Down Expand Up @@ -231,8 +234,7 @@ TEST(CustomKernel, custom_kernel_dot) {
pten::DataType fake_attr_dtype = pten::DataType::UINT32;
paddle::framework::LoDTensor tmp_tensor;
tmp_tensor.mutable_data<uint8_t>({1}, pten::TransToPtenPlace(backend));
pten::Scalar fake_attr_scalar =
paddle::experimental::MakePtenScalar(tmp_tensor);
pten::Scalar fake_attr_scalar{tmp_tensor};
pten::ScalarArray fake_attr_scalar_array;
std::vector<int64_t> fake_attr_int64_vec;
std::vector<int> fake_attr_int_vec;
Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2099,6 +2099,10 @@ void OperatorWithKernel::BuildPtenKernelContext(
std::type_index(typeid(std::vector<int32_t>))) {
pt_kernel_context->EmplaceBackAttr(std::move(pten::ScalarArray(
BOOST_GET_CONST(std::vector<int32_t>, attr_iter->second))));
} else if (std::type_index(attr_iter->second.type()) ==
std::type_index(typeid(int32_t))) {
pt_kernel_context->EmplaceBackAttr(std::move(pten::ScalarArray(
&BOOST_GET_CONST(int32_t, attr_iter->second), 1)));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` to ScalarArray when "
Expand Down
8 changes: 8 additions & 0 deletions paddle/fluid/imperative/prepared_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,14 @@ void BuildDygraphPtenKernelContext(
std::type_index(typeid(std::vector<int32_t>))) {
kernel_ctx->EmplaceBackAttr(std::move(
pten::ScalarArray(BOOST_GET_CONST(std::vector<int32_t>, attr))));
} else if (std::type_index(attr.type()) ==
std::type_index(typeid(int64_t))) {
kernel_ctx->EmplaceBackAttr(
std::move(pten::ScalarArray(&BOOST_GET_CONST(int64_t, attr), 1)));
} else if (std::type_index(attr.type()) ==
std::type_index(typeid(int32_t))) {
kernel_ctx->EmplaceBackAttr(
std::move(pten::ScalarArray(&BOOST_GET_CONST(int32_t, attr), 1)));
} else if (attr_defs[i].type_index ==
std::type_index(typeid(std::vector<int32_t>))) {
const auto& vector_int_attr = BOOST_GET_CONST(std::vector<int>, attr);
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/imperative/tests/test_prepare_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ TEST(test_prepare_op, test_prepare_data_cpu_mkldnn) {
} // namespace imperative
} // namespace paddle

USE_OP(split);
USE_OP_ITSELF(split);
USE_OP(relu);
#ifdef PADDLE_WITH_MKLDNN
USE_OP_DEVICE_KERNEL(relu, MKLDNN);
Expand Down
8 changes: 0 additions & 8 deletions paddle/fluid/operators/split_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -172,11 +172,3 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(split, ops::SplitOp, ops::SplitOpMaker,
ops::SplitGradMaker<paddle::framework::OpDesc>,
ops::SplitGradMaker<paddle::imperative::OpBase>);
namespace plat = paddle::platform;
REGISTER_OP_CPU_KERNEL(
split, ops::SplitOpKernel<plat::CPUDeviceContext, double>,
ops::SplitOpKernel<plat::CPUDeviceContext, float>,
ops::SplitOpKernel<plat::CPUDeviceContext, int64_t>,
ops::SplitOpKernel<plat::CPUDeviceContext, int>,
ops::SplitOpKernel<plat::CPUDeviceContext, bool>,
ops::SplitOpKernel<plat::CPUDeviceContext, plat::float16>);
25 changes: 0 additions & 25 deletions paddle/fluid/operators/split_op.cu.cc

This file was deleted.

54 changes: 1 addition & 53 deletions paddle/fluid/operators/split_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,8 @@ limitations under the License. */
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/concat_and_split.h"
#include "paddle/fluid/operators/strided_memcpy.h"
#include "paddle/fluid/operators/utils.h"

#include "paddle/pten/kernels/split_kernel.h"
namespace paddle {
namespace operators {
static inline std::vector<framework::DDim> UpdateOutsDims(
Expand Down Expand Up @@ -108,56 +106,6 @@ static inline std::vector<framework::DDim> UpdateOutsDims(
}
return outs_dims;
}
template <typename DeviceContext, typename T>
class SplitOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in = ctx.Input<framework::Tensor>("X");
auto outs = ctx.MultiOutput<framework::Tensor>("Out");
int num = ctx.Attr<int>("num");
std::vector<int> sections = ctx.Attr<std::vector<int>>("sections");
int axis = ctx.Attr<int>("axis");

auto in_dims = in->dims();
auto outs_number = outs.size();

bool need_resize_outs_dims = false;
if (ctx.HasInput("AxisTensor")) {
auto* axis_tensor = ctx.Input<framework::Tensor>("AxisTensor");
axis = GetDataFromTensor(axis_tensor)[0];
need_resize_outs_dims = true;
}
auto sections_tensor_list =
ctx.MultiInput<framework::Tensor>("SectionsTensorList");
if (sections_tensor_list.size() > 0) {
sections = GetDataFromTensorList(sections_tensor_list);
need_resize_outs_dims = true;
}

if (need_resize_outs_dims) {
std::vector<framework::DDim> outs_dims =
UpdateOutsDims(true, true, in_dims, num, sections, axis, outs_number);
for (size_t j = 0; j < outs.size(); ++j) {
outs[j]->Resize(outs_dims[j]);
}
}

std::vector<const framework::Tensor*> shape_refer;
for (size_t j = 0; j < outs.size(); ++j) {
outs[j]->mutable_data<T>(ctx.GetPlace());
shape_refer.emplace_back(outs[j]);
}

auto& dev_ctx = ctx.template device_context<DeviceContext>();
// Sometimes direct copies will be faster, this maybe need deeply analysis.
if (axis == 0 && outs.size() < 10) {
StridedMemcpyWithAxis0<T>(dev_ctx, *in, shape_refer, &outs);
} else {
math::SplitFunctor<DeviceContext, T> functor;
functor(dev_ctx, *in, shape_refer, axis, &outs);
}
}
};

template <typename T>
class SplitGradMaker : public framework::SingleGradOpMaker<T> {
Expand Down
8 changes: 8 additions & 0 deletions paddle/pten/api/include/manual_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ limitations under the License. */

#include "paddle/pten/api/include/tensor.h"
#include "paddle/pten/common/backend.h"
#include "paddle/pten/common/scalar.h"
#include "paddle/pten/common/scalar_array.h"

/**
* This file stores some special APIs that are implemented manually
Expand All @@ -28,5 +30,11 @@ namespace experimental {
// TODO(chenweihang): Replace backend by place when place is ready
PADDLE_API Tensor copy_to(const Tensor& x, Backend backend, bool blocking);

// TODO(chentianyu03): Split API has extra logic to calculate the outputs size,
// api_gen do not support
PADDLE_API std::vector<Tensor> split(const Tensor& x,
const ScalarArray& num_or_sections,
const Scalar& axis);

} // namespace experimental
} // namespace paddle
68 changes: 68 additions & 0 deletions paddle/pten/api/lib/manual_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,12 @@ limitations under the License. */
#include "glog/logging.h"

#include "paddle/pten/api/lib/api_registry.h"
#include "paddle/pten/api/lib/api_utils.h"
#include "paddle/pten/api/lib/data_transform.h"
#include "paddle/pten/api/lib/kernel_dispatch.h"
#include "paddle/pten/api/lib/utils/storage.h"
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/core/meta_tensor.h"
#include "paddle/pten/infermeta/unary.h"

PT_DECLARE_KERNEL(copy, CPU, ALL_LAYOUT);
Expand Down Expand Up @@ -75,6 +78,71 @@ PADDLE_API Tensor copy_to(const Tensor& x, Backend backend, bool blocking) {
return out;
}

PADDLE_API std::vector<Tensor> split(const Tensor& x,
const ScalarArray& num_or_sections,
const Scalar& axis) {
Backend kernel_backend = Backend::UNDEFINED;
DataLayout kernel_layout = DataLayout::UNDEFINED;
DataType kernel_data_type = DataType::UNDEFINED;

if (kernel_backend == Backend::UNDEFINED ||
kernel_layout == DataLayout::UNDEFINED ||
kernel_data_type == DataType::UNDEFINED) {
auto kernel_key_set = ParseKernelKeyByInputArgs(x);
auto kernel_key = kernel_key_set.GetHigestPriorityKernelKey();
if (kernel_backend == Backend::UNDEFINED) {
kernel_backend = kernel_key.backend();
}
if (kernel_layout == DataLayout::UNDEFINED) {
kernel_layout = kernel_key.layout();
}
if (kernel_data_type == DataType::UNDEFINED) {
kernel_data_type = kernel_key.dtype();
}
}

auto kernel = pten::KernelFactory::Instance().SelectKernelOrThrowError(
"split", {kernel_backend, kernel_layout, kernel_data_type});
VLOG(6) << "split API kernel key: [" << kernel_backend << ", "
<< kernel_layout << ", " << kernel_data_type << "]";
VLOG(6) << "split API kernel: " << kernel;

auto* dev_ctx = GetDeviceContextByBackend(kernel_backend);

auto dense_x = PrepareData(x, kernel.InputAt(0), {});

// Calculate the number of out tensors
size_t out_number;
if (num_or_sections.GetData().size() == 1) {
out_number = num_or_sections.GetData()[0];
} else {
out_number = num_or_sections.GetData().size();
}

std::vector<Tensor> out;
auto dense_outs = SetKernelOutput(out_number, kernel_backend, &out);
std::vector<pten::MetaTensor> meta_outs;
for (size_t i = 0; i < out_number; ++i) {
meta_outs.push_back(dense_outs[i]);
}

pten::SplitInferMeta(
MakeMetaTensor(*dense_x), num_or_sections, axis, &meta_outs);

using kernel_signature = void (*)(const platform::DeviceContext&,
const pten::DenseTensor&,
const pten::ScalarArray&,
const pten::Scalar&,
std::vector<pten::DenseTensor*>&);
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)(*dev_ctx,
*dense_x,
pten::ScalarArray(num_or_sections),
pten::Scalar(axis),
dense_outs);

return out;
}
} // namespace experimental
} // namespace paddle

Expand Down
Loading