Skip to content

Add the transformop parameter in TensorReduceFunctorImpl #38135

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 27 commits into from
Dec 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
7d58b91
Merge pull request #1 from PaddlePaddle/develop
AnnaTrainingG Mar 25, 2021
1021e08
Merge pull request #2 from PaddlePaddle/develop
AnnaTrainingG Mar 29, 2021
43f53fe
Merge pull request #3 from PaddlePaddle/develop
AnnaTrainingG Apr 19, 2021
d25ab26
Merge pull request #4 from PaddlePaddle/develop
AnnaTrainingG May 7, 2021
8c8717f
Merge pull request #5 from PaddlePaddle/develop
AnnaTrainingG May 25, 2021
9ddf5e8
Merge pull request #6 from PaddlePaddle/develop
AnnaTrainingG May 26, 2021
b0cbcca
Merge pull request #9 from PaddlePaddle/develop
AnnaTrainingG Jun 1, 2021
cdecaf0
Merge pull request #14 from PaddlePaddle/develop
AnnaTrainingG Jun 11, 2021
0da14c9
Merge pull request #16 from PaddlePaddle/develop
AnnaTrainingG Jun 15, 2021
ca95763
Merge pull request #17 from PaddlePaddle/develop
AnnaTrainingG Jun 22, 2021
25ba21c
Merge pull request #18 from PaddlePaddle/develop
AnnaTrainingG Jul 5, 2021
3ce9983
Merge pull request #19 from PaddlePaddle/develop
AnnaTrainingG Jul 6, 2021
61842ed
Merge pull request #20 from PaddlePaddle/develop
AnnaTrainingG Jul 12, 2021
0e2c73b
Merge pull request #21 from PaddlePaddle/develop
AnnaTrainingG Jul 28, 2021
c1e59cf
Merge pull request #22 from PaddlePaddle/develop
AnnaTrainingG Aug 2, 2021
3a54149
Merge pull request #23 from PaddlePaddle/develop
AnnaTrainingG Aug 4, 2021
7addd79
Merge pull request #24 from PaddlePaddle/develop
AnnaTrainingG Aug 11, 2021
1e843d1
Merge pull request #25 from PaddlePaddle/develop
AnnaTrainingG Aug 23, 2021
e1a92d6
Merge pull request #26 from PaddlePaddle/develop
AnnaTrainingG Sep 1, 2021
05da032
Merge pull request #27 from PaddlePaddle/develop
AnnaTrainingG Sep 3, 2021
e1fe6dc
Merge pull request #28 from PaddlePaddle/develop
AnnaTrainingG Sep 6, 2021
80e0684
Add the transformop parameter in TensorReduceFunctorImpl
AnnaTrainingG Dec 14, 2021
b32d42e
Merge https://github.com/niuliling123/Paddle into ReduceRename_40913
AnnaTrainingG Dec 14, 2021
d29f91b
update elementwise_sub
AnnaTrainingG Dec 15, 2021
bf9b49f
update margin_cross_entropy_op.cu
AnnaTrainingG Dec 15, 2021
5946063
update prod
AnnaTrainingG Dec 15, 2021
cb50b5b
Update elementwise_sub_op.cu
AnnaTrainingG Dec 15, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 4 additions & 25 deletions paddle/fluid/operators/clip_by_norm_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,29 +18,6 @@ limitations under the License. */
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename Tx, typename Ty = Tx>
struct SquareTransformer {
HOSTDEVICE explicit inline SquareTransformer(int n) {}

HOSTDEVICE inline Ty operator()(const Tx& x) const {
return static_cast<Ty>(x) * static_cast<Ty>(x);
}

HOSTDEVICE inline Ty operator()(const Tx* x) const {
return static_cast<Ty>(x[0]) * static_cast<Ty>(x[0]);
}
};

template <typename Tx, typename Ty = Tx>
struct SquareSum {
using Transformer = SquareTransformer<Tx, Ty>;

inline Ty initial() { return static_cast<Ty>(0.0f); }

__device__ __forceinline__ Ty operator()(const Ty& a, const Ty& b) const {
return b + a;
}
};

