Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions paddle/fluid/operators/detection/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,11 @@ function(detection_library TARGET_NAME)
endfunction()

if(WITH_ASCEND_CL)
detection_library(box_coder_op SRCS box_coder_op.cc box_coder_op.cu
box_coder_op_npu.cc)
detection_library(box_coder_op SRCS box_coder_op.cc box_coder_op_npu.cc)
detection_library(density_prior_box_op SRCS density_prior_box_op.cc
density_prior_box_op.cu density_prior_box_op_npu.cc)
else()
detection_library(box_coder_op SRCS box_coder_op.cc box_coder_op.cu)
detection_library(box_coder_op SRCS box_coder_op.cc)
detection_library(density_prior_box_op SRCS density_prior_box_op.cc
density_prior_box_op.cu)
endif()
Expand Down
137 changes: 12 additions & 125 deletions paddle/fluid/operators/detection/box_coder_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,135 +9,19 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/detection/box_coder_op.h"

#include <vector>

#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/ternary.h"

namespace paddle {
namespace operators {

class BoxCoderOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

protected:
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInput("PriorBox"),
true,
platform::errors::NotFound(
"Input(PriorBox) of BoxCoder operator is not found."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("TargetBox"),
true,
platform::errors::NotFound(
"Input(TargetBox) of BoxCoder operator is not found."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("OutputBox"),
true,
platform::errors::NotFound(
"Output(OutputBox) of BoxCoder operator is not found."));

auto prior_box_dims = ctx->GetInputDim("PriorBox");
auto target_box_dims = ctx->GetInputDim("TargetBox");

if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(prior_box_dims.size(),
2,
platform::errors::InvalidArgument(
"The rank of Input PriorBox in BoxCoder operator "
"must be 2. But received rank = %d",
prior_box_dims.size()));
PADDLE_ENFORCE_EQ(prior_box_dims[1],
4,
platform::errors::InvalidArgument(
"The second dimension of PriorBox in BoxCoder "
"operator must be 4. But received dimension = %d",
prior_box_dims[1]));
if (ctx->HasInput("PriorBoxVar")) {
auto prior_box_var_dims = ctx->GetInputDim("PriorBoxVar");
PADDLE_ENFORCE_EQ(
prior_box_var_dims.size(),
2,
platform::errors::InvalidArgument(
"The rank of Input(PriorBoxVar) in BoxCoder operator"
" should be 2. But received rank = %d",
prior_box_var_dims.size()));
PADDLE_ENFORCE_EQ(
prior_box_dims,
prior_box_var_dims,
platform::errors::InvalidArgument(
"The dimension of Input(PriorBoxVar) should be equal to"
"the dimension of Input(PriorBox) in BoxCoder operator "
"when the rank is 2."));
}
}

auto code_type = GetBoxCodeType(ctx->Attrs().Get<std::string>("code_type"));
int axis = ctx->Attrs().Get<int>("axis");
if (code_type == BoxCodeType::kEncodeCenterSize) {
PADDLE_ENFORCE_EQ(target_box_dims.size(),
2,
platform::errors::InvalidArgument(
"The rank of Input TargetBox in BoxCoder operator "
"must be 2. But received rank is %d",
target_box_dims.size()));
PADDLE_ENFORCE_EQ(target_box_dims[1],
4,
platform::errors::InvalidArgument(
"The second dimension of TargetBox in BoxCoder "
"operator is 4. But received dimension is %d",
target_box_dims[1]));
ctx->SetOutputDim(
"OutputBox",
phi::make_ddim({target_box_dims[0], prior_box_dims[0], 4}));
} else if (code_type == BoxCodeType::kDecodeCenterSize) {
PADDLE_ENFORCE_EQ(target_box_dims.size(),
3,
platform::errors::InvalidArgument(
"The rank of Input TargetBox in BoxCoder "
"operator must be 3. But received rank is %d",
target_box_dims.size()));
PADDLE_ENFORCE_EQ(axis == 0 || axis == 1,
true,
platform::errors::InvalidArgument(
"axis in BoxCoder operator must be 0 or 1."
"But received axis = %d",
axis));
if (ctx->IsRuntime()) {
if (axis == 0) {
PADDLE_ENFORCE_EQ(
target_box_dims[1],
prior_box_dims[0],
platform::errors::InvalidArgument(
"When axis is 0, The second "
"dimension of TargetBox in BoxCoder "
"should be equal to the first dimension of PriorBox."));
} else if (axis == 1) {
PADDLE_ENFORCE_EQ(
target_box_dims[0],
prior_box_dims[0],
platform::errors::InvalidArgument(
"When axis is 1, The first "
"dimension of TargetBox in BoxCoder "
"should be equal to the first dimension of PriorBox."));
}
PADDLE_ENFORCE_EQ(target_box_dims[2],
prior_box_dims[1],
platform::errors::InvalidArgument(
"The third dimension of TargetBox"
" in BoxCoder should be equal to the "
"second dimension of PriorBox."));
}
ctx->ShareDim("TargetBox", /*->*/ "OutputBox");
}

if (code_type == BoxCodeType::kDecodeCenterSize && axis == 1) {
ctx->ShareLoD("PriorBox", /*->*/ "OutputBox");
} else {
ctx->ShareLoD("TargetBox", /*->*/ "OutputBox");
}
}
};

class BoxCoderOpMaker : public framework::OpProtoAndCheckerMaker {
Expand Down Expand Up @@ -245,12 +129,15 @@ box will broadcast to target box along the assigned axis.
} // namespace paddle

namespace ops = paddle::operators;

DECLARE_INFER_SHAPE_FUNCTOR(box_coder,
BoxCoderInferShapeFunctor,
PD_INFER_META(phi::BoxCoderInferMeta));

REGISTER_OPERATOR(
box_coder,
ops::BoxCoderOp,
ops::BoxCoderOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(box_coder,
ops::BoxCoderKernel<phi::CPUContext, float>,
ops::BoxCoderKernel<phi::CPUContext, double>);
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
BoxCoderInferShapeFunctor);
Loading