Skip to content

Commit b7d31df

Browse files
sefiraqingqing01
authored andcommitted
FasterRCNN Anchor Generator Op (PaddlePaddle#11218)
* Add anchor generator operator for Faster-RCNN. * Add unittest testing. * Add Python API.
1 parent 1e96642 commit b7d31df

File tree

7 files changed

+618
-0
lines changed

7 files changed

+618
-0
lines changed

paddle/fluid/operators/detection/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ iou_similarity_op.cu)
2222
detection_library(mine_hard_examples_op SRCS mine_hard_examples_op.cc)
2323
detection_library(multiclass_nms_op SRCS multiclass_nms_op.cc)
2424
detection_library(prior_box_op SRCS prior_box_op.cc prior_box_op.cu)
25+
detection_library(anchor_generator_op SRCS anchor_generator_op.cc
26+
anchor_generator_op.cu)
2527
detection_library(target_assign_op SRCS target_assign_op.cc
2628
target_assign_op.cu)
2729
detection_library(polygon_box_transform_op SRCS polygon_box_transform_op.cc
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/operators/detection/anchor_generator_op.h"
16+
17+
namespace paddle {
18+
namespace operators {
19+
20+
class AnchorGeneratorOp : public framework::OperatorWithKernel {
21+
public:
22+
using framework::OperatorWithKernel::OperatorWithKernel;
23+
24+
void InferShape(framework::InferShapeContext* ctx) const override {
25+
PADDLE_ENFORCE(ctx->HasInput("Input"),
26+
"Input(Input) of AnchorGeneratorOp should not be null.");
27+
PADDLE_ENFORCE(ctx->HasOutput("Anchors"),
28+
"Output(Anchors) of AnchorGeneratorOp should not be null.");
29+
PADDLE_ENFORCE(
30+
ctx->HasOutput("Variances"),
31+
"Output(Variances) of AnchorGeneratorOp should not be null.");
32+
33+
auto input_dims = ctx->GetInputDim("Input");
34+
PADDLE_ENFORCE(input_dims.size() == 4, "The layout of input is NCHW.");
35+
36+
auto anchor_sizes = ctx->Attrs().Get<std::vector<float>>("anchor_sizes");
37+
auto aspect_ratios = ctx->Attrs().Get<std::vector<float>>("aspect_ratios");
38+
auto stride = ctx->Attrs().Get<std::vector<float>>("stride");
39+
auto variances = ctx->Attrs().Get<std::vector<float>>("variances");
40+
41+
size_t num_anchors = aspect_ratios.size() * anchor_sizes.size();
42+
43+
std::vector<int64_t> dim_vec(4);
44+
dim_vec[0] = input_dims[2];
45+
dim_vec[1] = input_dims[3];
46+
dim_vec[2] = num_anchors;
47+
dim_vec[3] = 4;
48+
ctx->SetOutputDim("Anchors", framework::make_ddim(dim_vec));
49+
ctx->SetOutputDim("Variances", framework::make_ddim(dim_vec));
50+
}
51+
52+
protected:
53+
framework::OpKernelType GetExpectedKernelType(
54+
const framework::ExecutionContext& ctx) const override {
55+
return framework::OpKernelType(
56+
framework::ToDataType(ctx.Input<framework::Tensor>("Input")->type()),
57+
ctx.device_context());
58+
}
59+
};
60+
61+
class AnchorGeneratorOpMaker : public framework::OpProtoAndCheckerMaker {
62+
public:
63+
void Make() override {
64+
AddInput("Input",
65+
"(Tensor, default Tensor<float>), "
66+
"the input feature is a tensor with a rank of 4. "
67+
"The layout is NCHW.");
68+
AddOutput("Anchors",
69+
"(Tensor, default Tensor<float>), the output is a "
70+
"tensor with a rank of 4. The layout is [H, W, num_anchors, 4]. "
71+
"H is the height of input, W is the width of input, num_anchors "
72+
"is the box count of each position. "
73+
"Each anchor is in (xmin, ymin, xmax, ymax) format");
74+
AddOutput("Variances",
75+
"(Tensor, default Tensor<float>), the expanded variances for "
76+
"normalizing bbox regression targets. The layout is [H, W, "
77+
"num_anchors, 4]. "
78+
"H is the height of input, W is the width of input, num_anchors "
79+
"is the box count of each position. "
80+
"Each variance is in (xcenter, ycenter, w, h) format");
81+
82+
AddAttr<std::vector<float>>(
83+
"anchor_sizes",
84+
"(vector<float>) List of Region Proposal Network(RPN) anchor sizes "
85+
" given in absolute pixels e.g. (64, 128, 256, 512)."
86+
" For instance, the anchor size of 64 means the area of this anchor "
87+
"equals to 64**2.")
88+
.AddCustomChecker([](const std::vector<float>& anchor_sizes) {
89+
PADDLE_ENFORCE_GT(anchor_sizes.size(), 0,
90+
"Size of anchor_sizes must be at least 1.");
91+
for (size_t i = 0; i < anchor_sizes.size(); ++i) {
92+
PADDLE_ENFORCE_GT(anchor_sizes[i], 0.0,
93+
"anchor_sizes[%d] must be positive.", i);
94+
}
95+
});
96+
AddAttr<std::vector<float>>(
97+
"aspect_ratios",
98+
"(vector<float>) List of Region Proposal Network(RPN) anchor aspect "
99+
"ratios, e.g. (0.5, 1, 2)."
100+
"For instacne, the aspect ratio of 0.5 means the height / width of "
101+
"this anchor equals 0.5.");
102+
103+
AddAttr<std::vector<float>>("variances",
104+
"(vector<float>) List of variances to be used "
105+
"in box regression deltas")
106+
.AddCustomChecker([](const std::vector<float>& variances) {
107+
PADDLE_ENFORCE_EQ(variances.size(), 4,
108+
"Must and only provide 4 variance.");
109+
for (size_t i = 0; i < variances.size(); ++i) {
110+
PADDLE_ENFORCE_GT(variances[i], 0.0,
111+
"variance[%d] must be greater than 0.", i);
112+
}
113+
});
114+
115+
AddAttr<std::vector<float>>("stride",
116+
"Anchors stride across width and height, "
117+
"with a default of (16, 16)")
118+
.SetDefault(std::vector<float>(2, 16.0))
119+
.AddCustomChecker([](const std::vector<float>& stride) {
120+
PADDLE_ENFORCE_EQ(
121+
stride.size(), 2,
122+
"Must and only provide 2 stride for width and height.");
123+
for (size_t i = 0; i < stride.size(); ++i) {
124+
PADDLE_ENFORCE_GT(stride[i], 0.0,
125+
"stride[%d] should be larger than 0.", i);
126+
}
127+
});
128+
129+
AddAttr<float>("offset",
130+
"(float) "
131+
"Anchor center offset, with a default of 0.5")
132+
.SetDefault(0.5);
133+
AddComment(R"DOC(
134+
AnchorGenerator operator
135+
Generates anchors for Faster RCNN, FPN etc. algorithm.
136+
Each position of the input produce N anchors, N =
137+
size(anchor_sizes) * size(aspect_ratios).
138+
139+
Please get more information from the following papers:
140+
https://arxiv.org/abs/1506.01497.
141+
)DOC");
142+
}
143+
};
144+
145+
} // namespace operators
146+
} // namespace paddle
147+
148+
namespace ops = paddle::operators;
149+
REGISTER_OPERATOR(anchor_generator, ops::AnchorGeneratorOp,
150+
ops::AnchorGeneratorOpMaker,
151+
paddle::framework::EmptyGradOpMaker);
152+
153+
REGISTER_OP_CPU_KERNEL(anchor_generator, ops::AnchorGeneratorOpKernel<float>,
154+
ops::AnchorGeneratorOpKernel<double>);
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/operators/detection/anchor_generator_op.h"
16+
17+
namespace paddle {
18+
namespace operators {
19+
20+
template <typename T>
21+
__global__ void GenAnchors(T* out, const T* aspect_ratios, const int ar_num,
22+
const T* anchor_sizes, const int as_num,
23+
const T* stride, const int sd_num, const int height,
24+
const int width, const T offset) {
25+
int num_anchors = as_num * ar_num;
26+
int box_num = height * width * num_anchors;
27+
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < box_num;
28+
i += blockDim.x * gridDim.x) {
29+
int h_idx = i / (num_anchors * width);
30+
int w_idx = (i / num_anchors) % width;
31+
T stride_width = stride[0];
32+
T stride_height = stride[1];
33+
T x_ctr = (w_idx * stride_width) + offset * (stride_width - 1);
34+
T y_ctr = (h_idx * stride_height) + offset * (stride_height - 1);
35+
T area, area_ratios;
36+
T base_w, base_h;
37+
T scale_w, scale_h;
38+
T anchor_width, anchor_height;
39+
int anch_idx = i % num_anchors;
40+
int ar_idx = anch_idx / as_num;
41+
int as_idx = anch_idx % as_num;
42+
T aspect_ratio = aspect_ratios[ar_idx];
43+
T anchor_size = anchor_sizes[as_idx];
44+
area = stride_width * stride_height;
45+
area_ratios = area / aspect_ratio;
46+
base_w = round(sqrt(area_ratios));
47+
base_h = round(base_w * aspect_ratio);
48+
scale_w = anchor_size / stride_width;
49+
scale_h = anchor_size / stride_height;
50+
anchor_width = scale_w * base_w;
51+
anchor_height = scale_h * base_h;
52+
53+
T xmin = (x_ctr - 0.5 * (anchor_width - 1));
54+
T ymin = (y_ctr - 0.5 * (anchor_height - 1));
55+
T xmax = (x_ctr + 0.5 * (anchor_width - 1));
56+
T ymax = (y_ctr + 0.5 * (anchor_height - 1));
57+
out[i * 4] = xmin;
58+
out[i * 4 + 1] = ymin;
59+
out[i * 4 + 2] = xmax;
60+
out[i * 4 + 3] = ymax;
61+
}
62+
}
63+
64+
template <typename T>
65+
__global__ void SetVariance(T* out, const T* var, const int vnum,
66+
const int num) {
67+
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num;
68+
i += blockDim.x * gridDim.x) {
69+
out[i] = var[i % vnum];
70+
}
71+
}
72+
73+
template <typename T>
74+
class AnchorGeneratorOpCUDAKernel : public framework::OpKernel<T> {
75+
public:
76+
void Compute(const framework::ExecutionContext& ctx) const override {
77+
auto* input = ctx.Input<paddle::framework::Tensor>("Input");
78+
auto* anchors = ctx.Output<paddle::framework::Tensor>("Anchors");
79+
auto* vars = ctx.Output<paddle::framework::Tensor>("Variances");
80+
81+
auto anchor_sizes = ctx.Attr<std::vector<float>>("anchor_sizes");
82+
auto aspect_ratios = ctx.Attr<std::vector<float>>("aspect_ratios");
83+
auto stride = ctx.Attr<std::vector<float>>("stride");
84+
auto variances = ctx.Attr<std::vector<float>>("variances");
85+
86+
T offset = static_cast<T>(ctx.Attr<float>("offset"));
87+
88+
auto width = input->dims()[3];
89+
auto height = input->dims()[2];
90+
91+
int num_anchors = aspect_ratios.size() * anchor_sizes.size();
92+
93+
int box_num = width * height * num_anchors;
94+
95+
int block = 512;
96+
int grid = (box_num + block - 1) / block;
97+
98+
auto stream =
99+
ctx.template device_context<platform::CUDADeviceContext>().stream();
100+
101+
anchors->mutable_data<T>(ctx.GetPlace());
102+
vars->mutable_data<T>(ctx.GetPlace());
103+
104+
framework::Tensor ar;
105+
framework::TensorFromVector(aspect_ratios, ctx.device_context(), &ar);
106+
107+
framework::Tensor as;
108+
framework::TensorFromVector(anchor_sizes, ctx.device_context(), &as);
109+
110+
framework::Tensor sd;
111+
framework::TensorFromVector(stride, ctx.device_context(), &sd);
112+
113+
GenAnchors<T><<<grid, block, 0, stream>>>(
114+
anchors->data<T>(), ar.data<T>(), aspect_ratios.size(), as.data<T>(),
115+
anchor_sizes.size(), sd.data<T>(), stride.size(), height, width,
116+
offset);
117+
118+
framework::Tensor v;
119+
framework::TensorFromVector(variances, ctx.device_context(), &v);
120+
grid = (box_num * 4 + block - 1) / block;
121+
SetVariance<T><<<grid, block, 0, stream>>>(vars->data<T>(), v.data<T>(),
122+
variances.size(), box_num * 4);
123+
}
124+
}; // namespace operators
125+
126+
} // namespace operators
127+
} // namespace paddle
128+
129+
namespace ops = paddle::operators;
130+
REGISTER_OP_CUDA_KERNEL(anchor_generator,
131+
ops::AnchorGeneratorOpCUDAKernel<float>,
132+
ops::AnchorGeneratorOpCUDAKernel<double>);

0 commit comments

Comments
 (0)