template <>
class ClipByNormKernel<platform::CUDADeviceContext, platform::float16>
Expand Down Expand Up @@ -97,8 +74,10 @@ class ClipByNormKernel<platform::CUDADeviceContext, platform::float16>
}
Tensor tmp = context.AllocateTmpTensor<float, platform::CUDADeviceContext>(
{1}, dev_ctx);
TensorReduceFunctorImpl<platform::float16, float, SquareSum>(
*input, &tmp, reduce_dims, dev_ctx.stream());
TensorReduceFunctorImpl<platform::float16, float, kps::AddFunctor,
kps::SquareFunctor<platform::float16, float>>(
*input, &tmp, kps::SquareFunctor<platform::float16, float>(),
reduce_dims, dev_ctx.stream());
auto tmp_eigen = EigenVector<float>::Flatten(tmp);
auto x_norm = tmp_eigen.sqrt();

Expand Down
7 changes: 4 additions & 3 deletions paddle/fluid/operators/elementwise/elementwise_add_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ limitations under the License. */
#include "paddle/fluid/framework/pten_utils.h"
#include "paddle/fluid/operators/elementwise/elementwise_add_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/float16.h"
Expand Down Expand Up @@ -91,7 +90,8 @@ default_elementwise_add_grad(const framework::ExecutionContext& ctx,
}
std::vector<int> reduce_dims = GetReduceDim(x->dims(), out->dims(), axis);
gpuStream_t stream = ctx.cuda_device_context().stream();
TensorReduceFunctorImpl<T, T, CustomSum>(*dout, dx, reduce_dims, stream);
TensorReduceFunctorImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
*dout, dx, kps::IdentityFunctor<T>(), reduce_dims, stream);
}
}
// dy
Expand All @@ -106,7 +106,8 @@ default_elementwise_add_grad(const framework::ExecutionContext& ctx,
} else {
std::vector<int> reduce_dims = GetReduceDim(y->dims(), out->dims(), axis);
gpuStream_t stream = ctx.cuda_device_context().stream();
TensorReduceFunctorImpl<T, T, CustomSum>(*dout, dy, reduce_dims, stream);
TensorReduceFunctorImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
*dout, dy, kps::IdentityFunctor<T>(), reduce_dims, stream);
}
}
}
Expand Down
7 changes: 4 additions & 3 deletions paddle/fluid/operators/elementwise/elementwise_sub_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ limitations under the License. */

#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/elementwise/elementwise_sub_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/float16.h"
Expand Down Expand Up @@ -69,7 +68,8 @@ default_elementwise_sub_grad(const framework::ExecutionContext& ctx,
}
std::vector<int> reduce_dims = GetReduceDim(x->dims(), out->dims(), axis);
gpuStream_t stream = ctx.cuda_device_context().stream();
TensorReduceFunctorImpl<T, T, CustomSum>(*dout, dx, reduce_dims, stream);
TensorReduceFunctorImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
*dout, dx, kps::IdentityFunctor<T>(), reduce_dims, stream);
}
}
// dy
Expand All @@ -90,7 +90,8 @@ default_elementwise_sub_grad(const framework::ExecutionContext& ctx,
} else {
std::vector<int> reduce_dims = GetReduceDim(y->dims(), out->dims(), axis);
gpuStream_t stream = ctx.cuda_device_context().stream();
TensorReduceFunctorImpl<T, T, CustomSub>(*dout, dy, reduce_dims, stream);
TensorReduceFunctorImpl<T, T, kps::AddFunctor, kps::InverseFunctor<T>>(
*dout, dy, kps::InverseFunctor<T>(), reduce_dims, stream);
}
}
}
Expand Down
7 changes: 3 additions & 4 deletions paddle/fluid/operators/fused/attn_gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,10 @@ limitations under the License. */

#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h"
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"

