@@ -16,6 +16,7 @@ limitations under the License. */
1616#include < thrust/random.h>
1717#include < thrust/transform.h>
1818
19+ #include " paddle/fluid/framework/generator.h"
1920#include " paddle/fluid/framework/op_registry.h"
2021#include " paddle/fluid/framework/operator.h"
2122#include " paddle/fluid/operators/bernoulli_op.h"
@@ -27,7 +28,10 @@ namespace operators {
2728template <typename T>
2829struct BernoulliCudaFunctor {
2930 unsigned int seed_;
30- __host__ __device__ BernoulliCudaFunctor (int seed) : seed_(seed) {}
31+ unsigned int offset_;
32+ __host__ __device__ BernoulliCudaFunctor (unsigned int seed,
33+ unsigned int offset)
34+ : seed_(seed), offset_(offset) {}
3135
3236 __host__ __device__ T operator ()(const unsigned int n, const T p) const {
3337 // NOTE(zhiqiu): currently, PADDLE_ENFORCE in cuda kernel may print several
@@ -37,7 +41,7 @@ struct BernoulliCudaFunctor {
3741 thrust::minstd_rand rng;
3842 rng.seed (seed_);
3943 thrust::uniform_real_distribution<T> dist (0.0 , 1.0 );
40- rng.discard (n);
44+ rng.discard (n + offset_ );
4145 return static_cast <T>(dist (rng) < p);
4246 }
4347};
@@ -47,20 +51,26 @@ class BernoulliOpKernel<platform::CUDADeviceContext, T>
4751 : public framework::OpKernel<T> {
4852 public:
4953 void Compute (const framework::ExecutionContext& ctx) const override {
50- std::random_device rd;
51- auto seed = rd ();
5254 const auto x = ctx.Input <framework::Tensor>(" X" );
5355 auto out = ctx.Output <framework::Tensor>(" Out" );
5456 auto * in_data = x->data <T>();
5557 auto * out_data = out->mutable_data <T>(ctx.GetPlace ());
56-
5758 int64_t size = x->numel ();
58- thrust::counting_iterator<unsigned int > index_sequence_begin (0 );
59+
60+ int device_id =
61+ BOOST_GET_CONST (platform::CUDAPlace, ctx.GetPlace ()).GetDeviceId ();
62+ auto gen_cuda = framework::GetDefaultCUDAGenerator (device_id);
63+ auto seed_offset = gen_cuda->IncrementOffset (1 );
64+ int gen_offset = size * seed_offset.second ;
5965 platform::Transform<platform::CUDADeviceContext> trans;
66+ thrust::counting_iterator<unsigned int > index_sequence_begin (0 );
6067 auto * context =
6168 static_cast <const platform::CUDADeviceContext*>(&ctx.device_context ());
69+
6270 trans (*context, index_sequence_begin, index_sequence_begin + size, in_data,
63- out_data, BernoulliCudaFunctor<T>(seed));
71+ out_data,
72+ BernoulliCudaFunctor<T>(static_cast <unsigned int >(seed_offset.first ),
73+ static_cast <unsigned int >(gen_offset)));
6474 }
6575};
6676
0 commit comments