@@ -41,7 +41,7 @@ namespace funcs {
4141template <typename T>
4242struct DstFunctor {
4343 using MT = typename phi::kps::details::MPTypeTrait<T>::Type;
44- MT factor;
44+
4545 HOSTDEVICE inline DstFunctor (const float retain_prob,
4646 const bool is_upscale_in_train,
4747 const int64_t num)
@@ -67,17 +67,12 @@ struct DstFunctor {
6767 const float retain_prob_;
6868 const bool is_upscale_in_train_;
6969 const int64_t num_;
70+ MT factor;
7071};
7172
7273template <typename T>
7374struct MaskFunctor {
74- const float retain_prob_;
75- using MT = typename phi::kps::details::MPTypeTrait<T>::Type;
76- MT factor;
77- HOSTDEVICE inline MaskFunctor (const float retain_prob)
78- : retain_prob_(retain_prob) {
79- factor = static_cast <MT>(1 .0f / retain_prob_);
80- }
75+ explicit MaskFunctor (const float retain_prob) : retain_prob_(retain_prob) {}
8176
8277 HOSTDEVICE inline void operator ()(T* dst, const float * rand, int num) const {
8378 static constexpr int kCount =
@@ -88,14 +83,14 @@ struct MaskFunctor {
8883 dst[i] = rand[i] < retain_prob_ ? static_cast <T>(1 ) : static_cast <T>(0 );
8984 }
9085 }
86+
87+ private:
88+ float retain_prob_;
9189};
9290
9391template <typename T>
9492struct DstMaskFunctor {
95- const float retain_prob_;
96- const bool is_upscale_in_train_;
9793 using MT = typename phi::kps::details::MPTypeTrait<T>::Type;
98- MT factor;
9994 HOSTDEVICE inline DstMaskFunctor (const float retain_prob,
10095 const bool is_upscale_in_train)
10196 : retain_prob_(retain_prob), is_upscale_in_train_(is_upscale_in_train) {
@@ -122,6 +117,11 @@ struct DstMaskFunctor {
122117 }
123118 }
124119 }
120+
121+ private:
122+ MT factor;
123+ float retain_prob_;
124+ bool is_upscale_in_train_;
125125};
126126
127127template <typename T>
@@ -172,9 +172,6 @@ __global__ void VectorizedRandomGenerator(const size_t n,
172172 &mask_result[0 ], &dst_mask[kCount ], Cast ());
173173 kps::WriteData<uint8_t , kCount , 1 , false >(
174174 mask + fix, &mask_result[0 ], deal_size);
175- if (fix > idx * kCount + 1 ) {
176- __syncthreads ();
177- }
178175 }
179176 int remainder = n - fix;
180177 if (remainder > 0 ) {
@@ -190,7 +187,6 @@ __global__ void VectorizedRandomGenerator(const size_t n,
190187 &mask_result[0 ], &dst_mask[kCount ], Cast ());
191188 kps::WriteData<uint8_t , kCount , 1 , true >(
192189 mask + fix, &mask_result[0 ], remainder);
193- __syncthreads ();
194190 }
195191}
196192
@@ -204,11 +200,17 @@ __global__ void DropOutNdForwardKernel(
204200 uint64_t increment,
205201 size_t main_offset,
206202 DstFunctor<T> dst_functor,
203+ MaskFunctor<T> mask_functor,
207204 T* y,
208205 int64_t N,
209- kps::details::BroadcastConfig broadcast_config) {
206+ kps::details::BroadcastConfig broadcast_config,
207+ const uint64_t * seed_ptr) {
210208 // Vectorized Generate Mask
211209 // kCount is 4 for curand_uniform4 is used
210+ if (seed_ptr) {
211+ seed = seed_ptr[0 ];
212+ }
213+
212214 constexpr int kCount = phi::funcs::uniform_distribution<float >::kReturnsCount ;
213215 size_t idx = static_cast <size_t >(BLOCK_ID_X * BLOCK_NUM_X);
214216 size_t stride = BLOCK_NUM_X * GRID_NUM_X * kCount ;
@@ -229,8 +231,6 @@ __global__ void DropOutNdForwardKernel(
229231 int deal_size = BLOCK_NUM_X * kCount ;
230232
231233 size_t fix = idx * kCount ;
232-
233- auto mask_functor = MaskFunctor<T>(1 .0f - dropout_prob);
234234 for (; fix < main_offset; fix += stride) {
235235 kps::ReadData<T, kCount , 1 , false >(&dst_mask[0 ], src + fix, deal_size);
236236 kps::ElementwiseRandom<SType, float , kCount , Rand>(
@@ -244,9 +244,6 @@ __global__ void DropOutNdForwardKernel(
244244 &mask_result[0 ], &dst_mask[0 ], Cast ());
245245 kps::WriteData<uint8_t , kCount , 1 , false >(
246246 mask + fix, &mask_result[0 ], deal_size);
247- if (fix > idx * kCount + 1 ) {
248- __syncthreads ();
249- }
250247 }
251248 int remainder = n - fix;
252249 if (remainder > 0 ) {
@@ -261,7 +258,6 @@ __global__ void DropOutNdForwardKernel(
261258 &mask_result[0 ], &dst_mask[0 ], Cast ());
262259 kps::WriteData<uint8_t , kCount , 1 , true >(
263260 mask + fix, &mask_result[0 ], remainder);
264- __syncthreads ();
265261 }
266262 // Broadcast mask data and do elementwise operaiton with DstFunctor
267263 CUDA_KERNEL_LOOP (i, N) {
@@ -347,24 +343,32 @@ void DropoutFwGPUKernelDriver(
347343
348344 auto offset =
349345 ((x_numel - 1 ) / (grid_size * block_size * kVecSize ) + 1 ) * kVecSize ;
350- GetSeedDataAndIncrement (
351- dev_ctx, seed, is_fix_seed, seed_val, offset, &seed_data, &increment);
352346 size_t main_offset =
353347 size / (block_size * kVecSize ) * (block_size * kVecSize );
354348
355349 if (is_dropout_nd) {
356350 auto dst_functor =
357351 DstFunctor<T>(1 .0f - dropout_prob, upscale_in_train, x_numel);
358352
359- auto input_x_dims = x.dims ();
360- auto mask_dims = mask->dims ();
361- std::vector<int64_t > out_dims = phi::vectorize<int64_t >(input_x_dims);
362- std::vector<int64_t > in_dims = phi::vectorize<int64_t >(mask_dims);
363- reverse (out_dims.begin (), out_dims.end ());
364- reverse (in_dims.begin (), in_dims.end ());
353+ std::vector<int64_t > out_dims = phi::vectorize<int64_t >(x.dims ());
354+ std::vector<int64_t > in_dims = phi::vectorize<int64_t >(mask->dims ());
355+ std::reverse (out_dims.begin (), out_dims.end ());
356+ std::reverse (in_dims.begin (), in_dims.end ());
365357 kps::details::BroadcastConfig broadcast_config (
366358 out_dims, in_dims, x.dims ().size ());
367359
360+ auto mask_functor = MaskFunctor<T>(1 .0f - dropout_prob);
361+ bool copy_in_kernel = GetSeedDataAndIncrement (dev_ctx,
362+ seed,
363+ is_fix_seed,
364+ seed_val,
365+ offset,
366+ &seed_data,
367+ &increment,
368+ true );
369+ const uint64_t * seed_ptr =
370+ copy_in_kernel ? seed->data <uint64_t >() : nullptr ;
371+
368372 DropOutNdForwardKernel<T>
369373 <<<grid_size, block_size, 0 , stream>>>(size,
370374 seed_data,
@@ -374,10 +378,15 @@ void DropoutFwGPUKernelDriver(
374378 increment,
375379 main_offset,
376380 dst_functor,
381+ mask_functor,
377382 y_data,
378383 y->numel (),
379- broadcast_config);
384+ broadcast_config,
385+ seed_ptr);
380386 } else {
387+ bool copy_in_kernel = GetSeedDataAndIncrement (
388+ dev_ctx, seed, is_fix_seed, seed_val, offset, &seed_data, &increment);
389+
381390#define PD_DROPOUT_KERNEL_NAME VectorizedRandomGenerator<T>
382391 PD_RECORD_CUDA_GRAPH_RANDOM_KERNEL (!is_fix_seed,
383392 PD_DROPOUT_KERNEL_NAME,
0 commit comments