Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 40 additions & 31 deletions paddle/phi/kernels/funcs/dropout_impl.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ namespace funcs {
template <typename T>
struct DstFunctor {
using MT = typename phi::kps::details::MPTypeTrait<T>::Type;
MT factor;

HOSTDEVICE inline DstFunctor(const float retain_prob,
const bool is_upscale_in_train,
const int64_t num)
Expand All @@ -67,17 +67,12 @@ struct DstFunctor {
const float retain_prob_;
const bool is_upscale_in_train_;
const int64_t num_;
MT factor;
};

template <typename T>
struct MaskFunctor {
const float retain_prob_;
using MT = typename phi::kps::details::MPTypeTrait<T>::Type;
MT factor;
HOSTDEVICE inline MaskFunctor(const float retain_prob)
: retain_prob_(retain_prob) {
factor = static_cast<MT>(1.0f / retain_prob_);
}
explicit MaskFunctor(const float retain_prob) : retain_prob_(retain_prob) {}

HOSTDEVICE inline void operator()(T* dst, const float* rand, int num) const {
static constexpr int kCount =
Expand All @@ -88,14 +83,14 @@ struct MaskFunctor {
dst[i] = rand[i] < retain_prob_ ? static_cast<T>(1) : static_cast<T>(0);
}
}

private:
float retain_prob_;
};

template <typename T>
struct DstMaskFunctor {
const float retain_prob_;
const bool is_upscale_in_train_;
using MT = typename phi::kps::details::MPTypeTrait<T>::Type;
MT factor;
HOSTDEVICE inline DstMaskFunctor(const float retain_prob,
const bool is_upscale_in_train)
: retain_prob_(retain_prob), is_upscale_in_train_(is_upscale_in_train) {
Expand All @@ -122,6 +117,11 @@ struct DstMaskFunctor {
}
}
}

private:
MT factor;
float retain_prob_;
bool is_upscale_in_train_;
};

template <typename T>
Expand Down Expand Up @@ -172,9 +172,6 @@ __global__ void VectorizedRandomGenerator(const size_t n,
&mask_result[0], &dst_mask[kCount], Cast());
kps::WriteData<uint8_t, kCount, 1, false>(
mask + fix, &mask_result[0], deal_size);
if (fix > idx * kCount + 1) {
__syncthreads();
}
}
int remainder = n - fix;
if (remainder > 0) {
Expand All @@ -190,7 +187,6 @@ __global__ void VectorizedRandomGenerator(const size_t n,
&mask_result[0], &dst_mask[kCount], Cast());
kps::WriteData<uint8_t, kCount, 1, true>(
mask + fix, &mask_result[0], remainder);
__syncthreads();
}
}

