Skip to content

Commit c5bf09b

Browse files
cxxlyFeiyu Chan
andauthored
add dirichlet random sample op in cpu and gpu kernel (PaddlePaddle#38244)
* add dirichlet sample op and cpu backend kernel * add Dirichlet op cuda kernel (#6) * add dirichlet op hip kernel Co-authored-by: Feiyu Chan <chenfeiyu@baidu.com>
1 parent cc83c95 commit c5bf09b

File tree

4 files changed

+429
-0
lines changed

4 files changed

+429
-0
lines changed
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/operators/dirichlet_op.h"
16+
17+
#include "paddle/fluid/framework/generator.h"
18+
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
19+
#include "paddle/fluid/operators/reduce_ops/reduce_op.h"
20+
#include "paddle/fluid/operators/reduce_ops/reduce_sum_op.h"
21+
22+
namespace paddle {
23+
namespace operators {
24+
template <typename T, typename UniformSamplerT, typename NormalSamplerT>
25+
struct GammaCPUFunctor {
26+
GammaCPUFunctor(const T* alpha, T* gamma,
27+
BaseSampler<T, UniformSamplerT> uniform,
28+
BaseSampler<T, NormalSamplerT> normal)
29+
: alpha_(alpha), gamma_(gamma), uniform_(uniform), normal_(normal) {}
30+
31+
HOST void operator()(int64_t index) {
32+
auto sample = sample_gamma<T, T, UniformSamplerT, NormalSamplerT>(
33+
alpha_[index], uniform_, normal_);
34+
gamma_[index] = std::max(std::numeric_limits<T>::min(), sample);
35+
}
36+
37+
const T* alpha_;
38+
T* gamma_;
39+
BaseSampler<T, UniformSamplerT> uniform_;
40+
BaseSampler<T, NormalSamplerT> normal_;
41+
};
42+
43+
template <typename T>
44+
struct DirichletSampler<platform::CPUDeviceContext, T> {
45+
void operator()(const framework::ExecutionContext& ctx, const Tensor* alpha,
46+
Tensor* out) {
47+
auto& dev_ctx = ctx.device_context<platform::CPUDeviceContext>();
48+
49+
auto p_gen = framework::DefaultCPUGenerator();
50+
auto generator = p_gen->GetCPUEngine();
51+
52+
auto uniform = [&generator]() -> T {
53+
std::uniform_real_distribution<T> u(0.0, 1.0);
54+
return u(*generator);
55+
};
56+
BaseSampler<T, decltype(uniform)> standard_uniform(uniform);
57+
58+
auto normal = [&generator]() {
59+
std::normal_distribution<T> n(0.0, 1.0);
60+
return n(*generator);
61+
};
62+
BaseSampler<T, decltype(normal)> standard_normal(normal);
63+
64+
// sample from K gamma distributions, where K=alpha.numel()
65+
framework::Tensor gamma_samples;
66+
gamma_samples.mutable_data<T>(alpha->dims(), dev_ctx.GetPlace());
67+
GammaCPUFunctor<T, decltype(uniform), decltype(normal)> gamma_functor(
68+
alpha->data<T>(), gamma_samples.data<T>(), standard_uniform,
69+
standard_normal);
70+
platform::ForRange<platform::CPUDeviceContext> for_range(dev_ctx,
71+
alpha->numel());
72+
for_range(gamma_functor);
73+
74+
// normalize them into a simplex, along the last axis
75+
framework::Tensor gamma_sum;
76+
auto new_shape = gamma_samples.dims();
77+
new_shape[new_shape.size() - 1] = 1;
78+
gamma_sum.mutable_data<T>(new_shape, dev_ctx.GetPlace());
79+
80+
ReduceKernelFunctor<platform::CPUDeviceContext, T, SumFunctor>(
81+
&gamma_samples, &gamma_sum, {new_shape.size() - 1}, true, false, ctx)
82+
.template apply<T>();
83+
ElementwiseComputeEx<DivFunctor<T>, platform::CPUDeviceContext, T, T>(
84+
ctx, &gamma_samples, &gamma_sum, -1, DivFunctor<T>(), out);
85+
}
86+
};
87+
88+
class DirichletOpMaker : public framework::OpProtoAndCheckerMaker {
89+
public:
90+
void Make() override {
91+
AddInput("Alpha", "(Tensor), The dirichlet Alpha parameter");
92+
AddOutput("Out", "(Tensor), The output tensor of sample");
93+
AddComment(R"DOC(Sample random data from dirichlet distribution.)DOC");
94+
}
95+
};
96+
97+
class DirichletOp : public framework::OperatorWithKernel {
98+
public:
99+
using framework::OperatorWithKernel::OperatorWithKernel;
100+
101+
void InferShape(framework::InferShapeContext* ctx) const override {
102+
OP_INOUT_CHECK(ctx->HasInput("Alpha"), "Input", "Alpha", "dirichlet");
103+
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "dirichlet");
104+
const auto alpha_dim = ctx->GetInputDim("Alpha");
105+
PADDLE_ENFORCE_GE(alpha_dim.size(), 1,
106+
platform::errors::InvalidArgument(
107+
"ShapeError: The number of dimensions of 'Alpha' "
108+
"must be greater than or euqal to 1. "
109+
"But received Alpha's dimensions = %d,",
110+
alpha_dim.size()));
111+
ctx->ShareDim("Alpha", /*->*/ "Out");
112+
}
113+
};
114+
115+
} // namespace operators
116+
} // namespace paddle
117+
118+
REGISTER_OP_WITHOUT_GRADIENT(dirichlet, paddle::operators::DirichletOp,
119+
paddle::operators::DirichletOpMaker);
120+
REGISTER_OP_CPU_KERNEL(
121+
dirichlet,
122+
paddle::operators::DirichletKernel<paddle::platform::CPUDeviceContext,
123+
float>,
124+
paddle::operators::DirichletKernel<paddle::platform::CPUDeviceContext,
125+
double>);
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/framework/generator.h"
16+
#include "paddle/fluid/operators/dirichlet_op.h"
17+
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
18+
#include "paddle/fluid/operators/reduce_ops/reduce_op.h"
19+
#include "paddle/fluid/operators/reduce_ops/reduce_sum_op.h"
20+
#include "paddle/fluid/platform/for_range.h"
21+
22+
#ifdef PADDLE_WITH_CUDA
23+
#include <curand_kernel.h>
24+
#endif
25+
#ifdef PADDLE_WITH_HIP
26+
#include <hiprand_kernel.h>
27+
#endif
28+
29+
#if defined(PADDLE_WITH_CUDA)
30+
using COMPAT_RANDSTATEPHILOX4_32_10_T = curandStatePhilox4_32_10_t;
31+
#define COMPAT_RAND_INIT curand_init
32+
#define COMPAT_RAND_UNIFORM curand_uniform
33+
#define COMPAT_RAND_NORMAL curand_normal
34+
#elif defined(PADDLE_WITH_HIP)
35+
using COMPAT_RANDSTATEPHILOX4_32_10_T = hiprandStatePhilox4_32_10_t;
36+
#define COMPAT_RAND_INIT hiprand_init
37+
#define COMPAT_RAND_UNIFORM hiprand_uniform
38+
#define COMPAT_RAND_NORMAL hiprand_normal
39+
#endif
40+
41+
namespace paddle {
42+
namespace operators {
43+
template <typename T>
44+
struct GammaCUDAFunctor {
45+
GammaCUDAFunctor(const T* alpha, T* gamma, uint64_t seed, uint64_t offset)
46+
: alpha_(alpha), gamma_(gamma), seed_(seed), offset_(offset) {}
47+
48+
DEVICE void operator()(int64_t index) {
49+
// curand initialization
50+
COMPAT_RANDSTATEPHILOX4_32_10_T state;
51+
COMPAT_RAND_INIT(/*seed=*/seed_, /*subsequence=*/index, /*offset=*/offset_,
52+
&state);
53+
54+
// sample
55+
auto uniform_lambda = [&state]() { return COMPAT_RAND_UNIFORM(&state); };
56+
BaseSampler<T, decltype(uniform_lambda)> standard_uniform(uniform_lambda);
57+
auto normal_lambda = [&state]() { return COMPAT_RAND_NORMAL(&state); };
58+
BaseSampler<T, decltype(normal_lambda)> standard_normal(normal_lambda);
59+
60+
auto sample =
61+
sample_gamma<T, T, decltype(uniform_lambda), decltype(normal_lambda)>(
62+
alpha_[index], standard_uniform, standard_normal);
63+
gamma_[index] = std::max(std::numeric_limits<T>::min(), sample);
64+
}
65+
66+
const T* alpha_;
67+
T* gamma_;
68+
const uint64_t seed_;
69+
const uint64_t offset_;
70+
};
71+
72+
template <typename T>
73+
struct DirichletSampler<platform::CUDADeviceContext, T> {
74+
void operator()(const framework::ExecutionContext& ctx,
75+
const framework::Tensor* alpha, framework::Tensor* out) {
76+
auto& dev_ctx = ctx.device_context<platform::CUDADeviceContext>();
77+
78+
// init state, seed & offset for all threads
79+
int device_id =
80+
BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace()).GetDeviceId();
81+
auto p_gen = framework::GetDefaultCUDAGenerator(device_id);
82+
auto seed_and_offset = p_gen->IncrementOffset(10); // hard-coded offset
83+
auto seed = seed_and_offset.first;
84+
auto offset = seed_and_offset.second;
85+
86+
// sample from K gamma distributions, where K=alpha.numel()
87+
framework::Tensor gamma_samples;
88+
gamma_samples.mutable_data<T>(alpha->dims(), dev_ctx.GetPlace());
89+
GammaCUDAFunctor<T> gamma_functor(alpha->data<T>(), gamma_samples.data<T>(),
90+
seed, offset);
91+
platform::ForRange<platform::CUDADeviceContext> for_range(dev_ctx,
92+
out->numel());
93+
for_range(gamma_functor);
94+
95+
// normalize them into a simplex, along the last axis
96+
framework::Tensor gamma_sum;
97+
auto new_shape = gamma_samples.dims();
98+
new_shape[new_shape.size() - 1] = 1;
99+
gamma_sum.mutable_data<T>(new_shape, dev_ctx.GetPlace());
100+
101+
ReduceKernelFunctor<platform::CUDADeviceContext, T, SumFunctor>(
102+
&gamma_samples, &gamma_sum, {new_shape.size() - 1}, true, false, ctx)
103+
.template apply<T>();
104+
ElementwiseComputeEx<DivFunctor<T>, platform::CUDADeviceContext, T, T>(
105+
ctx, &gamma_samples, &gamma_sum, -1, DivFunctor<T>(), out);
106+
}
107+
};
108+
} // namespace operators
109+
} // namespace paddle
110+
111+
namespace ops = paddle::operators;
112+
113+
REGISTER_OP_CUDA_KERNEL(
114+
dirichlet, ops::DirichletKernel<paddle::platform::CUDADeviceContext, float>,
115+
ops::DirichletKernel<paddle::platform::CUDADeviceContext, double>);
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
#include <cmath>
17+
#include <random>
18+
#include "paddle/fluid/framework/op_registry.h"
19+
#include "paddle/fluid/platform/for_range.h"
20+
21+
// ROCM hcc doesn't work well with using std:: in kernel functions
22+
#if defined(PADDLE_WITH_CUDA)
23+
#define COMPAT_EXP exp
24+
#define COMPAT_CEIL ceil
25+
#define COMPAT_FLOOR floor
26+
#define COMPAT_LOG log
27+
#define COMPAT_POW pow
28+
#define COMPAT_SQRT sqrt
29+
#define COMPAT_TAN tan
30+
#define COMPAT_ABS abs
31+
#define COMPAT_LOG1P log1p
32+
#else
33+
#define COMPAT_EXP std::exp
34+
#define COMPAT_CEIL std::ceil
35+
#define COMPAT_FLOOR std::floor
36+
#define COMPAT_LOG std::log
37+
#define COMPAT_POW std::pow
38+
#define COMPAT_SQRT std::sqrt
39+
#define COMPAT_TAN std::tan
40+
#define COMPAT_ABS std::abs
41+
#define COMPAT_LOG1P std::log1p
42+
#endif
43+
44+
namespace paddle {
45+
namespace operators {
46+
template <typename DeviceContext, typename T>
47+
struct DirichletSampler;
48+
49+
template <typename ScalarT, typename SamplerT>
50+
struct BaseSampler {
51+
SamplerT sampler_;
52+
HOSTDEVICE BaseSampler(const SamplerT& sampler) : sampler_(sampler) {}
53+
HOSTDEVICE ScalarT sample() { return sampler_(); }
54+
};
55+
56+
// `sample_gamma` is d from Numpy's distributions.c, and add support for
57+
// paddle data type and code style.
58+
// Source MIT licensed:
59+
/* Copyright 2005 Robert Kern (robert.kern@gmail.com)
60+
*
61+
* Permission is hereby granted, free of charge, to any person obtaining a
62+
* copy of this software and associated documentation files (the
63+
* "Software"), to deal in the Software without restriction, including
64+
* without limitation the rights to use, copy, modify, merge, publish,
65+
* distribute, sublicense, and/or sell copies of the Software, and to
66+
* permit persons to whom the Software is furnished to do so, subject to
67+
* the following conditions:
68+
*
69+
* The above copyright notice and this permission notice shall be included
70+
* in all copies or substantial portions of the Software.
71+
*
72+
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
73+
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
74+
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
75+
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
76+
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
77+
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
78+
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
79+
*/
80+
81+
template <typename ScalarT, typename AccscalarT, typename UniformSamplerT,
82+
typename NormalSamplerT>
83+
HOSTDEVICE ScalarT sample_gamma(
84+
ScalarT alpha, BaseSampler<AccscalarT, UniformSamplerT> standard_uniform,
85+
BaseSampler<AccscalarT, NormalSamplerT> standard_normal) {
86+
AccscalarT scale = 1.0f;
87+
88+
// Boost alpha for higher acceptance probability.
89+
if (alpha < 1.0f) {
90+
if (alpha == 0.f) return 0.f;
91+
scale *= COMPAT_POW(1 - standard_uniform.sample(), 1.0f / alpha);
92+
alpha += 1.0f;
93+
}
94+
95+
// This implements the acceptance-rejection method of Marsaglia and Tsang
96+
// (2000)
97+
// doi:10.1145/358407.358414
98+
const AccscalarT d = alpha - 1.0f / 3.0f;
99+
const AccscalarT c = 1.0f / COMPAT_SQRT(9.0f * d);
100+
for (;;) {
101+
AccscalarT x, y;
102+
do {
103+
x = standard_normal.sample();
104+
y = 1.0f + c * x;
105+
} while (y <= 0);
106+
const AccscalarT v = y * y * y;
107+
const AccscalarT u = 1 - standard_uniform.sample();
108+
const AccscalarT xx = x * x;
109+
if (u < 1.0f - 0.0331f * xx * xx)
110+
return static_cast<ScalarT>(scale * d * v);
111+
if (COMPAT_LOG(u) < 0.5f * xx + d * (1.0f - v + COMPAT_LOG(v)))
112+
return static_cast<ScalarT>(scale * d * v);
113+
}
114+
}
115+
116+
template <typename DeviceContext, typename T>
117+
class DirichletKernel : public framework::OpKernel<T> {
118+
public:
119+
void Compute(const framework::ExecutionContext& ctx) const override {
120+
const auto* alpha = ctx.Input<framework::Tensor>("Alpha");
121+
auto* out = ctx.Output<framework::Tensor>("Out");
122+
out->mutable_data<T>(ctx.GetPlace());
123+
124+
DirichletSampler<DeviceContext, T> sampler;
125+
sampler(ctx, alpha, out);
126+
}
127+
};
128+
} // namespace operators
129+
} // namespace paddle

0 commit comments

Comments
 (0)