-
Notifications
You must be signed in to change notification settings - Fork 5.8k
FasterRCNN Anchor Generator Op #11218
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
19181e1
79d286d
456bb2c
3ae6208
c1b831d
c831787
1ddb1d7
8dedcec
be9a82e
6394330
4b04afd
8ac7b0a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include "paddle/fluid/operators/detection/anchor_generator_op.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
class AnchorGeneratorOp : public framework::OperatorWithKernel { | ||
public: | ||
using framework::OperatorWithKernel::OperatorWithKernel; | ||
|
||
void InferShape(framework::InferShapeContext* ctx) const override { | ||
PADDLE_ENFORCE(ctx->HasInput("Input"), | ||
"Input(Input) of AnchorGeneratorOp should not be null."); | ||
PADDLE_ENFORCE(ctx->HasOutput("Anchors"), | ||
"Output(Anchors) of AnchorGeneratorOp should not be null."); | ||
PADDLE_ENFORCE( | ||
ctx->HasOutput("Variances"), | ||
"Output(Variances) of AnchorGeneratorOp should not be null."); | ||
|
||
auto input_dims = ctx->GetInputDim("Input"); | ||
PADDLE_ENFORCE(input_dims.size() == 4, "The layout of input is NCHW."); | ||
|
||
auto anchor_sizes = ctx->Attrs().Get<std::vector<float>>("anchor_sizes"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. auto --> auto& |
||
auto aspect_ratios = ctx->Attrs().Get<std::vector<float>>("aspect_ratios"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. auto --> auto& |
||
auto stride = ctx->Attrs().Get<std::vector<float>>("stride"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. auto --> auto& |
||
auto variances = ctx->Attrs().Get<std::vector<float>>("variances"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. auto --> auto& |
||
|
||
size_t num_anchors = aspect_ratios.size() * anchor_sizes.size(); | ||
|
||
std::vector<int64_t> dim_vec(4); | ||
dim_vec[0] = input_dims[2]; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Check |
||
dim_vec[1] = input_dims[3]; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Check |
||
dim_vec[2] = num_anchors; | ||
dim_vec[3] = 4; | ||
ctx->SetOutputDim("Anchors", framework::make_ddim(dim_vec)); | ||
ctx->SetOutputDim("Variances", framework::make_ddim(dim_vec)); | ||
} | ||
|
||
protected: | ||
framework::OpKernelType GetExpectedKernelType( | ||
const framework::ExecutionContext& ctx) const override { | ||
return framework::OpKernelType( | ||
framework::ToDataType(ctx.Input<framework::Tensor>("Input")->type()), | ||
ctx.device_context()); | ||
} | ||
}; | ||
|
||
class AnchorGeneratorOpMaker : public framework::OpProtoAndCheckerMaker { | ||
public: | ||
void Make() override { | ||
AddInput("Input", | ||
"(Tensor, default Tensor<float>), " | ||
"the input feature is a tensor with a rank of 4. " | ||
"The layout is NCHW."); | ||
AddOutput("Anchors", | ||
"(Tensor, default Tensor<float>), the output is a " | ||
"tensor with a rank of 4. The layout is [H, W, num_anchors, 4]. " | ||
"H is the height of input, W is the width of input, num_anchors " | ||
"is the box count of each position. " | ||
"Each anchor is in (xmin, ymin, xmax, ymax) format"); | ||
AddOutput("Variances", | ||
"(Tensor, default Tensor<float>), the expanded variances for " | ||
"normalizing bbox regression targets. The layout is [H, W, " | ||
"num_anchors, 4]. " | ||
"H is the height of input, W is the width of input, num_anchors " | ||
"is the box count of each position. " | ||
"Each variance is in (xcenter, ycenter, w, h) format"); | ||
|
||
AddAttr<std::vector<float>>( | ||
"anchor_sizes", | ||
"(vector<float>) List of Region Proposal Network(RPN) anchor sizes " | ||
" given in absolute pixels e.g. (64, 128, 256, 512)." | ||
" For instance, the anchor size of 64 means the area of this anchor " | ||
"equals to 64**2.") | ||
.AddCustomChecker([](const std::vector<float>& anchor_sizes) { | ||
PADDLE_ENFORCE_GT(anchor_sizes.size(), 0, | ||
"Size of anchor_sizes must be at least 1."); | ||
for (size_t i = 0; i < anchor_sizes.size(); ++i) { | ||
PADDLE_ENFORCE_GT(anchor_sizes[i], 0.0, | ||
"anchor_sizes[%d] must be positive.", i); | ||
} | ||
}); | ||
AddAttr<std::vector<float>>( | ||
"aspect_ratios", | ||
"(vector<float>) List of Region Proposal Network(RPN) anchor aspect " | ||
"ratios, e.g. (0.5, 1, 2)." | ||
"For instacne, the aspect ratio of 0.5 means the height / width of " | ||
"this anchor equals 0.5."); | ||
|
||
AddAttr<std::vector<float>>("variances", | ||
"(vector<float>) List of variances to be used " | ||
"in box regression deltas") | ||
.AddCustomChecker([](const std::vector<float>& variances) { | ||
PADDLE_ENFORCE_EQ(variances.size(), 4, | ||
"Must and only provide 4 variance."); | ||
for (size_t i = 0; i < variances.size(); ++i) { | ||
PADDLE_ENFORCE_GT(variances[i], 0.0, | ||
"variance[%d] must be greater than 0.", i); | ||
} | ||
}); | ||
|
||
AddAttr<std::vector<float>>("stride", | ||
"Anchors stride across width and height, " | ||
"with a default of (16, 16)") | ||
.SetDefault(std::vector<float>(2, 16.0)) | ||
.AddCustomChecker([](const std::vector<float>& stride) { | ||
PADDLE_ENFORCE_EQ( | ||
stride.size(), 2, | ||
"Must and only provide 2 stride for width and height."); | ||
for (size_t i = 0; i < stride.size(); ++i) { | ||
PADDLE_ENFORCE_GT(stride[i], 0.0, | ||
"stride[%d] should be larger than 0.", i); | ||
} | ||
}); | ||
|
||
AddAttr<float>("offset", | ||
"(float) " | ||
"Anchor center offset, with a default of 0.5") | ||
.SetDefault(0.5); | ||
AddComment(R"DOC( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 上面是否有默认值和Python端保持一致? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 0.5 is the default value both in python and C++ |
||
AnchorGenerator operator | ||
Generates anchors for Faster RCNN, FPN etc. algorithm. | ||
Each position of the input produce N anchors, N = | ||
size(anchor_sizes) * size(aspect_ratios). | ||
|
||
Please get more information from the following papers: | ||
https://arxiv.org/abs/1506.01497. | ||
)DOC"); | ||
} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle | ||
|
||
namespace ops = paddle::operators; | ||
REGISTER_OPERATOR(anchor_generator, ops::AnchorGeneratorOp, | ||
ops::AnchorGeneratorOpMaker, | ||
paddle::framework::EmptyGradOpMaker); | ||
|
||
REGISTER_OP_CPU_KERNEL(anchor_generator, ops::AnchorGeneratorOpKernel<float>, | ||
ops::AnchorGeneratorOpKernel<double>); |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include "paddle/fluid/operators/detection/anchor_generator_op.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
template <typename T> | ||
__global__ void GenAnchors(T* out, const T* aspect_ratios, const int ar_num, | ||
const T* anchor_sizes, const int as_num, | ||
const T* stride, const int sd_num, const int height, | ||
const int width, const T offset) { | ||
int num_anchors = as_num * ar_num; | ||
int box_num = height * width * num_anchors; | ||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < box_num; | ||
i += blockDim.x * gridDim.x) { | ||
int h_idx = i / (num_anchors * width); | ||
int w_idx = (i / num_anchors) % width; | ||
T stride_width = stride[0]; | ||
T stride_height = stride[1]; | ||
T x_ctr = (w_idx * stride_width) + offset * (stride_width - 1); | ||
T y_ctr = (h_idx * stride_height) + offset * (stride_height - 1); | ||
T area, area_ratios; | ||
T base_w, base_h; | ||
T scale_w, scale_h; | ||
T anchor_width, anchor_height; | ||
int anch_idx = i % num_anchors; | ||
int ar_idx = anch_idx / as_num; | ||
int as_idx = anch_idx % as_num; | ||
T aspect_ratio = aspect_ratios[ar_idx]; | ||
T anchor_size = anchor_sizes[as_idx]; | ||
area = stride_width * stride_height; | ||
area_ratios = area / aspect_ratio; | ||
base_w = round(sqrt(area_ratios)); | ||
base_h = round(base_w * aspect_ratio); | ||
scale_w = anchor_size / stride_width; | ||
scale_h = anchor_size / stride_height; | ||
anchor_width = scale_w * base_w; | ||
anchor_height = scale_h * base_h; | ||
|
||
T xmin = (x_ctr - 0.5 * (anchor_width - 1)); | ||
T ymin = (y_ctr - 0.5 * (anchor_height - 1)); | ||
T xmax = (x_ctr + 0.5 * (anchor_width - 1)); | ||
T ymax = (y_ctr + 0.5 * (anchor_height - 1)); | ||
out[i * 4] = xmin; | ||
out[i * 4 + 1] = ymin; | ||
out[i * 4 + 2] = xmax; | ||
out[i * 4 + 3] = ymax; | ||
} | ||
} | ||
|
||
template <typename T> | ||
__global__ void SetVariance(T* out, const T* var, const int vnum, | ||
const int num) { | ||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num; | ||
i += blockDim.x * gridDim.x) { | ||
out[i] = var[i % vnum]; | ||
} | ||
} | ||
|
||
template <typename T> | ||
class AnchorGeneratorOpCUDAKernel : public framework::OpKernel<T> { | ||
public: | ||
void Compute(const framework::ExecutionContext& ctx) const override { | ||
auto* input = ctx.Input<paddle::framework::Tensor>("Input"); | ||
auto* anchors = ctx.Output<paddle::framework::Tensor>("Anchors"); | ||
auto* vars = ctx.Output<paddle::framework::Tensor>("Variances"); | ||
|
||
auto anchor_sizes = ctx.Attr<std::vector<float>>("anchor_sizes"); | ||
auto aspect_ratios = ctx.Attr<std::vector<float>>("aspect_ratios"); | ||
auto stride = ctx.Attr<std::vector<float>>("stride"); | ||
auto variances = ctx.Attr<std::vector<float>>("variances"); | ||
|
||
T offset = static_cast<T>(ctx.Attr<float>("offset")); | ||
|
||
auto width = input->dims()[3]; | ||
auto height = input->dims()[2]; | ||
|
||
int num_anchors = aspect_ratios.size() * anchor_sizes.size(); | ||
|
||
int box_num = width * height * num_anchors; | ||
|
||
int block = 512; | ||
int grid = (box_num + block - 1) / block; | ||
|
||
auto stream = | ||
ctx.template device_context<platform::CUDADeviceContext>().stream(); | ||
|
||
anchors->mutable_data<T>(ctx.GetPlace()); | ||
vars->mutable_data<T>(ctx.GetPlace()); | ||
|
||
framework::Tensor ar; | ||
framework::TensorFromVector(aspect_ratios, ctx.device_context(), &ar); | ||
|
||
framework::Tensor as; | ||
framework::TensorFromVector(anchor_sizes, ctx.device_context(), &as); | ||
|
||
framework::Tensor sd; | ||
framework::TensorFromVector(stride, ctx.device_context(), &sd); | ||
|
||
GenAnchors<T><<<grid, block, 0, stream>>>( | ||
anchors->data<T>(), ar.data<T>(), aspect_ratios.size(), as.data<T>(), | ||
anchor_sizes.size(), sd.data<T>(), stride.size(), height, width, | ||
offset); | ||
|
||
framework::Tensor v; | ||
framework::TensorFromVector(variances, ctx.device_context(), &v); | ||
grid = (box_num * 4 + block - 1) / block; | ||
SetVariance<T><<<grid, block, 0, stream>>>(vars->data<T>(), v.data<T>(), | ||
variances.size(), box_num * 4); | ||
} | ||
}; // namespace operators | ||
|
||
} // namespace operators | ||
} // namespace paddle | ||
|
||
namespace ops = paddle::operators; | ||
REGISTER_OP_CUDA_KERNEL(anchor_generator, | ||
ops::AnchorGeneratorOpCUDAKernel<float>, | ||
ops::AnchorGeneratorOpCUDAKernel<double>); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also need to check outputs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
revised