Skip to content

Commit e7cbc43

Browse files
authored
[cherry-pick] use cuda generator in bernoulli cuda kernel (#30199) #30286
[cherry-pick] use cuda generator in bernoulli cuda kernel (#30199)
1 parent 330aea6 commit e7cbc43

File tree

2 files changed

+18
-9
lines changed

2 files changed

+18
-9
lines changed

paddle/fluid/framework/generator.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,8 +172,7 @@ std::pair<uint64_t, uint64_t> Generator::IncrementOffset(
172172
PADDLE_THROW(platform::errors::PermissionDenied(
173173
"Increment Offset only support in CUDA place"));
174174
#endif
175-
return std::make_pair(static_cast<int>(this->state_.current_seed),
176-
cur_offset);
175+
return std::make_pair(this->state_.current_seed, cur_offset);
177176
}
178177

179178
void Generator::SetIsInitPy(bool is_init_py) {

paddle/fluid/operators/bernoulli_op.cu

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ limitations under the License. */
1616
#include <thrust/random.h>
1717
#include <thrust/transform.h>
1818

19+
#include "paddle/fluid/framework/generator.h"
1920
#include "paddle/fluid/framework/op_registry.h"
2021
#include "paddle/fluid/framework/operator.h"
2122
#include "paddle/fluid/operators/bernoulli_op.h"
@@ -27,7 +28,10 @@ namespace operators {
2728
template <typename T>
2829
struct BernoulliCudaFunctor {
2930
unsigned int seed_;
30-
__host__ __device__ BernoulliCudaFunctor(int seed) : seed_(seed) {}
31+
unsigned int offset_;
32+
__host__ __device__ BernoulliCudaFunctor(unsigned int seed,
33+
unsigned int offset)
34+
: seed_(seed), offset_(offset) {}
3135

3236
__host__ __device__ T operator()(const unsigned int n, const T p) const {
3337
// NOTE(zhiqiu): currently, PADDLE_ENFORCE in cuda kernel may print several
@@ -37,7 +41,7 @@ struct BernoulliCudaFunctor {
3741
thrust::minstd_rand rng;
3842
rng.seed(seed_);
3943
thrust::uniform_real_distribution<T> dist(0.0, 1.0);
40-
rng.discard(n);
44+
rng.discard(n + offset_);
4145
return static_cast<T>(dist(rng) < p);
4246
}
4347
};
@@ -47,20 +51,26 @@ class BernoulliOpKernel<platform::CUDADeviceContext, T>
4751
: public framework::OpKernel<T> {
4852
public:
4953
void Compute(const framework::ExecutionContext& ctx) const override {
50-
std::random_device rd;
51-
auto seed = rd();
5254
const auto x = ctx.Input<framework::Tensor>("X");
5355
auto out = ctx.Output<framework::Tensor>("Out");
5456
auto* in_data = x->data<T>();
5557
auto* out_data = out->mutable_data<T>(ctx.GetPlace());
56-
5758
int64_t size = x->numel();
58-
thrust::counting_iterator<unsigned int> index_sequence_begin(0);
59+
60+
int device_id =
61+
BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace()).GetDeviceId();
62+
auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id);
63+
auto seed_offset = gen_cuda->IncrementOffset(1);
64+
int gen_offset = size * seed_offset.second;
5965
platform::Transform<platform::CUDADeviceContext> trans;
66+
thrust::counting_iterator<unsigned int> index_sequence_begin(0);
6067
auto* context =
6168
static_cast<const platform::CUDADeviceContext*>(&ctx.device_context());
69+
6270
trans(*context, index_sequence_begin, index_sequence_begin + size, in_data,
63-
out_data, BernoulliCudaFunctor<T>(seed));
71+
out_data,
72+
BernoulliCudaFunctor<T>(static_cast<unsigned int>(seed_offset.first),
73+
static_cast<unsigned int>(gen_offset)));
6474
}
6575
};
6676

0 commit comments

Comments
 (0)