namespace paddle {
namespace operators {

// support gemm-nt and gemm-nn, which is used in fused_attention_op.
template <typename T>
class AttnMatMul {
Expand Down Expand Up @@ -165,8 +164,8 @@ class AttnMatMul {
(input_dims[2] == output_dims[0]));
if (support_case_1 || support_case_2) {
gpuStream_t stream = dev_ctx_.stream();
TensorReduceFunctorImpl<T, T, CustomSum>(*d_output, d_bias, {0, 1},
stream);
TensorReduceFunctorImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
*d_output, d_bias, kps::IdentityFunctor<T>(), {0, 1}, stream);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Only support reduce when the input dims are [0,1,2,3,4] and "
Expand Down
22 changes: 6 additions & 16 deletions paddle/fluid/operators/margin_cross_entropy_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ namespace cub = hipcub;
#include "paddle/fluid/operators/margin_cross_entropy_op.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/softmax_impl.h"
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.h"
#include "paddle/fluid/string/string_helper.h"

Expand Down Expand Up @@ -128,17 +127,6 @@ __global__ void AddMarginToPositiveLogitsKernel(
}
}

template <typename Tx, typename Ty = Tx>
struct ExpAndSum {
using Transformer = kps::ExpFunctor<Tx>;

inline Ty initial() { return static_cast<Ty>(0.0f); }

__device__ __forceinline__ Ty operator()(const Ty& a, const Ty& b) const {
return b + a;
}
};

template <typename T>
__global__ void ScaleLogitKernel(T* logits, const float scale, const int64_t N,
const int64_t D) {
Expand Down Expand Up @@ -309,8 +297,9 @@ class MarginCrossEntropyOpCUDAKernel : public framework::OpKernel<T> {
logits_max =
ctx.AllocateTmpTensor<T, platform::CUDADeviceContext>({N, 1}, dev_ctx);
T* logits_max_buff = logits_max.mutable_data<T>(place);
TensorReduceFunctorImpl<T, T, CustomMax>(softmax_2d, &logits_max, {1},
dev_ctx.stream());
TensorReduceFunctorImpl<T, T, kps::MaxFunctor, kps::IdentityFunctor<T>>(
softmax_2d, &logits_max, kps::IdentityFunctor<T>(), {1},
dev_ctx.stream());

#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
if (nranks > 1) {
Expand All @@ -330,8 +319,9 @@ class MarginCrossEntropyOpCUDAKernel : public framework::OpKernel<T> {
sum_exp_logits =
ctx.AllocateTmpTensor<T, platform::CUDADeviceContext>({N, 1}, dev_ctx);
T* sum_exp_logits_buff = sum_exp_logits.mutable_data<T>(place);
TensorReduceFunctorImpl<T, T, ExpAndSum>(softmax_2d, &sum_exp_logits, {1},
dev_ctx.stream());
TensorReduceFunctorImpl<T, T, kps::AddFunctor, kps::ExpFunctor<T>>(
softmax_2d, &sum_exp_logits, kps::ExpFunctor<T>(), {1},
dev_ctx.stream());

#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
if (nranks > 1) {
Expand Down
73 changes: 10 additions & 63 deletions paddle/fluid/operators/p_norm_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -59,28 +59,17 @@ __device__ __forceinline__ double inline_pow(double base, double exponent) {
return pow(base, exponent);
}

struct IdentityFunctor {
HOSTDEVICE explicit inline IdentityFunctor() {}
HOSTDEVICE explicit inline IdentityFunctor(int n) {}
template <typename T>
HOSTDEVICE inline T operator()(const T& x) const {
return static_cast<T>(x);
}
};

template <typename T>
struct NonzeroFunctor {
HOSTDEVICE explicit inline NonzeroFunctor() {}
HOSTDEVICE explicit inline NonzeroFunctor(int n) {}
template <typename T>
HOSTDEVICE inline T operator()(const T& x) const {
return static_cast<T>(static_cast<double>(x) != 0);
}
};

template <typename T>
struct AbsFunctor {
HOSTDEVICE explicit inline AbsFunctor() {}
HOSTDEVICE explicit inline AbsFunctor(int n) {}
template <typename T>
HOSTDEVICE inline T operator()(const T& x) const {
return static_cast<T>(inline_abs(x));
}
Expand All @@ -106,48 +95,6 @@ struct PowFunctor {
float porder;
};

template <typename Tx, typename Ty = Tx>
struct AbsAndMin {
using Transformer = AbsFunctor;
using MT = typename details::MPTypeTrait<Ty>::Type;
inline Ty initial() {
return static_cast<Ty>(std::numeric_limits<MT>::infinity());
}
__device__ __forceinline__ Ty operator()(const Ty& a, const Ty& b) const {
return (a < b) ? a : b;
}
};

template <typename Tx, typename Ty = Tx>
struct AbsAndMax {
using Transformer = AbsFunctor;
using MT = typename details::MPTypeTrait<Ty>::Type;
inline Ty initial() {
return static_cast<Ty>(-std::numeric_limits<MT>::infinity());
}
__device__ __forceinline__ Ty operator()(const Ty& a, const Ty& b) const {
return (a > b) ? a : b;
}
};

template <typename Tx, typename Ty = Tx>
struct NonzeroAndSum {
using Transformer = NonzeroFunctor;
inline Ty initial() { return static_cast<Ty>(0.0f); }
__device__ __forceinline__ Ty operator()(const Ty& a, const Ty& b) const {
return b + a;
}
};

template <typename Tx, typename Ty = Tx>
struct IdentityAndSum {
using Transformer = IdentityFunctor;
inline Ty initial() { return static_cast<Ty>(0.0f); }
__device__ __forceinline__ Ty operator()(const Ty& a, const Ty& b) const {
return b + a;
}
};

template <typename DeviceContext, typename T>
class PnormCUDAKernel : public framework::OpKernel<T> {
public:
Expand All @@ -167,14 +114,14 @@ class PnormCUDAKernel : public framework::OpKernel<T> {

using MT = typename details::MPTypeTrait<T>::Type;
if (porder == 0) {
TensorReduceFunctorImpl<T, T, NonzeroAndSum>(*in_x, out_norm, reduce_axis,
stream);
TensorReduceFunctorImpl<T, T, kps::AddFunctor, NonzeroFunctor<T>>(
*in_x, out_norm, NonzeroFunctor<T>(), reduce_axis, stream);
} else if (porder == INFINITY) {
TensorReduceFunctorImpl<T, T, AbsAndMax>(*in_x, out_norm, reduce_axis,
stream);
TensorReduceFunctorImpl<T, T, kps::MaxFunctor, AbsFunctor<T>>(
*in_x, out_norm, AbsFunctor<T>(), reduce_axis, stream);
} else if (porder == -INFINITY) {
TensorReduceFunctorImpl<T, T, AbsAndMin>(*in_x, out_norm, reduce_axis,
stream);
TensorReduceFunctorImpl<T, T, kps::MinFunctor, AbsFunctor<T>>(
*in_x, out_norm, AbsFunctor<T>(), reduce_axis, stream);
} else {
framework::Tensor tmp_x;
tmp_x.mutable_data<T>(xdim, ctx.GetPlace());
Expand All @@ -189,8 +136,8 @@ class PnormCUDAKernel : public framework::OpKernel<T> {
cuda_ctx, ins, &outs, func);
framework::Tensor tmp_y;
tmp_y.mutable_data<T>(ndim, ctx.GetPlace());
TensorReduceFunctorImpl<T, T, IdentityAndSum>(tmp_x, &tmp_y, reduce_axis,
stream);
TensorReduceFunctorImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
tmp_x, &tmp_y, kps::IdentityFunctor<T>(), reduce_axis, stream);
const framework::Tensor* tmp_norm = &tmp_y;
ins = {tmp_norm};
outs = {out_norm};
Expand Down
8 changes: 4 additions & 4 deletions paddle/fluid/operators/pool_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ limitations under the License. */
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/pooling.h"
#if defined(__HIPCC__) || defined(__NVCC__)
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#endif

Expand Down Expand Up @@ -203,13 +202,14 @@ class PoolKernel : public framework::OpKernel<T> {
} else if (pooling_type == "avg") {
std::vector<int> reduce_dim;
int reduce_num = getReduceNum(*in_x, out, data_format, &reduce_dim);

if (reduce_num > 0 &&
adaptive) { // for adaptive_avg_pool2d && output_size == 1
#if defined(__HIPCC__) || defined(__NVCC__)
auto stream = dev_ctx.stream();
TensorReduceFunctorImpl<T, T, CustomMean>(*in_x, out, reduce_dim,
stream);
TensorReduceFunctorImpl<T, T, kps::AddFunctor,
kps::DivideFunctor<T>>(
*in_x, out, kps::DivideFunctor<T>(reduce_num), reduce_dim,
stream);
#else // for cpu
paddle::operators::math::Pool2dFunctor<
DeviceContext, paddle::operators::math::AvgPool<T>, T>
Expand Down
14 changes: 3 additions & 11 deletions paddle/fluid/operators/prelu_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/prelu.h"
#include "paddle/fluid/operators/prelu_op.h"
#include "paddle/fluid/operators/reduce_ops/cub_reduce.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"

namespace paddle {
Expand Down Expand Up @@ -123,13 +123,6 @@ class PreluOpGradFunctor {
}
};

struct IdentityFunctor {
template <typename T>
HOSTDEVICE inline T operator()(const T& x) const {
return x;
}
};

template <typename DeviceContext, typename T>
class CUDAPReluGradKernel : public framework::OpKernel<T> {
public:
Expand Down Expand Up @@ -192,9 +185,8 @@ class CUDAPReluGradKernel : public framework::OpKernel<T> {
reduce_dims.push_back(i);
}

TensorReduce<T, T, cub::Sum, IdentityFunctor>(
dalpha_tmp, dalpha, reduce_dims, static_cast<T>(0), cub::Sum(),
IdentityFunctor(), stream);
TensorReduceFunctorImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
dalpha_tmp, dalpha, kps::IdentityFunctor<T>(), reduce_dims, stream);
}
};

Expand Down
3 changes: 1 addition & 2 deletions paddle/fluid/operators/reduce_ops/reduce_all_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
// limitations under the License.

#include "paddle/fluid/operators/reduce_ops/reduce_all_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"

REGISTER_OP_CUDA_KERNEL(
reduce_all,
ops::ReduceCudaKernel<bool, paddle::operators::CustomLogicalAnd>);
ops::ReduceCudaKernel<bool, kps::LogicalAndFunctor, kps::IdentityFunctor>);
3 changes: 1 addition & 2 deletions paddle/fluid/operators/reduce_ops/reduce_any_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,8 @@
// limitations under the License.

#include "paddle/fluid/operators/reduce_ops/reduce_any_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.h"

REGISTER_OP_CUDA_KERNEL(
reduce_any,
ops::ReduceCudaKernel<bool, paddle::operators::CustomLogicalOr>);
ops::ReduceCudaKernel<bool, kps::LogicalOrFunctor, kps::IdentityFunctor>);
10 changes: 5 additions & 5 deletions paddle/fluid/operators/reduce_ops/reduce_max_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.h"

// reduce_max
REGISTER_OP_CUDA_KERNEL(
reduce_max, ops::ReduceCudaKernel<float, paddle::operators::CustomMax>,
ops::ReduceCudaKernel<double, paddle::operators::CustomMax>,
ops::ReduceCudaKernel<int, paddle::operators::CustomMax>,
ops::ReduceCudaKernel<int64_t, paddle::operators::CustomMax>);
reduce_max,
ops::ReduceCudaKernel<float, kps::MaxFunctor, kps::IdentityFunctor>,
ops::ReduceCudaKernel<double, kps::MaxFunctor, kps::IdentityFunctor>,
ops::ReduceCudaKernel<int, kps::MaxFunctor, kps::IdentityFunctor>,
ops::ReduceCudaKernel<int64_t, kps::MaxFunctor, kps::IdentityFunctor>);
Loading