|
| 1 | +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +// |
| 3 | +// Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +// you may not use this file except in compliance with the License. |
| 5 | +// You may obtain a copy of the License at |
| 6 | +// |
| 7 | +// http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +// |
| 9 | +// Unless required by applicable law or agreed to in writing, software |
| 10 | +// distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +// See the License for the specific language governing permissions and |
| 13 | +// limitations under the License. |
| 14 | + |
| 15 | +#include "paddle/fluid/operators/dirichlet_op.h" |
| 16 | + |
| 17 | +#include "paddle/fluid/framework/generator.h" |
| 18 | +#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" |
| 19 | +#include "paddle/fluid/operators/reduce_ops/reduce_op.h" |
| 20 | +#include "paddle/fluid/operators/reduce_ops/reduce_sum_op.h" |
| 21 | + |
| 22 | +namespace paddle { |
| 23 | +namespace operators { |
| 24 | +template <typename T, typename UniformSamplerT, typename NormalSamplerT> |
| 25 | +struct GammaCPUFunctor { |
| 26 | + GammaCPUFunctor(const T* alpha, T* gamma, |
| 27 | + BaseSampler<T, UniformSamplerT> uniform, |
| 28 | + BaseSampler<T, NormalSamplerT> normal) |
| 29 | + : alpha_(alpha), gamma_(gamma), uniform_(uniform), normal_(normal) {} |
| 30 | + |
| 31 | + HOST void operator()(int64_t index) { |
| 32 | + auto sample = sample_gamma<T, T, UniformSamplerT, NormalSamplerT>( |
| 33 | + alpha_[index], uniform_, normal_); |
| 34 | + gamma_[index] = std::max(std::numeric_limits<T>::min(), sample); |
| 35 | + } |
| 36 | + |
| 37 | + const T* alpha_; |
| 38 | + T* gamma_; |
| 39 | + BaseSampler<T, UniformSamplerT> uniform_; |
| 40 | + BaseSampler<T, NormalSamplerT> normal_; |
| 41 | +}; |
| 42 | + |
| 43 | +template <typename T> |
| 44 | +struct DirichletSampler<platform::CPUDeviceContext, T> { |
| 45 | + void operator()(const framework::ExecutionContext& ctx, const Tensor* alpha, |
| 46 | + Tensor* out) { |
| 47 | + auto& dev_ctx = ctx.device_context<platform::CPUDeviceContext>(); |
| 48 | + |
| 49 | + auto p_gen = framework::DefaultCPUGenerator(); |
| 50 | + auto generator = p_gen->GetCPUEngine(); |
| 51 | + |
| 52 | + auto uniform = [&generator]() -> T { |
| 53 | + std::uniform_real_distribution<T> u(0.0, 1.0); |
| 54 | + return u(*generator); |
| 55 | + }; |
| 56 | + BaseSampler<T, decltype(uniform)> standard_uniform(uniform); |
| 57 | + |
| 58 | + auto normal = [&generator]() { |
| 59 | + std::normal_distribution<T> n(0.0, 1.0); |
| 60 | + return n(*generator); |
| 61 | + }; |
| 62 | + BaseSampler<T, decltype(normal)> standard_normal(normal); |
| 63 | + |
| 64 | + // sample from K gamma distributions, where K=alpha.numel() |
| 65 | + framework::Tensor gamma_samples; |
| 66 | + gamma_samples.mutable_data<T>(alpha->dims(), dev_ctx.GetPlace()); |
| 67 | + GammaCPUFunctor<T, decltype(uniform), decltype(normal)> gamma_functor( |
| 68 | + alpha->data<T>(), gamma_samples.data<T>(), standard_uniform, |
| 69 | + standard_normal); |
| 70 | + platform::ForRange<platform::CPUDeviceContext> for_range(dev_ctx, |
| 71 | + alpha->numel()); |
| 72 | + for_range(gamma_functor); |
| 73 | + |
| 74 | + // normalize them into a simplex, along the last axis |
| 75 | + framework::Tensor gamma_sum; |
| 76 | + auto new_shape = gamma_samples.dims(); |
| 77 | + new_shape[new_shape.size() - 1] = 1; |
| 78 | + gamma_sum.mutable_data<T>(new_shape, dev_ctx.GetPlace()); |
| 79 | + |
| 80 | + ReduceKernelFunctor<platform::CPUDeviceContext, T, SumFunctor>( |
| 81 | + &gamma_samples, &gamma_sum, {new_shape.size() - 1}, true, false, ctx) |
| 82 | + .template apply<T>(); |
| 83 | + ElementwiseComputeEx<DivFunctor<T>, platform::CPUDeviceContext, T, T>( |
| 84 | + ctx, &gamma_samples, &gamma_sum, -1, DivFunctor<T>(), out); |
| 85 | + } |
| 86 | +}; |
| 87 | + |
| 88 | +class DirichletOpMaker : public framework::OpProtoAndCheckerMaker { |
| 89 | + public: |
| 90 | + void Make() override { |
| 91 | + AddInput("Alpha", "(Tensor), The dirichlet Alpha parameter"); |
| 92 | + AddOutput("Out", "(Tensor), The output tensor of sample"); |
| 93 | + AddComment(R"DOC(Sample random data from dirichlet distribution.)DOC"); |
| 94 | + } |
| 95 | +}; |
| 96 | + |
| 97 | +class DirichletOp : public framework::OperatorWithKernel { |
| 98 | + public: |
| 99 | + using framework::OperatorWithKernel::OperatorWithKernel; |
| 100 | + |
| 101 | + void InferShape(framework::InferShapeContext* ctx) const override { |
| 102 | + OP_INOUT_CHECK(ctx->HasInput("Alpha"), "Input", "Alpha", "dirichlet"); |
| 103 | + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "dirichlet"); |
| 104 | + const auto alpha_dim = ctx->GetInputDim("Alpha"); |
| 105 | + PADDLE_ENFORCE_GE(alpha_dim.size(), 1, |
| 106 | + platform::errors::InvalidArgument( |
| 107 | + "ShapeError: The number of dimensions of 'Alpha' " |
| 108 | + "must be greater than or euqal to 1. " |
| 109 | + "But received Alpha's dimensions = %d,", |
| 110 | + alpha_dim.size())); |
| 111 | + ctx->ShareDim("Alpha", /*->*/ "Out"); |
| 112 | + } |
| 113 | +}; |
| 114 | + |
| 115 | +} // namespace operators |
| 116 | +} // namespace paddle |
| 117 | + |
| 118 | +REGISTER_OP_WITHOUT_GRADIENT(dirichlet, paddle::operators::DirichletOp, |
| 119 | + paddle::operators::DirichletOpMaker); |
| 120 | +REGISTER_OP_CPU_KERNEL( |
| 121 | + dirichlet, |
| 122 | + paddle::operators::DirichletKernel<paddle::platform::CPUDeviceContext, |
| 123 | + float>, |
| 124 | + paddle::operators::DirichletKernel<paddle::platform::CPUDeviceContext, |
| 125 | + double>); |
0 commit comments