Skip to content

Commit ceb9c52

Browse files
committed
Add bincount op (PaddlePaddle#36317)
* Add bincount op * upload cpu version * fix unitest * fix unittest * fix unittest * fix en doc * add more test * fix en doc * add more test case * fix test * fix input vailidation * fix input check * fix unittest * fix test * fix en doc cherry-pick
1 parent 59615ff commit ceb9c52

File tree

8 files changed

+651
-0
lines changed

8 files changed

+651
-0
lines changed

paddle/fluid/operators/bincount_op.cc

+116
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/operators/bincount_op.h"
16+
17+
#include <string>
18+
#include <unordered_map>
19+
#include <vector>
20+
21+
namespace paddle {
22+
namespace operators {
23+
24+
using framework::OpKernelType;
25+
using framework::Tensor;
26+
27+
class BincountOp : public framework::OperatorWithKernel {
28+
public:
29+
using framework::OperatorWithKernel::OperatorWithKernel;
30+
31+
void InferShape(framework::InferShapeContext *ctx) const override {
32+
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
33+
platform::errors::InvalidArgument(
34+
"Input(X) of BincountOp should not be null."));
35+
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
36+
platform::errors::InvalidArgument(
37+
"Output(Out) of BincountOp should not be null."));
38+
39+
auto input_dim = ctx->GetInputDim("X");
40+
auto minlength = ctx->Attrs().Get<int>("minlength");
41+
42+
PADDLE_ENFORCE_GE(minlength, 0,
43+
platform::errors::InvalidArgument(
44+
"The minlength should be greater than or equal to 0."
45+
"But received minlength is %d",
46+
minlength));
47+
48+
PADDLE_ENFORCE_EQ(input_dim.size(), 1,
49+
platform::errors::InvalidArgument(
50+
"The 'shape' of Input(X) must be 1-D tensor."
51+
"But the dimension of Input(X) is [%d]",
52+
input_dim.size()));
53+
54+
if (ctx->HasInput("Weights")) {
55+
auto weights_dim = ctx->GetInputDim("Weights");
56+
PADDLE_ENFORCE_EQ(weights_dim.size(), 1,
57+
platform::errors::InvalidArgument(
58+
"The 'shape' of Input(Weights) must be 1-D tensor."
59+
"But the dimension of Input(Weights) is [%d]",
60+
weights_dim.size()));
61+
62+
PADDLE_ENFORCE_EQ(
63+
weights_dim[0], input_dim[0],
64+
platform::errors::InvalidArgument(
65+
"The 'shape' of Input(Weights) must be equal to the 'shape' of "
66+
"Input(X)."
67+
"But received: the 'shape' of Input(Weights) is [%s],"
68+
"the 'shape' of Input(X) is [%s]",
69+
weights_dim, input_dim));
70+
}
71+
72+
ctx->SetOutputDim("Out", framework::make_ddim({-1}));
73+
ctx->ShareLoD("X", /*->*/ "Out");
74+
}
75+
76+
framework::OpKernelType GetExpectedKernelType(
77+
const framework::ExecutionContext &ctx) const {
78+
auto data_type =
79+
ctx.HasInput("Weights")
80+
? OperatorWithKernel::IndicateVarDataType(ctx, "Weights")
81+
: OperatorWithKernel::IndicateVarDataType(ctx, "X");
82+
return framework::OpKernelType(data_type, ctx.device_context());
83+
}
84+
};
85+
86+
class BincountOpMaker : public framework::OpProtoAndCheckerMaker {
87+
public:
88+
void Make() override {
89+
AddInput("X", "(Tensor) The input tensor of Bincount op,");
90+
AddInput("Weights", "(Tensor) The weights tensor of Bincount op,")
91+
.AsDispensable();
92+
AddOutput("Out", "(Tensor) The output tensor of Bincount op,");
93+
AddAttr<int>("minlength", "(int) The minimal numbers of bins")
94+
.SetDefault(0)
95+
.EqualGreaterThan(0);
96+
AddComment(R"DOC(
97+
Bincount Operator.
98+
Computes frequency of each value in the input tensor.
99+
Elements of input tensor should be non-negative ints.
100+
)DOC");
101+
}
102+
};
103+
104+
} // namespace operators
105+
} // namespace paddle
106+
107+
namespace ops = paddle::operators;
108+
REGISTER_OPERATOR(
109+
bincount, ops::BincountOp, ops::BincountOpMaker,
110+
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
111+
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
112+
REGISTER_OP_CPU_KERNEL(
113+
bincount, ops::BincountKernel<paddle::platform::CPUDeviceContext, float>,
114+
ops::BincountKernel<paddle::platform::CPUDeviceContext, double>,
115+
ops::BincountKernel<paddle::platform::CPUDeviceContext, int>,
116+
ops::BincountKernel<paddle::platform::CPUDeviceContext, int64_t>);

paddle/fluid/operators/bincount_op.cu

+160
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/framework/eigen.h"
16+
#include "paddle/fluid/operators/bincount_op.h"
17+
#include "paddle/fluid/platform/cuda_primitives.h"
18+
#include "paddle/fluid/platform/gpu_launch_config.h"
19+
#include "paddle/fluid/platform/hostdevice.h"
20+
21+
namespace paddle {
22+
namespace operators {
23+
24+
using Tensor = framework::Tensor;
25+
using platform::PADDLE_CUDA_NUM_THREADS;
26+
27+
inline int GET_BLOCKS(const int N) {
28+
return (N + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS;
29+
}
30+
31+
template <typename T, typename InputT, typename OutT>
32+
__global__ void KernelBincount(const InputT* input, const int total_elements,
33+
const bool has_weights, const T* weights,
34+
OutT* output) {
35+
if (!has_weights) {
36+
for (int i = threadIdx.x; i < total_elements; i += blockDim.x) {
37+
paddle::platform::CudaAtomicAdd(&output[input[i]], 1L);
38+
}
39+
} else {
40+
for (int i = threadIdx.x; i < total_elements; i += blockDim.x) {
41+
paddle::platform::CudaAtomicAdd(&output[input[i]],
42+
static_cast<OutT>(weights[i]));
43+
}
44+
}
45+
}
46+
47+
template <typename DeviceContext, typename T, typename InputT>
48+
void BincountCUDAInner(const framework::ExecutionContext& context) {
49+
const Tensor* input = context.Input<framework::Tensor>("X");
50+
const Tensor* weights = context.Input<framework::Tensor>("Weights");
51+
Tensor* output = context.Output<framework::Tensor>("Out");
52+
auto& minlength = context.Attr<int>("minlength");
53+
54+
const InputT* input_data = input->data<InputT>();
55+
56+
const int input_numel = input->numel();
57+
58+
if (input_data == nullptr) {
59+
framework::DDim out_dim{0};
60+
output->Resize(out_dim);
61+
output->mutable_data<T>(context.GetPlace());
62+
return;
63+
}
64+
auto input_x = framework::EigenVector<InputT>::Flatten(*input);
65+
66+
framework::Tensor input_min_t, input_max_t;
67+
auto* input_max_data =
68+
input_max_t.mutable_data<InputT>({1}, context.GetPlace());
69+
auto* input_min_data =
70+
input_min_t.mutable_data<InputT>({1}, context.GetPlace());
71+
72+
auto input_max_scala = framework::EigenScalar<InputT>::From(input_max_t);
73+
auto input_min_scala = framework::EigenScalar<InputT>::From(input_min_t);
74+
75+
auto* place = context.template device_context<DeviceContext>().eigen_device();
76+
input_max_scala.device(*place) = input_x.maximum();
77+
input_min_scala.device(*place) = input_x.minimum();
78+
79+
Tensor input_min_cpu, input_max_cpu;
80+
TensorCopySync(input_max_t, platform::CPUPlace(), &input_max_cpu);
81+
TensorCopySync(input_min_t, platform::CPUPlace(), &input_min_cpu);
82+
83+
InputT input_min = input_min_cpu.data<InputT>()[0];
84+
85+
PADDLE_ENFORCE_GE(
86+
input_min, static_cast<InputT>(0),
87+
platform::errors::InvalidArgument(
88+
"The elements in input tensor must be non-negative ints"));
89+
90+
int64_t output_size =
91+
static_cast<int64_t>(input_max_cpu.data<InputT>()[0]) + 1L;
92+
93+
output_size = std::max(output_size, static_cast<int64_t>(minlength));
94+
framework::DDim out_dim{output_size};
95+
output->Resize(out_dim);
96+
97+
bool has_weights = (weights != nullptr);
98+
99+
const T* weights_data = has_weights ? weights->data<T>() : nullptr;
100+
101+
auto stream =
102+
context.template device_context<platform::CUDADeviceContext>().stream();
103+
104+
if (!has_weights) {
105+
int64_t* output_data = output->mutable_data<int64_t>(context.GetPlace());
106+
math::SetConstant<DeviceContext, int64_t>()(
107+
context.template device_context<DeviceContext>(), output, 0L);
108+
109+
KernelBincount<T, InputT, int64_t><<<GET_BLOCKS(input_numel),
110+
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
111+
input_data, input_numel, has_weights, weights_data, output_data);
112+
} else {
113+
const auto& weights_type = weights->type();
114+
115+
if (weights_type == framework::proto::VarType::FP32) {
116+
float* output_data = output->mutable_data<float>(context.GetPlace());
117+
math::SetConstant<DeviceContext, float>()(
118+
context.template device_context<DeviceContext>(), output,
119+
static_cast<float>(0));
120+
121+
KernelBincount<T, InputT, float><<<GET_BLOCKS(input_numel),
122+
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
123+
input_data, input_numel, has_weights, weights_data, output_data);
124+
} else {
125+
double* output_data = output->mutable_data<double>(context.GetPlace());
126+
math::SetConstant<DeviceContext, double>()(
127+
context.template device_context<DeviceContext>(), output,
128+
static_cast<double>(0));
129+
130+
KernelBincount<T, InputT, double><<<GET_BLOCKS(input_numel),
131+
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
132+
input_data, input_numel, has_weights, weights_data, output_data);
133+
}
134+
}
135+
}
136+
137+
template <typename DeviceContext, typename T>
138+
class BincountCUDAKernel : public framework::OpKernel<T> {
139+
public:
140+
void Compute(const framework::ExecutionContext& context) const override {
141+
const Tensor* input = context.Input<framework::Tensor>("X");
142+
const auto& input_type = input->type();
143+
144+
if (input_type == framework::proto::VarType::INT32) {
145+
BincountCUDAInner<DeviceContext, T, int>(context);
146+
} else if (input_type == framework::proto::VarType::INT64) {
147+
BincountCUDAInner<DeviceContext, T, int64_t>(context);
148+
}
149+
}
150+
};
151+
152+
} // namespace operators
153+
} // namespace paddle
154+
155+
namespace ops = paddle::operators;
156+
REGISTER_OP_CUDA_KERNEL(
157+
bincount, ops::BincountCUDAKernel<paddle::platform::CUDADeviceContext, int>,
158+
ops::BincountCUDAKernel<paddle::platform::CUDADeviceContext, int64_t>,
159+
ops::BincountCUDAKernel<paddle::platform::CUDADeviceContext, float>,
160+
ops::BincountCUDAKernel<paddle::platform::CUDADeviceContext, double>);

