-
Notifications
You must be signed in to change notification settings - Fork 5.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add roll and roll_grad kernels and strided_slice and strided_slice_gr…
…ad kernels, test=kunlun (#47368) * add stat tool * add roll and roll_grad kernels and strided_slice and strided_slice_grad kernels, test=kunlun
- Loading branch information
Showing
9 changed files
with
672 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/phi/kernels/roll_grad_kernel.h" | ||
|
||
#include "paddle/phi/backends/xpu/enforce_xpu.h" | ||
#include "paddle/phi/core/kernel_registry.h" | ||
|
||
namespace phi { | ||
|
||
template <typename T, typename Context> | ||
void RollGradKernel(const Context& dev_ctx, | ||
const DenseTensor& x, | ||
const DenseTensor& out_grad, | ||
const IntArray& shifts, | ||
const std::vector<int64_t>& axis, | ||
DenseTensor* x_grad) { | ||
using XPUType = typename XPUTypeTrait<T>::Type; | ||
auto shifts_data = shifts.GetData(); | ||
dev_ctx.template Alloc<T>(x_grad); | ||
DDim input_dim = x.dims(); | ||
std::vector<int> xshape; | ||
size_t nums = shifts_data.size(); | ||
for (int i = 0; i < input_dim.size(); ++i) { | ||
xshape.emplace_back(input_dim[i]); | ||
} | ||
|
||
auto dims = axis; | ||
|
||
// axis = none, reshape to 1-D tensor | ||
if (dims.size() == 0) { | ||
dims.push_back(0l); | ||
input_dim = phi::Dim<1>(x.numel()); | ||
} | ||
std::vector<int> shifts_in; | ||
std::vector<int> axis_in; | ||
|
||
for (size_t i = 0; i < nums; ++i) { | ||
int a = dims[i]; | ||
if (a < 0) { | ||
a += (input_dim.size()); | ||
} | ||
axis_in.emplace_back(a); | ||
int sh = (0 - shifts_data[i]) % input_dim[a]; | ||
if (sh < 0) { | ||
sh += input_dim[a]; | ||
} | ||
shifts_in.emplace_back(sh); | ||
} | ||
|
||
int r = xpu::roll(dev_ctx.x_context(), | ||
reinterpret_cast<const XPUType*>(out_grad.data<T>()), | ||
reinterpret_cast<XPUType*>(x_grad->data<T>()), | ||
xshape, | ||
shifts_in, | ||
axis_in); | ||
PADDLE_ENFORCE_XDNN_SUCCESS(r, "roll"); | ||
} | ||
|
||
} // namespace phi | ||
|
||
PD_REGISTER_KERNEL(roll_grad, XPU, ALL_LAYOUT, phi::RollGradKernel, float) {} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/phi/kernels/roll_kernel.h" | ||
|
||
#include "paddle/phi/backends/xpu/enforce_xpu.h" | ||
#include "paddle/phi/core/kernel_registry.h" | ||
|
||
namespace phi { | ||
template <typename T, typename Context> | ||
void RollKernel(const Context& dev_ctx, | ||
const DenseTensor& x, | ||
const IntArray& shifts, | ||
const std::vector<int64_t>& axis, | ||
DenseTensor* out) { | ||
using XPUType = typename XPUTypeTrait<T>::Type; | ||
auto shifts_data = shifts.GetData(); | ||
dev_ctx.template Alloc<T>(out); | ||
DDim input_dim = x.dims(); | ||
std::vector<int> xshape; | ||
size_t nums = shifts_data.size(); | ||
for (int i = 0; i < input_dim.size(); ++i) { | ||
xshape.emplace_back(input_dim[i]); | ||
} | ||
|
||
auto dims = axis; | ||
|
||
// axis = none, reshape to 1-D tensor | ||
if (dims.size() == 0) { | ||
dims.push_back(0l); | ||
input_dim = phi::Dim<1>(x.numel()); | ||
} | ||
std::vector<int> shifts_in; | ||
std::vector<int> axis_in; | ||
|
||
for (size_t i = 0; i < nums; ++i) { | ||
int a = dims[i]; | ||
if (a < 0) { | ||
a += (input_dim.size()); | ||
} | ||
axis_in.emplace_back(a); | ||
int sh = shifts_data[i] % input_dim[a]; | ||
if (sh < 0) { | ||
sh += input_dim[a]; | ||
} | ||
shifts_in.emplace_back(sh); | ||
} | ||
int r = xpu::roll(dev_ctx.x_context(), | ||
reinterpret_cast<const XPUType*>(x.data<T>()), | ||
reinterpret_cast<XPUType*>(out->data<T>()), | ||
xshape, | ||
shifts_in, | ||
axis_in); | ||
PADDLE_ENFORCE_XDNN_SUCCESS(r, "roll"); | ||
} | ||
|
||
} // namespace phi | ||
|
||
PD_REGISTER_KERNEL(roll, XPU, ALL_LAYOUT, phi::RollKernel, float) {} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/phi/kernels/strided_slice_grad_kernel.h" | ||
|
||
#include "paddle/phi/backends/xpu/enforce_xpu.h" | ||
#include "paddle/phi/core/kernel_registry.h" | ||
|
||
namespace phi { | ||
|
||
template <typename T, typename Context> | ||
void StridedSliceRawGradKernel(const Context& dev_ctx, | ||
const DenseTensor& x, | ||
const DenseTensor& out_grad, | ||
const std::vector<int>& axes, | ||
const IntArray& starts, | ||
const IntArray& ends, | ||
const IntArray& strides, | ||
const std::vector<int>& infer_flags, | ||
const std::vector<int>& decrease_axis, | ||
DenseTensor* x_grad) { | ||
using XPUType = typename XPUTypeTrait<T>::Type; | ||
DDim in_dims = x.dims(); | ||
dev_ctx.template Alloc<T>(x_grad); | ||
|
||
auto starts_ = starts.GetData(); | ||
auto ends_ = ends.GetData(); | ||
auto strides_ = strides.GetData(); | ||
|
||
std::vector<int> xshape; | ||
std::vector<int> starts_in(in_dims.size(), 0); | ||
std::vector<int> ends_in; | ||
std::vector<int> strides_in(in_dims.size(), 1); | ||
|
||
for (int i = 0; i < in_dims.size(); ++i) { | ||
xshape.emplace_back(in_dims[i]); | ||
ends_in.emplace_back(in_dims[i]); | ||
} | ||
int num = axes.size(); | ||
|
||
for (int i = 0; i < num; ++i) { | ||
PADDLE_ENFORCE_EQ( | ||
strides_[i] > 0, | ||
true, | ||
errors::InvalidArgument("Currently, XPU strided slice kernel does not", | ||
"support reverse strided slice")); | ||
int cur_axe = axes[i]; | ||
int st = starts_[i]; | ||
if (st > xshape[cur_axe]) { | ||
st = xshape[cur_axe]; | ||
} | ||
if (st < 0) { | ||
st += xshape[cur_axe]; | ||
} | ||
starts_in[cur_axe] = st; | ||
|
||
int end = ends_[i]; | ||
if (end > xshape[cur_axe]) { | ||
end = xshape[cur_axe]; | ||
} | ||
if (end < 0) { | ||
end += xshape[cur_axe]; | ||
} | ||
|
||
ends_in[cur_axe] = end; | ||
strides_in[cur_axe] = strides_[i]; | ||
} | ||
|
||
int r = xpu::strided_slice_grad( | ||
dev_ctx.x_context(), | ||
reinterpret_cast<const XPUType*>(out_grad.data<T>()), | ||
reinterpret_cast<XPUType*>(x_grad->data<T>()), | ||
xshape, | ||
starts_in, | ||
ends_in, | ||
strides_in); | ||
PADDLE_ENFORCE_XDNN_SUCCESS(r, "strided_slice_grad"); | ||
} | ||
|
||
} // namespace phi | ||
|
||
PD_REGISTER_KERNEL(strided_slice_raw_grad, | ||
XPU, | ||
ALL_LAYOUT, | ||
phi::StridedSliceRawGradKernel, | ||
int, | ||
int16_t, | ||
float, | ||
phi::dtype::float16) {} |
Oops, something went wrong.