Skip to content

[Random op] remove FLAGS_use_curand of all Random OP's CUDA implementation #41308

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 40 additions & 111 deletions paddle/fluid/operators/dropout_impl.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,43 +38,9 @@ limitations under the License. */
#include "paddle/phi/kernels/funcs/distribution_helper.h"
#include "paddle/phi/kernels/funcs/functors.h"

DECLARE_bool(use_curand);

namespace paddle {
namespace operators {

template <typename T1, typename T2 = T1, typename OutT = T1>
struct DstMaskGenerator {
const float dropout_prob_;
const bool is_upscale_in_train_;
using MT = typename details::MPTypeTrait<T1>::Type;
MT factor;
HOSTDEVICE inline DstMaskGenerator(const float dropout_prob,
const bool is_upscale_in_train)
: dropout_prob_(dropout_prob), is_upscale_in_train_(is_upscale_in_train) {
factor = static_cast<MT>(1.0f / (1.0f - dropout_prob_));
}

HOSTDEVICE inline void operator()(OutT* dst, const T1* src_val,
const T2* rand, int num) const {
static constexpr int kCount =
phi::funcs::uniform_distribution<T2>::kReturnsCount;
// 0 ~ kCount -1 is dist , kCount ~ 2 * kCount - 1 is mask
#pragma unroll
for (int i = 0; i < kCount; i++) {
if (rand[i] < dropout_prob_) {
dst[i] = static_cast<T1>(0);
dst[i + kCount] = dst[i];
} else {
dst[i] = is_upscale_in_train_
? static_cast<T1>(static_cast<MT>(src_val[i]) * factor)
: static_cast<T1>(src_val[i]);
dst[i + kCount] = static_cast<T1>(1);
}
}
}
};

template <typename T1, typename T2 = T1, typename OutT = T1>
struct DstMaskFunctor {
const float retain_prob_;
Expand Down Expand Up @@ -113,7 +79,7 @@ __global__ void VectorizedRandomGenerator(const size_t n, uint64_t seed,
const T* src, MaskType* mask, T* dst,
bool is_upscale_in_train,
uint64_t increment,
size_t main_offset, bool use_curand) {
size_t main_offset) {
size_t idx = static_cast<size_t>(BLOCK_ID_X * BLOCK_NUM_X);
static constexpr int kCount =
phi::funcs::uniform_distribution<float>::kReturnsCount;
Expand All @@ -135,76 +101,41 @@ __global__ void VectorizedRandomGenerator(const size_t n, uint64_t seed,
int deal_size = BLOCK_NUM_X * kCount;

size_t fix = idx * kCount;
if (use_curand) {
auto dst_functor =
DstMaskFunctor<T, float>(1.0f - dropout_prob, is_upscale_in_train);
for (; fix < main_offset; fix += stride) {
kps::ReadData<T, kCount, 1, 1, false>(&dst_mask[0], src + fix, deal_size);
kps::ElementwiseRandom<SType, float, kCount, 1, Rand>(&rands[0], Rand(),
&state);
// dst
kps::OperatorTernary<T, float, T, DstMaskFunctor<T, float>>(
&dst_mask[0], &dst_mask[0], &rands[0], dst_functor, kCount);
kps::WriteData<T, kCount, 1, 1, false>(dst + fix, &dst_mask[0],
deal_size);
// mask
kps::ElementwiseUnary<T, MaskType, kCount, 1, 1, Cast>(
&mask_result[0], &dst_mask[kCount], Cast());
kps::WriteData<MaskType, kCount, 1, 1, false>(mask + fix, &mask_result[0],
deal_size);
if (fix > idx * kCount + 1) {
__syncthreads();
}
}
int remainder = n - fix;
if (remainder > 0) {
kps::ReadData<T, kCount, 1, 1, true>(&dst_mask[0], src + fix, remainder);
kps::ElementwiseRandom<SType, float, kCount, 1, Rand>(&rands[0], Rand(),
&state);
// dst
kps::OperatorTernary<T, float, T, DstMaskFunctor<T, float>>(
&dst_mask[0], &dst_mask[0], &rands[0], dst_functor, kCount);
kps::WriteData<T, kCount, 1, 1, true>(dst + fix, &dst_mask[0], remainder);
// mask
kps::ElementwiseUnary<T, MaskType, kCount, 1, 1, Cast>(
&mask_result[0], &dst_mask[kCount], Cast());
kps::WriteData<MaskType, kCount, 1, 1, true>(mask + fix, &mask_result[0],
remainder);

auto dst_functor =
DstMaskFunctor<T, float>(1.0f - dropout_prob, is_upscale_in_train);
for (; fix < main_offset; fix += stride) {
kps::ReadData<T, kCount, 1, 1, false>(&dst_mask[0], src + fix, deal_size);
kps::ElementwiseRandom<SType, float, kCount, 1, Rand>(&rands[0], Rand(),
&state);
// dst
kps::OperatorTernary<T, float, T, DstMaskFunctor<T, float>>(
&dst_mask[0], &dst_mask[0], &rands[0], dst_functor, kCount);
kps::WriteData<T, kCount, 1, 1, false>(dst + fix, &dst_mask[0], deal_size);
// mask
kps::ElementwiseUnary<T, MaskType, kCount, 1, 1, Cast>(
&mask_result[0], &dst_mask[kCount], Cast());
kps::WriteData<MaskType, kCount, 1, 1, false>(mask + fix, &mask_result[0],
deal_size);
if (fix > idx * kCount + 1) {
__syncthreads();
}
} else {
auto dst_functor =
DstMaskGenerator<T, float>(dropout_prob, is_upscale_in_train);
for (; fix < main_offset; fix += stride) {
kps::ReadData<T, kCount, 1, 1, false>(&dst_mask[0], src + fix, deal_size);
kps::ElementwiseRandom<SType, float, kCount, 1, Rand>(&rands[0], Rand(),
&state);
// dst
kps::OperatorTernary<T, float, T, DstMaskGenerator<T, float>>(
&dst_mask[0], &dst_mask[0], &rands[0], dst_functor, kCount);
kps::WriteData<T, kCount, 1, 1, false>(dst + fix, &dst_mask[0],
deal_size);
// mask
kps::ElementwiseUnary<T, MaskType, kCount, 1, 1, Cast>(
&mask_result[0], &dst_mask[kCount], Cast());
kps::WriteData<MaskType, kCount, 1, 1, false>(mask + fix, &mask_result[0],
deal_size);
}
int remainder = n - fix;
if (remainder > 0) {
kps::ReadData<T, kCount, 1, 1, true>(&dst_mask[0], src + fix, remainder);
kps::ElementwiseRandom<SType, float, kCount, 1, Rand>(&rands[0], Rand(),
&state);
// dst
kps::OperatorTernary<T, float, T, DstMaskGenerator<T, float>>(
&dst_mask[0], &dst_mask[0], &rands[0], dst_functor, kCount);
kps::WriteData<T, kCount, 1, 1, true>(dst + fix, &dst_mask[0], remainder);
// mask
kps::ElementwiseUnary<T, MaskType, kCount, 1, 1, Cast>(
&mask_result[0], &dst_mask[kCount], Cast());
kps::WriteData<MaskType, kCount, 1, 1, true>(mask + fix, &mask_result[0],
remainder);
}
}
int remainder = n - fix;
if (remainder > 0) {
kps::ReadData<T, kCount, 1, 1, true>(&dst_mask[0], src + fix, remainder);
kps::ElementwiseRandom<SType, float, kCount, 1, Rand>(&rands[0], Rand(),
&state);
// dst
kps::OperatorTernary<T, float, T, DstMaskFunctor<T, float>>(
&dst_mask[0], &dst_mask[0], &rands[0], dst_functor, kCount);
kps::WriteData<T, kCount, 1, 1, true>(dst + fix, &dst_mask[0], remainder);
// mask
kps::ElementwiseUnary<T, MaskType, kCount, 1, 1, Cast>(
&mask_result[0], &dst_mask[kCount], Cast());
kps::WriteData<MaskType, kCount, 1, 1, true>(mask + fix, &mask_result[0],
remainder);
__syncthreads();
}
}

Expand Down Expand Up @@ -251,13 +182,11 @@ void DropoutFwGPUKernelDriver(const phi::GPUContext& dev_ctx, bool is_test,
size_t grid_size = gpu_config.GetGridSize();
size_t block_size = gpu_config.GetBlockSize();

if (FLAGS_use_curand) {
int64_t device_id = dev_ctx.GetPlace().GetDeviceId();
const auto& prop = platform::GetDeviceProperties(device_id);
size_t max_grid_size = prop.maxThreadsPerMultiProcessor *
prop.multiProcessorCount / block_size;
grid_size = std::min(grid_size, max_grid_size);
}
int64_t device_id = dev_ctx.GetPlace().GetDeviceId();
const auto& prop = platform::GetDeviceProperties(device_id);
size_t max_grid_size = prop.maxThreadsPerMultiProcessor *
prop.multiProcessorCount / block_size;
grid_size = std::min(grid_size, max_grid_size);

auto offset =
((x_numel - 1) / (grid_size * block_size * kVecSize) + 1) * kVecSize;
Expand All @@ -268,7 +197,7 @@ void DropoutFwGPUKernelDriver(const phi::GPUContext& dev_ctx, bool is_test,

VectorizedRandomGenerator<T, uint8_t><<<grid_size, block_size, 0, stream>>>(
size, seed_data, dropout_prob, x_data, mask_data, y_data,
upscale_in_train, increment, main_offset, FLAGS_use_curand);
upscale_in_train, increment, main_offset);
} else {
if (upscale_in_train) {
// todo: can y share with data with x directly?
Expand Down
7 changes: 0 additions & 7 deletions paddle/fluid/operators/gaussian_random_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,14 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
#include <thrust/random.h>
#include <thrust/transform.h>
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/fill_constant_op.h"

#include "paddle/phi/kernels/funcs/distribution_helper.h"
#include "paddle/phi/kernels/funcs/index_impl.cu.h"

DECLARE_bool(use_curand);

namespace paddle {
namespace operators {

Expand Down
54 changes: 4 additions & 50 deletions paddle/fluid/operators/uniform_random_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,7 @@
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#if defined(__NVCC__) || defined(__HIPCC__)
DECLARE_bool(use_curand);
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
#include <thrust/random.h>
#include <thrust/transform.h>
#include "paddle/fluid/framework/generator.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/funcs/distribution_helper.h"
Expand Down Expand Up @@ -146,39 +142,6 @@ struct UniformGenerator {
}
};

template <typename T>
struct UniformGeneratorOffset {
T min_, max_;
unsigned int seed_;
T diag_val_;
unsigned int diag_num_;
unsigned int diag_step_;
int offset_;
__host__ __device__ UniformGeneratorOffset(T min, T max, int seed,
int diag_num, int diag_step,
T diag_val, int offset)
: min_(min),
max_(max),
seed_(seed),
diag_num_(diag_num),
diag_step_(diag_step),
diag_val_(diag_val),
offset_(offset) {}

__host__ __device__ T operator()(const unsigned int n) const {
thrust::minstd_rand rng;
rng.seed(seed_);
thrust::uniform_real_distribution<T> dist(min_, max_);
rng.discard(n + offset_);
T out = dist(rng);
unsigned int remainder = n % (diag_step_ + 1);
if (remainder == 0 && diag_num_ > n / (diag_step_ + 1)) {
out = diag_val_;
}
return out;
}
};

template <typename T>
void UniformRandom(const framework::ExecutionContext& context,
framework::Tensor* tensor) {
Expand All @@ -205,19 +168,10 @@ void UniformRandom(const framework::ExecutionContext& context,
int device_id = context.GetPlace().GetDeviceId();
auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id);
if (gen_cuda->GetIsInitPy() && seed_flag) {
if (FLAGS_use_curand) {
using MT = typename details::MPTypeTrait<T>::Type;
phi::funcs::uniform_distribution<MT> dist;
phi::funcs::uniform_real_transform<MT> trans(min, max);
phi::funcs::distribution_and_transform<T>(dev_cxt, tensor, dist, trans);
} else {
auto seed_offset = gen_cuda->IncrementOffset(1);
int64_t gen_offset = size * seed_offset.second;
auto func =
UniformGeneratorOffset<T>(min, max, seed_offset.first, diag_num,
diag_step, diag_val, gen_offset);
phi::IndexKernel<T, UniformGeneratorOffset<T>>(dev_cxt, tensor, func);
}
using MT = typename details::MPTypeTrait<T>::Type;
phi::funcs::uniform_distribution<MT> dist;
phi::funcs::uniform_real_transform<MT> trans(min, max);
phi::funcs::distribution_and_transform<T>(dev_cxt, tensor, dist, trans);
} else {
auto func =
UniformGenerator<T>(min, max, seed, diag_num, diag_step, diag_val);
Expand Down
2 changes: 0 additions & 2 deletions paddle/fluid/platform/flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -549,8 +549,6 @@ PADDLE_DEFINE_EXPORTED_double(
*/
PADDLE_DEFINE_EXPORTED_bool(use_mkldnn, false, "Use MKLDNN to run");

PADDLE_DEFINE_EXPORTED_bool(use_curand, false, "Random OP use CURAND");

/**
* Debug related FLAG
* Name: FLAGS_call_stack_level
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/kernels/cpu/transpose_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ PD_REGISTER_KERNEL(transpose,
double,
int32_t,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
59 changes: 8 additions & 51 deletions paddle/phi/kernels/gpu/bernoulli_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@

#include "paddle/phi/kernels/bernoulli_kernel.h"

#include <thrust/random.h>
#include <thrust/transform.h>
#ifdef __NVCC__
#include <curand_kernel.h>
#endif
Expand All @@ -32,35 +30,8 @@
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/distribution_helper.h"

// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/platform/transform.h"

DECLARE_bool(use_curand);

namespace phi {

template <typename T>
struct BernoulliCudaFunctor {
unsigned int seed_;
unsigned int offset_;
__host__ __device__ BernoulliCudaFunctor(unsigned int seed,
unsigned int offset)
: seed_(seed), offset_(offset) {}

__host__ __device__ T operator()(const unsigned int n, const T p) const {
// NOTE(zhiqiu): currently, PADDLE_ENFORCE in cuda kernel may print several
// lines of error messages if, and it should be refined.
PADDLE_ENFORCE(p >= 0.0 && p <= 1.0,
"The probability should be >=0 and <= 1, but got %f",
p);
thrust::minstd_rand rng;
rng.seed(seed_);
thrust::uniform_real_distribution<T> dist(0.0, 1.0);
rng.discard(n + offset_);
return static_cast<T>(dist(rng) < p);
}
};

// 'curand_uniform4/hiprand_uniform4' generate 4 random number each time
template <typename T>
__global__ void bernoulli_cuda_kernel(
Expand Down Expand Up @@ -100,30 +71,16 @@ void BernoulliKernel(const Context& ctx,

auto gen_cuda = ctx.GetGenerator();

if (FLAGS_use_curand) {
auto seed_offset = gen_cuda->IncrementOffset(12);
uint64_t seed = seed_offset.first;
uint64_t offset = seed_offset.second;
auto seed_offset = gen_cuda->IncrementOffset(12);
uint64_t seed = seed_offset.first;
uint64_t offset = seed_offset.second;

auto gpu_config = phi::backends::gpu::GetGpuLaunchConfig1D(ctx, numel, 4);
size_t grid_size = gpu_config.GetGridSize();
size_t block_size = gpu_config.GetBlockSize();
auto gpu_config = phi::backends::gpu::GetGpuLaunchConfig1D(ctx, numel, 4);
size_t grid_size = gpu_config.GetGridSize();
size_t block_size = gpu_config.GetBlockSize();

bernoulli_cuda_kernel<<<grid_size, block_size, 0, ctx.stream()>>>(
numel, seed, offset, x_data, out_data);
} else {
auto seed_offset = gen_cuda->IncrementOffset(1);
int64_t gen_offset = numel * seed_offset.second;
paddle::platform::Transform<phi::GPUContext> trans;
thrust::counting_iterator<int64_t> index_sequence_begin(0);
trans(ctx,
index_sequence_begin,
index_sequence_begin + numel,
x_data,
out_data,
BernoulliCudaFunctor<T>(static_cast<int64_t>(seed_offset.first),
static_cast<int64_t>(gen_offset)));
}
bernoulli_cuda_kernel<<<grid_size, block_size, 0, ctx.stream()>>>(
numel, seed, offset, x_data, out_data);
}

} // namespace phi
Expand Down
Loading