Skip to content

Commit 0a9f9f9

Browse files
authored
add cuda generator (#26786) (#27014)
1 parent 09ede3b commit 0a9f9f9

File tree

13 files changed

+523
-18
lines changed

13 files changed

+523
-18
lines changed

paddle/fluid/framework/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ cc_test(op_compatible_info_test SRCS op_compatible_info_test.cc DEPS op_compatib
272272

273273
cc_library(save_load_util SRCS save_load_util DEPS tensor scope layer)
274274
cc_test(save_load_util_test SRCS save_load_util_test.cc DEPS save_load_util tensor scope layer)
275-
cc_library(generator SRCS generator.cc)
275+
cc_library(generator SRCS generator.cc DEPS enforce place)
276276

277277
# Get the current working branch
278278
execute_process(

paddle/fluid/framework/generator.cc

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,46 @@ limitations under the License. */
2121
#include <unordered_map>
2222
#include <unordered_set>
2323
#include <utility>
24+
#include <vector>
25+
26+
#include "paddle/fluid/platform/enforce.h"
27+
#include "paddle/fluid/platform/gpu_info.h"
28+
#include "paddle/fluid/platform/place.h"
2429

2530
namespace paddle {
2631
namespace framework {
2732

33+
const std::shared_ptr<Generator>& GetDefaultCUDAGenerator(int64_t device_id) {
34+
#ifdef PADDLE_WITH_CUDA
35+
36+
static int64_t num_cuda_devices = -1;
37+
static std::once_flag num_devices_init_flag;
38+
static std::deque<std::once_flag> cuda_device_flags;
39+
static std::vector<std::shared_ptr<Generator>> default_cuda_generators;
40+
41+
std::call_once(num_devices_init_flag, []() {
42+
num_cuda_devices = paddle::platform::GetCUDADeviceCount();
43+
cuda_device_flags.resize(num_cuda_devices);
44+
default_cuda_generators.resize(num_cuda_devices);
45+
});
46+
if (device_id < 0) {
47+
PADDLE_THROW(platform::errors::InvalidArgument(
48+
"cuda device id shoule be greater than 0"));
49+
}
50+
51+
std::call_once(cuda_device_flags[device_id], [device_id]() {
52+
default_cuda_generators[device_id] =
53+
std::make_shared<Generator>(GetRandomSeed(), device_id);
54+
VLOG(4) << "initial seed: "
55+
<< default_cuda_generators[device_id]->GetCurrentSeed();
56+
});
57+
return default_cuda_generators[device_id];
58+
#else
59+
PADDLE_THROW(platform::errors::PermissionDenied(
60+
"getDefaultCUDAGenerator only support in CUDA place"));
61+
#endif
62+
}
63+
2864
const std::shared_ptr<Generator>& DefaultCPUGenerator() {
2965
static auto default_cpu_generator =
3066
std::make_shared<Generator>(GetRandomSeed());
@@ -103,6 +139,7 @@ uint64_t Generator::Seed() {
103139
void Generator::SetCurrentSeed(uint64_t seed) {
104140
std::lock_guard<std::mutex> lock(this->mu_);
105141
this->state_.current_seed = seed;
142+
this->state_.thread_offset = 0;
106143
std::seed_seq seq({seed});
107144
this->engine_->seed(seq);
108145
}
@@ -123,6 +160,22 @@ uint64_t Generator::Random64() {
123160
return (*engine)();
124161
}
125162

163+
std::pair<uint64_t, uint64_t> Generator::IncrementOffset(
164+
uint64_t increament_offset) {
165+
uint64_t cur_offset = this->state_.thread_offset;
166+
#ifdef PADDLE_WITH_CUDA
167+
std::lock_guard<std::mutex> lock(this->mu_);
168+
169+
this->state_.thread_offset += increament_offset;
170+
171+
#else
172+
PADDLE_THROW(platform::errors::PermissionDenied(
173+
"Increment Offset only support in CUDA place"));
174+
#endif
175+
return std::make_pair(static_cast<int>(this->state_.current_seed),
176+
cur_offset);
177+
}
178+
126179
void Generator::SetIsInitPy(bool is_init_py) {
127180
this->is_init_py_ = is_init_py;
128181
VLOG(4) << "SetIsInitPy:" << this->is_init_py_;

paddle/fluid/framework/generator.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ static uint64_t GetRandomSeed() {
3838
struct GeneratorState {
3939
int64_t device = -1;
4040
uint64_t current_seed = 34342423252;
41+
uint64_t thread_offset = 0;
4142
std::mt19937_64 cpu_engine;
4243
};
4344

@@ -49,6 +50,7 @@ struct Generator {
4950
this->state_.cpu_engine = *engine;
5051
this->state_.device = -1;
5152
this->state_.current_seed = seed;
53+
this->state_.thread_offset = 0;
5254
this->engine_ = engine;
5355
VLOG(4) << "initial seed: " << this->state_.current_seed
5456
<< ", cpu engine: " << &this->state_.cpu_engine;
@@ -59,11 +61,25 @@ struct Generator {
5961
this->state_.cpu_engine = *engine;
6062
this->state_.device = -1;
6163
this->state_.current_seed = seed;
64+
this->state_.thread_offset = 0;
6265
this->engine_ = engine;
6366
VLOG(4) << "initial seed: " << this->state_.current_seed
6467
<< ", cpu engine: " << &this->state_.cpu_engine;
6568
this->is_init_py_ = true; // TODO(zhiqiu): remove it in future
6669
}
70+
Generator(uint64_t seed, uint64_t device_id) {
71+
std::seed_seq seq({seed});
72+
auto engine = std::make_shared<std::mt19937_64>(seq);
73+
this->state_.cpu_engine = *engine;
74+
this->state_.device = device_id;
75+
this->state_.current_seed = seed;
76+
this->state_.thread_offset = 0;
77+
this->engine_ = engine;
78+
VLOG(4) << "initial seed: " << this->state_.current_seed
79+
<< ", cpu engine: " << &this->state_.cpu_engine;
80+
this->is_init_py_ = false; // TODO(zhiqiu): remove it in future
81+
}
82+
6783
Generator(const Generator& other) = delete;
6884

6985
// get random state
@@ -83,8 +99,11 @@ struct Generator {
8399

84100
uint64_t Random64();
85101

102+
std::pair<uint64_t, uint64_t> IncrementOffset(uint64_t increament_offset);
103+
86104
void SetIsInitPy(bool);
87105
bool GetIsInitPy() const;
106+
uint64_t get_device_id() { return this->state_.device; }
88107

89108
private:
90109
GeneratorState state_;
@@ -105,5 +124,8 @@ std::shared_ptr<std::mt19937_64> OpDefaultCPUEngine();
105124

106125
std::shared_ptr<std::mt19937_64> GetCPURandomEngine(uint64_t);
107126

127+
const std::shared_ptr<Generator>& GetDefaultCUDAGenerator(
128+
int64_t device_id = -1);
129+
108130
} // namespace framework
109131
} // namespace paddle

paddle/fluid/operators/bernoulli_op.cu

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ limitations under the License. */
1616
#include <thrust/random.h>
1717
#include <thrust/transform.h>
1818

19-
#include "paddle/fluid/framework/generator.h"
2019
#include "paddle/fluid/framework/op_registry.h"
2120
#include "paddle/fluid/framework/operator.h"
2221
#include "paddle/fluid/operators/bernoulli_op.h"

paddle/fluid/operators/dropout_op.cu

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,42 @@ __global__ void RandomGeneratorWithSeed(const size_t n, const int* seed,
9696
}
9797
}
9898

99+
template <typename T, typename MaskType>
100+
__global__ void RandomGeneratorWithGenerator(const size_t n, uint64_t seed,
101+
const float dropout_prob,
102+
const T* src, MaskType* mask_data,
103+
T* dst, bool is_upscale_in_train,
104+
uint64_t increment) {
105+
curandStatePhilox4_32_10_t state;
106+
int idx = blockDim.x * blockIdx.x + threadIdx.x;
107+
int step_size = 0;
108+
109+
MaskType mask;
110+
T dest;
111+
for (; idx < n; idx += blockDim.x * gridDim.x) {
112+
T s = src[idx];
113+
if (step_size == 0) {
114+
curand_init(seed, idx, increment, &state);
115+
step_size = blockDim.x * gridDim.x;
116+
} else {
117+
curand_init(seed, idx, increment, &state);
118+
}
119+
if (curand_uniform(&state) < dropout_prob) {
120+
mask = 0;
121+
dest = 0;
122+
} else {
123+
mask = 1;
124+
if (is_upscale_in_train) {
125+
dest = s / static_cast<T>(1.0f - dropout_prob);
126+
} else {
127+
dest = s;
128+
}
129+
}
130+
mask_data[idx] = mask;
131+
dst[idx] = dest;
132+
}
133+
}
134+
99135
// It seems that Eigen::Tensor::setRandom in GPU will SEGFAULT.
100136
// Use std::random and thrust::random(thrust is a std library in CUDA) to
101137
// implement uniform random.
@@ -150,6 +186,17 @@ class GPUDropoutKernel : public framework::OpKernel<T> {
150186
context.Attr<bool>("fix_seed") ? context.Attr<int>("seed") : rnd();
151187
}
152188

189+
int device_id = BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace())
190+
.GetDeviceId();
191+
auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id);
192+
if (gen_cuda->GetIsInitPy() && (!context.Attr<bool>("fix_seed"))) {
193+
auto seed_offset = gen_cuda->IncrementOffset(1);
194+
RandomGeneratorWithGenerator<T, uint8_t><<<grid, threads, 0, stream>>>(
195+
size, seed_offset.first, dropout_prob, x_data, mask_data, y_data,
196+
upscale_in_train, seed_offset.second);
197+
return;
198+
}
199+
153200
RandomGenerator<T, uint8_t><<<grid, threads, 0, stream>>>(
154201
size, seed_data, dropout_prob, x_data, mask_data, y_data,
155202
upscale_in_train);

paddle/fluid/operators/gaussian_random_op.cu

Lines changed: 53 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414
#include <thrust/random.h>
1515
#include <thrust/transform.h>
16+
#include "paddle/fluid/framework/generator.h"
1617
#include "paddle/fluid/framework/op_registry.h"
1718
#include "paddle/fluid/framework/operator.h"
1819
#include "paddle/fluid/operators/fill_constant_op.h"
@@ -24,15 +25,20 @@ template <typename T>
2425
struct GaussianGenerator {
2526
T mean_, std_;
2627
unsigned int seed_;
28+
unsigned int offset_ = 0;
2729

2830
__host__ __device__ GaussianGenerator(T mean, T std, int seed)
2931
: mean_(mean), std_(std), seed_(seed) {}
3032

33+
__host__ __device__ GaussianGenerator(T mean, T std, int seed, int offset)
34+
: mean_(mean), std_(std), seed_(seed), offset_(offset) {}
35+
3136
__host__ __device__ T operator()(const unsigned int n) const {
3237
thrust::minstd_rand rng;
3338
rng.seed(seed_);
3439
thrust::normal_distribution<T> dist(mean_, std_);
35-
rng.discard(n);
40+
unsigned int new_n = n + offset_;
41+
rng.discard(new_n);
3642
return dist(rng);
3743
}
3844
};
@@ -43,9 +49,11 @@ class GPUGaussianRandomKernel : public framework::OpKernel<T> {
4349
void Compute(const framework::ExecutionContext& context) const override {
4450
auto* tensor = context.Output<framework::Tensor>("Out");
4551
unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed"));
52+
bool seed_flag = false;
4653
if (seed == 0) {
4754
std::random_device rd;
4855
seed = rd();
56+
seed_flag = true;
4957
}
5058
T mean = static_cast<T>(context.Attr<float>("mean"));
5159
T std = static_cast<T>(context.Attr<float>("std"));
@@ -56,9 +64,27 @@ class GPUGaussianRandomKernel : public framework::OpKernel<T> {
5664
T* data = tensor->mutable_data<T>(context.GetPlace());
5765

5866
int64_t size = tensor->numel();
59-
thrust::transform(index_sequence_begin, index_sequence_begin + size,
60-
thrust::device_ptr<T>(data),
61-
GaussianGenerator<T>(mean, std, seed));
67+
68+
int device_id =
69+
BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()).GetDeviceId();
70+
auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id);
71+
72+
if (gen_cuda->GetIsInitPy() && seed_flag) {
73+
auto seed_offset = gen_cuda->IncrementOffset(1);
74+
int offset_step = 100;
75+
// NOTE(xuefeng): Currently, we let offset step fixed to avoid
76+
// unexpected results which may cause ut fail.
77+
// we will fix this in future.
78+
int gen_offset = offset_step * seed_offset.second;
79+
thrust::transform(
80+
index_sequence_begin, index_sequence_begin + size,
81+
thrust::device_ptr<T>(data),
82+
GaussianGenerator<T>(mean, std, seed_offset.first, gen_offset));
83+
} else {
84+
thrust::transform(index_sequence_begin, index_sequence_begin + size,
85+
thrust::device_ptr<T>(data),
86+
GaussianGenerator<T>(mean, std, seed));
87+
}
6288
}
6389
};
6490

@@ -69,17 +95,37 @@ class GPUGaussianRandomBatchSizeLikeKernel : public framework::OpKernel<T> {
6995
auto* tensor = context.Output<framework::Tensor>("Out");
7096
T* data = tensor->mutable_data<T>(context.GetPlace());
7197
unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed"));
98+
bool seed_flag = false;
7299
if (seed == 0) {
73100
std::random_device rd;
74101
seed = rd();
102+
seed_flag = true;
75103
}
76104
T mean = static_cast<T>(context.Attr<float>("mean"));
77105
T std = static_cast<T>(context.Attr<float>("std"));
78106
thrust::counting_iterator<unsigned int> index_sequence_begin(0);
79107
int64_t size = tensor->numel();
80-
thrust::transform(index_sequence_begin, index_sequence_begin + size,
81-
thrust::device_ptr<T>(data),
82-
GaussianGenerator<T>(mean, std, seed));
108+
109+
int device_id =
110+
BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()).GetDeviceId();
111+
auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id);
112+
113+
if (gen_cuda->GetIsInitPy() && seed_flag) {
114+
auto seed_offset = gen_cuda->IncrementOffset(1);
115+
int offset_step = 100;
116+
// NOTE(xuefeng): Currently, we let offset step fixed to avoid
117+
// unexpected results which may cause ut fail.
118+
// we will fix this in future.
119+
int gen_offset = offset_step * seed_offset.second;
120+
thrust::transform(index_sequence_begin, index_sequence_begin + size,
121+
thrust::device_ptr<T>(data),
122+
GaussianGenerator<T>(mean, std, seed_offset.first,
123+
seed_offset.second));
124+
} else {
125+
thrust::transform(index_sequence_begin, index_sequence_begin + size,
126+
thrust::device_ptr<T>(data),
127+
GaussianGenerator<T>(mean, std, seed));
128+
}
83129
}
84130
};
85131
} // namespace operators

paddle/fluid/operators/randint_op.cu

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
// limitations under the License.
1414
#include <thrust/random.h>
1515
#include <thrust/transform.h>
16+
#include "paddle/fluid/framework/generator.h"
1617
#include "paddle/fluid/framework/op_registry.h"
1718
#include "paddle/fluid/operators/uniform_random_op.h"
1819

@@ -49,15 +50,23 @@ class GPURandintKernel : public framework::OpKernel<T> {
4950

5051
int64_t size = out->numel();
5152
unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed"));
53+
54+
/*
5255
std::minstd_rand engine;
5356
if (seed == 0) {
5457
std::random_device rd;
5558
seed = rd();
5659
}
5760
engine.seed(seed);
61+
*/
62+
5863
std::uniform_int_distribution<> dist(context.Attr<int>("low"),
5964
context.Attr<int>("high") - 1);
60-
for (int64_t i = 0; i < size; ++i) data[i] = dist(engine);
65+
auto engine = framework::GetCPURandomEngine(seed);
66+
67+
for (int64_t i = 0; i < size; ++i) {
68+
data[i] = dist(*engine);
69+
}
6170

6271
if (platform::is_gpu_place(context.GetPlace())) {
6372
// Copy tensor to out

0 commit comments

Comments
 (0)