|
| 1 | +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +
|
| 3 | +Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +you may not use this file except in compliance with the License. |
| 5 | +You may obtain a copy of the License at |
| 6 | +
|
| 7 | + http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +
|
| 9 | +Unless required by applicable law or agreed to in writing, software |
| 10 | +distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +See the License for the specific language governing permissions and |
| 13 | +limitations under the License. */ |
| 14 | + |
| 15 | +#include "paddle/fluid/framework/eigen.h" |
| 16 | +#include "paddle/fluid/operators/bincount_op.h" |
| 17 | +#include "paddle/fluid/platform/cuda_primitives.h" |
| 18 | +#include "paddle/fluid/platform/gpu_launch_config.h" |
| 19 | +#include "paddle/fluid/platform/hostdevice.h" |
| 20 | + |
| 21 | +namespace paddle { |
| 22 | +namespace operators { |
| 23 | + |
| 24 | +using Tensor = framework::Tensor; |
| 25 | +using platform::PADDLE_CUDA_NUM_THREADS; |
| 26 | + |
| 27 | +inline int GET_BLOCKS(const int N) { |
| 28 | + return (N + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS; |
| 29 | +} |
| 30 | + |
| 31 | +template <typename T, typename InputT, typename OutT> |
| 32 | +__global__ void KernelBincount(const InputT* input, const int total_elements, |
| 33 | + const bool has_weights, const T* weights, |
| 34 | + OutT* output) { |
| 35 | + if (!has_weights) { |
| 36 | + for (int i = threadIdx.x; i < total_elements; i += blockDim.x) { |
| 37 | + paddle::platform::CudaAtomicAdd(&output[input[i]], 1L); |
| 38 | + } |
| 39 | + } else { |
| 40 | + for (int i = threadIdx.x; i < total_elements; i += blockDim.x) { |
| 41 | + paddle::platform::CudaAtomicAdd(&output[input[i]], |
| 42 | + static_cast<OutT>(weights[i])); |
| 43 | + } |
| 44 | + } |
| 45 | +} |
| 46 | + |
| 47 | +template <typename DeviceContext, typename T, typename InputT> |
| 48 | +void BincountCUDAInner(const framework::ExecutionContext& context) { |
| 49 | + const Tensor* input = context.Input<framework::Tensor>("X"); |
| 50 | + const Tensor* weights = context.Input<framework::Tensor>("Weights"); |
| 51 | + Tensor* output = context.Output<framework::Tensor>("Out"); |
| 52 | + auto& minlength = context.Attr<int>("minlength"); |
| 53 | + |
| 54 | + const InputT* input_data = input->data<InputT>(); |
| 55 | + |
| 56 | + const int input_numel = input->numel(); |
| 57 | + |
| 58 | + if (input_data == nullptr) { |
| 59 | + framework::DDim out_dim{0}; |
| 60 | + output->Resize(out_dim); |
| 61 | + output->mutable_data<T>(context.GetPlace()); |
| 62 | + return; |
| 63 | + } |
| 64 | + auto input_x = framework::EigenVector<InputT>::Flatten(*input); |
| 65 | + |
| 66 | + framework::Tensor input_min_t, input_max_t; |
| 67 | + auto* input_max_data = |
| 68 | + input_max_t.mutable_data<InputT>({1}, context.GetPlace()); |
| 69 | + auto* input_min_data = |
| 70 | + input_min_t.mutable_data<InputT>({1}, context.GetPlace()); |
| 71 | + |
| 72 | + auto input_max_scala = framework::EigenScalar<InputT>::From(input_max_t); |
| 73 | + auto input_min_scala = framework::EigenScalar<InputT>::From(input_min_t); |
| 74 | + |
| 75 | + auto* place = context.template device_context<DeviceContext>().eigen_device(); |
| 76 | + input_max_scala.device(*place) = input_x.maximum(); |
| 77 | + input_min_scala.device(*place) = input_x.minimum(); |
| 78 | + |
| 79 | + Tensor input_min_cpu, input_max_cpu; |
| 80 | + TensorCopySync(input_max_t, platform::CPUPlace(), &input_max_cpu); |
| 81 | + TensorCopySync(input_min_t, platform::CPUPlace(), &input_min_cpu); |
| 82 | + |
| 83 | + InputT input_min = input_min_cpu.data<InputT>()[0]; |
| 84 | + |
| 85 | + PADDLE_ENFORCE_GE( |
| 86 | + input_min, static_cast<InputT>(0), |
| 87 | + platform::errors::InvalidArgument( |
| 88 | + "The elements in input tensor must be non-negative ints")); |
| 89 | + |
| 90 | + int64_t output_size = |
| 91 | + static_cast<int64_t>(input_max_cpu.data<InputT>()[0]) + 1L; |
| 92 | + |
| 93 | + output_size = std::max(output_size, static_cast<int64_t>(minlength)); |
| 94 | + framework::DDim out_dim{output_size}; |
| 95 | + output->Resize(out_dim); |
| 96 | + |
| 97 | + bool has_weights = (weights != nullptr); |
| 98 | + |
| 99 | + const T* weights_data = has_weights ? weights->data<T>() : nullptr; |
| 100 | + |
| 101 | + auto stream = |
| 102 | + context.template device_context<platform::CUDADeviceContext>().stream(); |
| 103 | + |
| 104 | + if (!has_weights) { |
| 105 | + int64_t* output_data = output->mutable_data<int64_t>(context.GetPlace()); |
| 106 | + math::SetConstant<DeviceContext, int64_t>()( |
| 107 | + context.template device_context<DeviceContext>(), output, 0L); |
| 108 | + |
| 109 | + KernelBincount<T, InputT, int64_t><<<GET_BLOCKS(input_numel), |
| 110 | + PADDLE_CUDA_NUM_THREADS, 0, stream>>>( |
| 111 | + input_data, input_numel, has_weights, weights_data, output_data); |
| 112 | + } else { |
| 113 | + const auto& weights_type = weights->type(); |
| 114 | + |
| 115 | + if (weights_type == framework::proto::VarType::FP32) { |
| 116 | + float* output_data = output->mutable_data<float>(context.GetPlace()); |
| 117 | + math::SetConstant<DeviceContext, float>()( |
| 118 | + context.template device_context<DeviceContext>(), output, |
| 119 | + static_cast<float>(0)); |
| 120 | + |
| 121 | + KernelBincount<T, InputT, float><<<GET_BLOCKS(input_numel), |
| 122 | + PADDLE_CUDA_NUM_THREADS, 0, stream>>>( |
| 123 | + input_data, input_numel, has_weights, weights_data, output_data); |
| 124 | + } else { |
| 125 | + double* output_data = output->mutable_data<double>(context.GetPlace()); |
| 126 | + math::SetConstant<DeviceContext, double>()( |
| 127 | + context.template device_context<DeviceContext>(), output, |
| 128 | + static_cast<double>(0)); |
| 129 | + |
| 130 | + KernelBincount<T, InputT, double><<<GET_BLOCKS(input_numel), |
| 131 | + PADDLE_CUDA_NUM_THREADS, 0, stream>>>( |
| 132 | + input_data, input_numel, has_weights, weights_data, output_data); |
| 133 | + } |
| 134 | + } |
| 135 | +} |
| 136 | + |
| 137 | +template <typename DeviceContext, typename T> |
| 138 | +class BincountCUDAKernel : public framework::OpKernel<T> { |
| 139 | + public: |
| 140 | + void Compute(const framework::ExecutionContext& context) const override { |
| 141 | + const Tensor* input = context.Input<framework::Tensor>("X"); |
| 142 | + const auto& input_type = input->type(); |
| 143 | + |
| 144 | + if (input_type == framework::proto::VarType::INT32) { |
| 145 | + BincountCUDAInner<DeviceContext, T, int>(context); |
| 146 | + } else if (input_type == framework::proto::VarType::INT64) { |
| 147 | + BincountCUDAInner<DeviceContext, T, int64_t>(context); |
| 148 | + } |
| 149 | + } |
| 150 | +}; |
| 151 | + |
| 152 | +} // namespace operators |
| 153 | +} // namespace paddle |
| 154 | + |
| 155 | +namespace ops = paddle::operators; |
| 156 | +REGISTER_OP_CUDA_KERNEL( |
| 157 | + bincount, ops::BincountCUDAKernel<paddle::platform::CUDADeviceContext, int>, |
| 158 | + ops::BincountCUDAKernel<paddle::platform::CUDADeviceContext, int64_t>, |
| 159 | + ops::BincountCUDAKernel<paddle::platform::CUDADeviceContext, float>, |
| 160 | + ops::BincountCUDAKernel<paddle::platform::CUDADeviceContext, double>); |
0 commit comments