Skip to content

Commit 101c9bb

Browse files
authored
Optimization for DropoutNd on Host side (#51934)
* first commit * fix bugs * remove_useless sync
1 parent f8a8dd5 commit 101c9bb

File tree

3 files changed

+55
-48
lines changed

3 files changed

+55
-48
lines changed

paddle/phi/kernels/funcs/dropout_impl.cu.h

Lines changed: 40 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ namespace funcs {
4141
template <typename T>
4242
struct DstFunctor {
4343
using MT = typename phi::kps::details::MPTypeTrait<T>::Type;
44-
MT factor;
44+
4545
HOSTDEVICE inline DstFunctor(const float retain_prob,
4646
const bool is_upscale_in_train,
4747
const int64_t num)
@@ -67,17 +67,12 @@ struct DstFunctor {
6767
const float retain_prob_;
6868
const bool is_upscale_in_train_;
6969
const int64_t num_;
70+
MT factor;
7071
};
7172

7273
template <typename T>
7374
struct MaskFunctor {
74-
const float retain_prob_;
75-
using MT = typename phi::kps::details::MPTypeTrait<T>::Type;
76-
MT factor;
77-
HOSTDEVICE inline MaskFunctor(const float retain_prob)
78-
: retain_prob_(retain_prob) {
79-
factor = static_cast<MT>(1.0f / retain_prob_);
80-
}
75+
explicit MaskFunctor(const float retain_prob) : retain_prob_(retain_prob) {}
8176

8277
HOSTDEVICE inline void operator()(T* dst, const float* rand, int num) const {
8378
static constexpr int kCount =
@@ -88,14 +83,14 @@ struct MaskFunctor {
8883
dst[i] = rand[i] < retain_prob_ ? static_cast<T>(1) : static_cast<T>(0);
8984
}
9085
}
86+
87+
private:
88+
float retain_prob_;
9189
};
9290

9391
template <typename T>
9492
struct DstMaskFunctor {
95-
const float retain_prob_;
96-
const bool is_upscale_in_train_;
9793
using MT = typename phi::kps::details::MPTypeTrait<T>::Type;
98-
MT factor;
9994
HOSTDEVICE inline DstMaskFunctor(const float retain_prob,
10095
const bool is_upscale_in_train)
10196
: retain_prob_(retain_prob), is_upscale_in_train_(is_upscale_in_train) {
@@ -122,6 +117,11 @@ struct DstMaskFunctor {
122117
}
123118
}
124119
}
120+
121+
private:
122+
MT factor;
123+
float retain_prob_;
124+
bool is_upscale_in_train_;
125125
};
126126

127127
template <typename T>
@@ -172,9 +172,6 @@ __global__ void VectorizedRandomGenerator(const size_t n,
172172
&mask_result[0], &dst_mask[kCount], Cast());
173173
kps::WriteData<uint8_t, kCount, 1, false>(
174174
mask + fix, &mask_result[0], deal_size);
175-
if (fix > idx * kCount + 1) {
176-
__syncthreads();
177-
}
178175
}
179176
int remainder = n - fix;
180177
if (remainder > 0) {
@@ -190,7 +187,6 @@ __global__ void VectorizedRandomGenerator(const size_t n,
190187
&mask_result[0], &dst_mask[kCount], Cast());
191188
kps::WriteData<uint8_t, kCount, 1, true>(
192189
mask + fix, &mask_result[0], remainder);
193-
__syncthreads();
194190
}
195191
}
196192