paddle/fluid/operators/bincount_op.h

+109
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include <algorithm>
18+
19+
#include "paddle/fluid/framework/op_registry.h"
20+
#include "paddle/fluid/framework/operator.h"
21+
#include "paddle/fluid/operators/math/math_function.h"
22+
23+
namespace paddle {
24+
namespace operators {
25+
26+
using Tensor = framework::Tensor;
27+
28+
template <typename DeviceContext, typename T, typename InputT>
29+
void BincountInner(const framework::ExecutionContext& context) {
30+
const Tensor* input = context.Input<framework::Tensor>("X");
31+
const Tensor* weights = context.Input<framework::Tensor>("Weights");
32+
Tensor* output = context.Output<framework::Tensor>("Out");
33+
auto& minlength = context.Attr<int>("minlength");
34+
35+
const InputT* input_data = input->data<InputT>();
36+
37+
auto input_numel = input->numel();
38+
39+
if (input_data == nullptr) {
40+
framework::DDim out_dim{0};
41+
output->Resize(out_dim);
42+
output->mutable_data<InputT>(context.GetPlace());
43+
return;
44+
}
45+
46+
PADDLE_ENFORCE_GE(
47+
*std::min_element(input_data, input_data + input_numel),
48+
static_cast<InputT>(0),
49+
platform::errors::InvalidArgument(
50+
"The elements in input tensor must be non-negative ints"));
51+
52+
int64_t output_size = static_cast<int64_t>(*std::max_element(
53+
input_data, input_data + input_numel)) +
54+
1L;
55+
output_size = std::max(output_size, static_cast<int64_t>(minlength));
56+
57+
framework::DDim out_dim{output_size};
58+
output->Resize(out_dim);
59+
60+
bool has_weights = (weights != nullptr);
61+
62+
if (has_weights) {
63+
const T* weights_data = weights->data<T>();
64+
const auto& weights_type = weights->type();
65+
if (weights_type == framework::proto::VarType::FP32) {
66+
float* output_data = output->mutable_data<float>(context.GetPlace());
67+
math::SetConstant<DeviceContext, float>()(
68+
context.template device_context<DeviceContext>(), output,
69+
static_cast<float>(0));
70+
for (int64_t i = 0; i < input_numel; i++) {
71+
output_data[input_data[i]] += static_cast<float>(weights_data[i]);
72+
}
73+
} else {
74+
double* output_data = output->mutable_data<double>(context.GetPlace());
75+
math::SetConstant<DeviceContext, double>()(
76+
context.template device_context<DeviceContext>(), output,
77+
static_cast<double>(0));
78+
for (int64_t i = 0; i < input_numel; i++) {
79+
output_data[input_data[i]] += static_cast<double>(weights_data[i]);
80+
}
81+
}
82+
83+
} else {
84+
int64_t* output_data = output->mutable_data<int64_t>(context.GetPlace());
85+
math::SetConstant<DeviceContext, int64_t>()(
86+
context.template device_context<DeviceContext>(), output, 0L);
87+
for (int64_t i = 0; i < input_numel; i++) {
88+
output_data[input_data[i]] += 1L;
89+
}
90+
}
91+
}
92+
93+
template <typename DeviceContext, typename T>
94+
class BincountKernel : public framework::OpKernel<T> {
95+
public:
96+
void Compute(const framework::ExecutionContext& context) const override {
97+
const Tensor* input = context.Input<framework::Tensor>("X");
98+
const auto& input_type = input->type();
99+
100+
if (input_type == framework::proto::VarType::INT32) {
101+
BincountInner<DeviceContext, T, int>(context);
102+
} else if (input_type == framework::proto::VarType::INT64) {
103+
BincountInner<DeviceContext, T, int64_t>(context);
104+
}
105+
}
106+
};
107+
108+
} // namespace operators
109+
} // namespace paddle

paddle/fluid/pybind/op_function_generator.cc

+4
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,10 @@
4040
// need to manually specify them in this map.
4141
std::map<std::string, std::set<std::string>> op_ins_map = {
4242
{"layer_norm", {"X", "Scale", "Bias"}},
43+
{"bincount", {"X", "Weights"}},
44+
{"fused_attention",
45+
{"X", "LnScale", "LnBias", "QKVW", "QKVBias", "SrcMask", "OutLinearW",
46+
"OutLinearBias", "Ln2Scale", "Ln2Bias"}},
4347
{"instance_norm", {"X", "Scale", "Bias"}},
4448
{"gru_unit", {"Input", "HiddenPrev", "Weight", "Bias"}},
4549
{"label_smooth", {"X", "PriorDist"}},

0 commit comments

Comments
 (0)