Skip to content

Commit d611e48

Browse files
authored
Dropout optimize & clean broadcast inT and ElementwiseType (#52969)
* change judgement for DropoutGradGPUKernelDriver * add UnrollerWithoutVecSize and after this Loaddata to be refined * pass unittest * use same unroller with XPU * BroadcastWithInt64Index * BroadcastDataLoader template partial specialization * fix compile errs in ROCms * clean ElementwiseT and InT for BroadcastKernel * default axis and clean inT * remove redundant fast divmod computation * optimize drop_nd & drop_nd_grad * optimize BroadcastDataLoader bf16 fp16 * rm InT etc. after merge develop * delete constexpr for windows ci * fix conflict * fix conflic with develop * fix conflic * new clean * clean
1 parent a53ee94 commit d611e48

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+334
-461
lines changed

paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,13 @@
1919
namespace paddle {
2020
namespace operators {
2121

22-
template <ElementwiseType ET,
23-
typename InT,
24-
typename OutT,
25-
typename Functor,
26-
int NumOuts = 1>
22+
template <typename OutT, typename Functor, int NumOuts = 1>
2723
void LaunchElementwiseCudaKernel(
2824
const KPDevice &ctx,
2925
const std::vector<const phi::DenseTensor *> &ins,
3026
std::vector<phi::DenseTensor *> *outs,
31-
int axis,
32-
Functor func) {
27+
Functor func,
28+
int axis = -1) {
3329
std::vector<const phi::DenseTensor *> pt_inputs;
3430
std::vector<phi::DenseTensor *> pt_outputs;
3531
// TODO(YuanRisheng) *_tmp for cache DenseTensor, because the temporary
@@ -53,8 +49,8 @@ void LaunchElementwiseCudaKernel(
5349
for (int i = 0; i < pt_outputs_tmp.size(); i++) {
5450
pt_outputs.push_back(pt_outputs_tmp[i].get());
5551
}
56-
phi::funcs::BroadcastKernel<ET, InT, OutT, Functor, NumOuts>(
57-
ctx, pt_inputs, &pt_outputs, axis, func);
52+
phi::funcs::BroadcastKernel<OutT, Functor, NumOuts>(
53+
ctx, pt_inputs, &pt_outputs, func, axis);
5854
}
5955

6056
} // namespace operators

paddle/fluid/operators/elementwise/elementwise_op_function.h

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ void ElementwiseComputeEx(const framework::ExecutionContext &ctx,
188188
z->mutable_data<OutType>(ctx.GetPlace());
189189
const auto &dev_ctx = ctx.template device_context<DeviceContext>();
190190
phi::funcs::ElementwiseCompute<Functor, T, OutType>(
191-
dev_ctx, *x, *y, axis, func, z);
191+
dev_ctx, *x, *y, func, z, axis);
192192
}
193193

194194
// FusedElemwiseAndAct
@@ -1596,7 +1596,7 @@ static inline std::vector<int> GetReduceDim(const framework::DDim &in,
15961596

15971597
#if defined(__NVCC__) || defined(__HIPCC__)
15981598

1599-
template <ElementwiseType ET, typename T, typename Functor>
1599+
template <typename T, typename Functor>
16001600
void GetGradXAndYOut(const phi::GPUContext &dev_ctx,
16011601
const platform::Place &place,
16021602
int axis,
@@ -1605,20 +1605,19 @@ void GetGradXAndYOut(const phi::GPUContext &dev_ctx,
16051605
phi::DenseTensor *dx,
16061606
phi::DenseTensor *dy,
16071607
Functor func) {
1608-
phi::GetGradXAndYOut<ET, T, Functor>(
1608+
phi::GetGradXAndYOut<T, Functor>(
16091609
dev_ctx, place, axis, ins, *dout, dx, dy, func);
16101610
}
16111611

1612-
template <ElementwiseType ET, typename T, typename Functor>
1612+
template <typename T, typename Functor>
16131613
void GetGradXOrYOut(const phi::GPUContext &dev_ctx,
16141614
const platform::Place &place,
16151615
int axis,
16161616
std::vector<const phi::DenseTensor *> ins,
16171617
const phi::DenseTensor *dout,
16181618
phi::DenseTensor *dxy,
16191619
Functor func) {
1620-
phi::GetGradXOrYOut<ET, T, Functor>(
1621-
dev_ctx, place, axis, ins, *dout, dxy, func);
1620+
phi::GetGradXOrYOut<T, Functor>(dev_ctx, place, axis, ins, *dout, dxy, func);
16221621
}
16231622

16241623
#endif

paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,6 @@ limitations under the License. */
2323
namespace paddle {
2424
namespace operators {
2525

26-
using ElementwiseType = phi::ElementwiseType;
27-
2826
template <typename OutT, typename Functor, int NumOuts = 1>
2927
void LaunchSameDimsElementwiseCudaKernel(
3028
const KPDevice &ctx,

paddle/fluid/operators/fused/attn_gemm.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,8 @@ class AttnMatMul {
109109
// bias_out = output + bias
110110
std::vector<const phi::DenseTensor*> ins = {output, bias};
111111
std::vector<phi::DenseTensor*> outs = {bias_out};
112-
phi::funcs::BroadcastKernel<phi::ElementwiseType::kBinary, T, T>(
113-
dev_ctx_, ins, &outs, -1, phi::funcs::AddFunctor<T>());
112+
phi::funcs::BroadcastKernel<T>(
113+
dev_ctx_, ins, &outs, phi::funcs::AddFunctor<T>());
114114
}
115115
}
116116

paddle/fluid/operators/fused/attn_gemm_int8.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,8 @@ class AttnMatmulINT8 {
8585
// bias_out = output + bias
8686
std::vector<const phi::DenseTensor*> ins = {output, bias};
8787
std::vector<phi::DenseTensor*> outs = {bias_out};
88-
phi::funcs::BroadcastKernel<phi::ElementwiseType::kBinary, T, T>(
89-
dev_ctx_, ins, &outs, -1, phi::funcs::AddFunctor<T>());
88+
phi::funcs::BroadcastKernel<T>(
89+
dev_ctx_, ins, &outs, phi::funcs::AddFunctor<T>());
9090
PADDLE_ENFORCE_EQ(cudaGetLastError(),
9191
cudaSuccess,
9292
platform::errors::Fatal(
@@ -139,8 +139,8 @@ class AttnMatmulINT8 {
139139
// bias_out = output + bias
140140
std::vector<const phi::DenseTensor*> ins = {output, bias};
141141
std::vector<phi::DenseTensor*> outs = {bias_out};
142-
phi::funcs::BroadcastKernel<phi::ElementwiseType::kBinary, T, T>(
143-
dev_ctx_, ins, &outs, -1, phi::funcs::AddFunctor<T>());
142+
phi::funcs::BroadcastKernel<T>(
143+
dev_ctx_, ins, &outs, phi::funcs::AddFunctor<T>());
144144
PADDLE_ENFORCE_EQ(cudaGetLastError(),
145145
cudaSuccess,
146146
platform::errors::Fatal(

paddle/fluid/operators/fused/fmha_ref.h

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -255,12 +255,11 @@ class FMHARef {
255255
ins.emplace_back(src_mask_tensor);
256256
outs.emplace_back(src_mask_out_tensor);
257257
int elewise_add_axis = -1;
258-
phi::funcs::BroadcastKernel<phi::ElementwiseType::kBinary, T, T>(
259-
dev_ctx_,
260-
ins,
261-
&outs,
262-
elewise_add_axis,
263-
phi::funcs::AddFunctor<T>());
258+
phi::funcs::BroadcastKernel<T>(dev_ctx_,
259+
ins,
260+
&outs,
261+
phi::funcs::AddFunctor<T>(),
262+
elewise_add_axis);
264263

265264
phi::SoftmaxForwardCUDAKernelDriver<T>(
266265
dev_ctx_, *src_mask_out_tensor, softmax_axis, softmax_out_tensor);
@@ -432,12 +431,11 @@ class FMHARef {
432431
ins.emplace_back(src_mask_tensor);
433432
outs.emplace_back(src_mask_out_tensor);
434433
int elewise_add_axis = -1;
435-
phi::funcs::BroadcastKernel<phi::ElementwiseType::kBinary, T, T>(
436-
dev_ctx_,
437-
ins,
438-
&outs,
439-
elewise_add_axis,
440-
phi::funcs::AddFunctor<T>());
434+
phi::funcs::BroadcastKernel<T>(dev_ctx_,
435+
ins,
436+
&outs,
437+
phi::funcs::AddFunctor<T>(),
438+
elewise_add_axis);
441439

442440
phi::SoftmaxForwardCUDAKernelDriver<T>(
443441
dev_ctx_, *src_mask_out_tensor, softmax_axis, softmax_out_tensor);

paddle/fluid/operators/fused/fused_gate_attention.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -689,13 +689,13 @@ class FMHAGateRef {
689689
std::vector<const phi::DenseTensor*> ins = {
690690
qk_out, src_mask, nonbatched_bias};
691691
std::vector<phi::DenseTensor*> outs = {qk_out};
692-
phi::funcs::BroadcastKernel<phi::ElementwiseType::kTernary, T, T>(
693-
dev_ctx_, ins, &outs, -1, TernaryAddFunctor<T>());
692+
phi::funcs::BroadcastKernel<T>(
693+
dev_ctx_, ins, &outs, TernaryAddFunctor<T>());
694694
} else {
695695
std::vector<const phi::DenseTensor*> ins = {qk_out, src_mask};
696696
std::vector<phi::DenseTensor*> outs = {qk_out};
697-
phi::funcs::BroadcastKernel<phi::ElementwiseType::kBinary, T, T>(
698-
dev_ctx_, ins, &outs, -1, phi::funcs::AddFunctor<T>());
697+
phi::funcs::BroadcastKernel<T>(
698+
dev_ctx_, ins, &outs, phi::funcs::AddFunctor<T>());
699699
}
700700
phi::SoftmaxForwardCUDAKernelDriver<T>(dev_ctx_, *qk_out, -1, softmax_out);
701701
}

paddle/fluid/operators/fused_token_prune_op.cu

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,8 +141,7 @@ class FusedTokenPruneOpCUDAKernel : public framework::OpKernel<T> {
141141
ins.emplace_back(attn);
142142
ins.emplace_back(mask);
143143
outs.emplace_back(&attn_tmp);
144-
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
145-
dev_ctx, ins, &outs, -1, AttnMaskFunctor<T>());
144+
LaunchElementwiseCudaKernel<T>(dev_ctx, ins, &outs, AttnMaskFunctor<T>());
146145

147146
// 2. Reduce sum
148147
const std::vector<int64_t> reduce_dims{1, 2};

paddle/fluid/operators/reduce_ops/reduce_op.h

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -834,12 +834,11 @@ class ReduceCudaGradKernel : public framework::OpKernel<T> {
834834
}
835835

836836
using MPType = typename kps::details::MPTypeTrait<T>::Type;
837-
phi::ReduceGrad<T, TransformOp<T, MPType>>(
838-
dev_ctx,
839-
pt_d_out.get(),
840-
pt_d_x.get(),
841-
pt_out_dtype,
842-
TransformOp<T, MPType>(reduce_num));
837+
phi::ReduceGrad<TransformOp<T, MPType>>(dev_ctx,
838+
pt_d_out.get(),
839+
pt_d_x.get(),
840+
pt_out_dtype,
841+
TransformOp<T, MPType>(reduce_num));
843842
}
844843
};
845844

paddle/phi/kernels/cpu/bitwise_kernel.cc

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,15 @@ limitations under the License. */
2424

2525
namespace phi {
2626

27-
#define DEFINE_BITWISE_KERNEL(op_type) \
28-
template <typename T, typename Context> \
29-
void Bitwise##op_type##Kernel(const Context& dev_ctx, \
30-
const DenseTensor& x, \
31-
const DenseTensor& y, \
32-
DenseTensor* out) { \
33-
funcs::Bitwise##op_type##Functor<T> func; \
34-
funcs::ElementwiseCompute<funcs::Bitwise##op_type##Functor<T>, T, T>( \
35-
dev_ctx, x, y, -1, func, out); \
27+
#define DEFINE_BITWISE_KERNEL(op_type) \
28+
template <typename T, typename Context> \
29+
void Bitwise##op_type##Kernel(const Context& dev_ctx, \
30+
const DenseTensor& x, \
31+
const DenseTensor& y, \
32+
DenseTensor* out) { \
33+
funcs::Bitwise##op_type##Functor<T> func; \
34+
funcs::ElementwiseCompute<funcs::Bitwise##op_type##Functor<T>, T>( \
35+
dev_ctx, x, y, func, out); \
3636
}
3737

3838
DEFINE_BITWISE_KERNEL(And)

0 commit comments

Comments
 (0)