@@ -204,11 +200,17 @@ __global__ void DropOutNdForwardKernel(
204200
uint64_t increment,
205201
size_t main_offset,
206202
DstFunctor<T> dst_functor,
203+
MaskFunctor<T> mask_functor,
207204
T* y,
208205
int64_t N,
209-
kps::details::BroadcastConfig broadcast_config) {
206+
kps::details::BroadcastConfig broadcast_config,
207+
const uint64_t* seed_ptr) {
210208
// Vectorized Generate Mask
211209
// kCount is 4 for curand_uniform4 is used
210+
if (seed_ptr) {
211+
seed = seed_ptr[0];
212+
}
213+
212214
constexpr int kCount = phi::funcs::uniform_distribution<float>::kReturnsCount;
213215
size_t idx = static_cast<size_t>(BLOCK_ID_X * BLOCK_NUM_X);
214216
size_t stride = BLOCK_NUM_X * GRID_NUM_X * kCount;
@@ -229,8 +231,6 @@ __global__ void DropOutNdForwardKernel(
229231
int deal_size = BLOCK_NUM_X * kCount;
230232

231233
size_t fix = idx * kCount;
232-
233-
auto mask_functor = MaskFunctor<T>(1.0f - dropout_prob);
234234
for (; fix < main_offset; fix += stride) {
235235
kps::ReadData<T, kCount, 1, false>(&dst_mask[0], src + fix, deal_size);
236236
kps::ElementwiseRandom<SType, float, kCount, Rand>(
@@ -244,9 +244,6 @@ __global__ void DropOutNdForwardKernel(
244244
&mask_result[0], &dst_mask[0], Cast());
245245
kps::WriteData<uint8_t, kCount, 1, false>(
246246
mask + fix, &mask_result[0], deal_size);
247-
if (fix > idx * kCount + 1) {
248-
__syncthreads();
249-
}
250247
}
251248
int remainder = n - fix;
252249
if (remainder > 0) {
@@ -261,7 +258,6 @@ __global__ void DropOutNdForwardKernel(
261258
&mask_result[0], &dst_mask[0], Cast());
262259
kps::WriteData<uint8_t, kCount, 1, true>(
263260
mask + fix, &mask_result[0], remainder);
264-
__syncthreads();
265261
}
266262
// Broadcast mask data and do elementwise operaiton with DstFunctor
267263
CUDA_KERNEL_LOOP(i, N) {
@@ -347,24 +343,32 @@ void DropoutFwGPUKernelDriver(
347343

348344
auto offset =
349345
((x_numel - 1) / (grid_size * block_size * kVecSize) + 1) * kVecSize;
350-
GetSeedDataAndIncrement(
351-
dev_ctx, seed, is_fix_seed, seed_val, offset, &seed_data, &increment);
352346
size_t main_offset =
353347
size / (block_size * kVecSize) * (block_size * kVecSize);
354348

355349
if (is_dropout_nd) {
356350
auto dst_functor =
357351
DstFunctor<T>(1.0f - dropout_prob, upscale_in_train, x_numel);
358352

359-
auto input_x_dims = x.dims();
360-
auto mask_dims = mask->dims();
361-
std::vector<int64_t> out_dims = phi::vectorize<int64_t>(input_x_dims);
362-
std::vector<int64_t> in_dims = phi::vectorize<int64_t>(mask_dims);
363-
reverse(out_dims.begin(), out_dims.end());
364-
reverse(in_dims.begin(), in_dims.end());
353+
std::vector<int64_t> out_dims = phi::vectorize<int64_t>(x.dims());
354+
std::vector<int64_t> in_dims = phi::vectorize<int64_t>(mask->dims());
355+
std::reverse(out_dims.begin(), out_dims.end());
356+
std::reverse(in_dims.begin(), in_dims.end());
365357
kps::details::BroadcastConfig broadcast_config(
366358
out_dims, in_dims, x.dims().size());
367359

360+
auto mask_functor = MaskFunctor<T>(1.0f - dropout_prob);
361+
bool copy_in_kernel = GetSeedDataAndIncrement(dev_ctx,
362+
seed,
363+
is_fix_seed,
364+
seed_val,
365+
offset,
366+
&seed_data,
367+
&increment,
368+
true);
369+
const uint64_t* seed_ptr =
370+
copy_in_kernel ? seed->data<uint64_t>() : nullptr;
371+
368372
DropOutNdForwardKernel<T>
369373
<<<grid_size, block_size, 0, stream>>>(size,
370374
seed_data,
@@ -374,10 +378,15 @@ void DropoutFwGPUKernelDriver(
374378
increment,
375379
main_offset,
376380
dst_functor,
381+
mask_functor,
377382
y_data,
378383
y->numel(),
379-
broadcast_config);
384+
broadcast_config,
385+
seed_ptr);
380386
} else {
387+
bool copy_in_kernel = GetSeedDataAndIncrement(
388+
dev_ctx, seed, is_fix_seed, seed_val, offset, &seed_data, &increment);
389+
381390
#define PD_DROPOUT_KERNEL_NAME VectorizedRandomGenerator<T>
382391
PD_RECORD_CUDA_GRAPH_RANDOM_KERNEL(!is_fix_seed,
383392
PD_DROPOUT_KERNEL_NAME,

paddle/phi/kernels/funcs/dropout_impl_util.h

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,27 +22,33 @@ limitations under the License. */
2222
namespace phi {
2323
namespace funcs {
2424

25-
inline void GetSeedDataAndIncrement(const phi::GPUContext& dev_ctx,
25+
inline bool GetSeedDataAndIncrement(const phi::GPUContext& dev_ctx,
2626
const phi::DenseTensor* seed,
2727
const bool is_fix_seed,
2828
const int seed_val,
2929
const int offset,
3030
uint64_t* seed_data,
31-
uint64_t* increment) {
31+
uint64_t* increment,
32+
bool use_copy = true) {
3233
auto gen_cuda = dev_ctx.GetGenerator();
3334

3435
if (seed) {
35-
phi::DenseTensor seed_cpu_tensor;
36-
phi::Copy(dev_ctx, *seed, phi::CPUPlace(), true, &seed_cpu_tensor);
37-
*seed_data = static_cast<uint64_t>(seed_cpu_tensor.data<int>()[0]);
36+
if (use_copy) {
37+
phi::DenseTensor seed_cpu_tensor;
38+
phi::Copy(dev_ctx, *seed, phi::CPUPlace(), true, &seed_cpu_tensor);
39+
*seed_data = static_cast<uint64_t>(seed_cpu_tensor.data<int>()[0]);
40+
}
3841
*increment = offset;
42+
return true;
3943
} else if (!is_fix_seed) {
4044
auto seed_offset = gen_cuda->IncrementOffset(offset);
4145
*seed_data = seed_offset.first;
4246
*increment = seed_offset.second;
47+
return false;
4348
} else {
4449
*seed_data = seed_val;
4550
*increment = offset;
51+
return false;
4652
}
4753
}
4854

paddle/phi/kernels/fusion/gpu/fused_linear_param_grad_add_kernel.cu

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -67,18 +67,10 @@ void FusedLinearParamGradAddImpl(const Context &ctx,
6767
dout_copy.Resize({M, N});
6868
if (kIsMultiPrecision) {
6969
*dbias_out = phi::Sum<T, Context>(
70-
ctx,
71-
dout_copy,
72-
{0},
73-
paddle::experimental::CppTypeToDataType<MT>::Type(),
74-
false);
70+
ctx, dout_copy, {0}, phi::CppTypeToDataType<MT>::Type(), false);
7571
} else {
7672
*dbias_out = phi::Sum<T, Context>(
77-
ctx,
78-
dout_copy,
79-
{0},
80-
paddle::experimental::CppTypeToDataType<T>::Type(),
81-
false);
73+
ctx, dout_copy, {0}, phi::CppTypeToDataType<T>::Type(), false);
8274
}
8375
}
8476

@@ -141,12 +133,12 @@ void FusedLinearParamGradAdd(const Context &ctx,
141133
if (multi_precision) {
142134
PADDLE_ENFORCE_EQ(
143135
dweight_out->dtype(),
144-
paddle::experimental::CppTypeToDataType<MT>::Type(),
136+
phi::CppTypeToDataType<MT>::Type(),
145137
phi::errors::InvalidArgument("Invaid data type error."));
146138
} else {
147139
PADDLE_ENFORCE_EQ(
148140
dweight_out->dtype(),
149-
paddle::experimental::CppTypeToDataType<T>::Type(),
141+
phi::CppTypeToDataType<T>::Type(),
150142
phi::errors::InvalidArgument("Invaid data type error."));
151143
}
152144
} else {

0 commit comments

Comments
 (0)