Expand All @@ -204,11 +200,17 @@ __global__ void DropOutNdForwardKernel(
uint64_t increment,
size_t main_offset,
DstFunctor<T> dst_functor,
MaskFunctor<T> mask_functor,
T* y,
int64_t N,
kps::details::BroadcastConfig broadcast_config) {
kps::details::BroadcastConfig broadcast_config,
const uint64_t* seed_ptr) {
// Vectorized Generate Mask
// kCount is 4 for curand_uniform4 is used
if (seed_ptr) {
seed = seed_ptr[0];
}

constexpr int kCount = phi::funcs::uniform_distribution<float>::kReturnsCount;
size_t idx = static_cast<size_t>(BLOCK_ID_X * BLOCK_NUM_X);
size_t stride = BLOCK_NUM_X * GRID_NUM_X * kCount;
Expand All @@ -229,8 +231,6 @@ __global__ void DropOutNdForwardKernel(
int deal_size = BLOCK_NUM_X * kCount;

size_t fix = idx * kCount;

auto mask_functor = MaskFunctor<T>(1.0f - dropout_prob);
for (; fix < main_offset; fix += stride) {
kps::ReadData<T, kCount, 1, false>(&dst_mask[0], src + fix, deal_size);
kps::ElementwiseRandom<SType, float, kCount, Rand>(
Expand All @@ -244,9 +244,6 @@ __global__ void DropOutNdForwardKernel(
&mask_result[0], &dst_mask[0], Cast());
kps::WriteData<uint8_t, kCount, 1, false>(
mask + fix, &mask_result[0], deal_size);
if (fix > idx * kCount + 1) {
__syncthreads();
}
}
int remainder = n - fix;
if (remainder > 0) {
Expand All @@ -261,7 +258,6 @@ __global__ void DropOutNdForwardKernel(
&mask_result[0], &dst_mask[0], Cast());
kps::WriteData<uint8_t, kCount, 1, true>(
mask + fix, &mask_result[0], remainder);
__syncthreads();
}
// Broadcast mask data and do elementwise operaiton with DstFunctor
CUDA_KERNEL_LOOP(i, N) {
Expand Down Expand Up @@ -347,24 +343,32 @@ void DropoutFwGPUKernelDriver(

auto offset =
((x_numel - 1) / (grid_size * block_size * kVecSize) + 1) * kVecSize;
GetSeedDataAndIncrement(
dev_ctx, seed, is_fix_seed, seed_val, offset, &seed_data, &increment);
size_t main_offset =
size / (block_size * kVecSize) * (block_size * kVecSize);

if (is_dropout_nd) {
auto dst_functor =
DstFunctor<T>(1.0f - dropout_prob, upscale_in_train, x_numel);

auto input_x_dims = x.dims();
auto mask_dims = mask->dims();
std::vector<int64_t> out_dims = phi::vectorize<int64_t>(input_x_dims);
std::vector<int64_t> in_dims = phi::vectorize<int64_t>(mask_dims);
reverse(out_dims.begin(), out_dims.end());
reverse(in_dims.begin(), in_dims.end());
std::vector<int64_t> out_dims = phi::vectorize<int64_t>(x.dims());
std::vector<int64_t> in_dims = phi::vectorize<int64_t>(mask->dims());
std::reverse(out_dims.begin(), out_dims.end());
std::reverse(in_dims.begin(), in_dims.end());
kps::details::BroadcastConfig broadcast_config(
out_dims, in_dims, x.dims().size());

auto mask_functor = MaskFunctor<T>(1.0f - dropout_prob);
bool copy_in_kernel = GetSeedDataAndIncrement(dev_ctx,
seed,
is_fix_seed,
seed_val,
offset,
&seed_data,
&increment,
true);
const uint64_t* seed_ptr =
copy_in_kernel ? seed->data<uint64_t>() : nullptr;

DropOutNdForwardKernel<T>
<<<grid_size, block_size, 0, stream>>>(size,
seed_data,
Expand All @@ -374,10 +378,15 @@ void DropoutFwGPUKernelDriver(
increment,
main_offset,
dst_functor,
mask_functor,
y_data,
y->numel(),
broadcast_config);
broadcast_config,
seed_ptr);
} else {
bool copy_in_kernel = GetSeedDataAndIncrement(
dev_ctx, seed, is_fix_seed, seed_val, offset, &seed_data, &increment);

#define PD_DROPOUT_KERNEL_NAME VectorizedRandomGenerator<T>
PD_RECORD_CUDA_GRAPH_RANDOM_KERNEL(!is_fix_seed,
PD_DROPOUT_KERNEL_NAME,
Expand Down
16 changes: 11 additions & 5 deletions paddle/phi/kernels/funcs/dropout_impl_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,27 +22,33 @@ limitations under the License. */
namespace phi {
namespace funcs {

inline void GetSeedDataAndIncrement(const phi::GPUContext& dev_ctx,
inline bool GetSeedDataAndIncrement(const phi::GPUContext& dev_ctx,
const phi::DenseTensor* seed,
const bool is_fix_seed,
const int seed_val,
const int offset,
uint64_t* seed_data,
uint64_t* increment) {
uint64_t* increment,
bool use_copy = true) {
auto gen_cuda = dev_ctx.GetGenerator();

if (seed) {
phi::DenseTensor seed_cpu_tensor;
phi::Copy(dev_ctx, *seed, phi::CPUPlace(), true, &seed_cpu_tensor);
*seed_data = static_cast<uint64_t>(seed_cpu_tensor.data<int>()[0]);
if (use_copy) {
phi::DenseTensor seed_cpu_tensor;
phi::Copy(dev_ctx, *seed, phi::CPUPlace(), true, &seed_cpu_tensor);
*seed_data = static_cast<uint64_t>(seed_cpu_tensor.data<int>()[0]);
}
*increment = offset;
return true;
} else if (!is_fix_seed) {
auto seed_offset = gen_cuda->IncrementOffset(offset);
*seed_data = seed_offset.first;
*increment = seed_offset.second;
return false;
} else {
*seed_data = seed_val;
*increment = offset;
return false;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,18 +67,10 @@ void FusedLinearParamGradAddImpl(const Context &ctx,
dout_copy.Resize({M, N});
if (kIsMultiPrecision) {
*dbias_out = phi::Sum<T, Context>(
ctx,
dout_copy,
{0},
paddle::experimental::CppTypeToDataType<MT>::Type(),
false);
ctx, dout_copy, {0}, phi::CppTypeToDataType<MT>::Type(), false);
} else {
*dbias_out = phi::Sum<T, Context>(
ctx,
dout_copy,
{0},
paddle::experimental::CppTypeToDataType<T>::Type(),
false);
ctx, dout_copy, {0}, phi::CppTypeToDataType<T>::Type(), false);
}
}

Expand Down Expand Up @@ -141,12 +133,12 @@ void FusedLinearParamGradAdd(const Context &ctx,
if (multi_precision) {
PADDLE_ENFORCE_EQ(
dweight_out->dtype(),
paddle::experimental::CppTypeToDataType<MT>::Type(),
phi::CppTypeToDataType<MT>::Type(),
phi::errors::InvalidArgument("Invaid data type error."));
} else {
PADDLE_ENFORCE_EQ(
dweight_out->dtype(),
paddle::experimental::CppTypeToDataType<T>::Type(),
phi::CppTypeToDataType<T>::Type(),
phi::errors::InvalidArgument("Invaid data type error."));
}
} else {
Expand Down