Skip to content

Commit 40c576e

Browse files
committed
phi_multiclass_nms3
1 parent 72b65d6 commit 40c576e

File tree

10 files changed

+910
-82
lines changed

10 files changed

+910
-82
lines changed

paddle/fluid/operators/detection/multiclass_nms_op.cc

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@ limitations under the License. */
1313

1414
#include <glog/logging.h>
1515

16+
#include "paddle/fluid/framework/infershape_utils.h"
1617
#include "paddle/fluid/framework/op_registry.h"
1718
#include "paddle/fluid/operators/detection/nms_util.h"
19+
#include "paddle/phi/infermeta/ternary.h"
1820

1921
namespace paddle {
2022
namespace operators {
@@ -609,12 +611,6 @@ class MultiClassNMS3Op : public MultiClassNMS2Op {
609611
const framework::VariableNameMap& outputs,
610612
const framework::AttributeMap& attrs)
611613
: MultiClassNMS2Op(type, inputs, outputs, attrs) {}
612-
613-
void InferShape(framework::InferShapeContext* ctx) const override {
614-
MultiClassNMS2Op::InferShape(ctx);
615-
616-
ctx->SetOutputDim("NmsRoisNum", {-1});
617-
}
618614
};
619615

620616
class MultiClassNMS3OpMaker : public MultiClassNMS2OpMaker {
@@ -633,6 +629,10 @@ class MultiClassNMS3OpMaker : public MultiClassNMS2OpMaker {
633629
} // namespace operators
634630
} // namespace paddle
635631

632+
DECLARE_INFER_SHAPE_FUNCTOR(multiclass_nms3,
633+
MultiClassNMSShapeFunctor,
634+
PD_INFER_META(phi::MultiClassNMSInferMeta));
635+
636636
namespace ops = paddle::operators;
637637
REGISTER_OPERATOR(
638638
multiclass_nms,
@@ -658,7 +658,5 @@ REGISTER_OPERATOR(
658658
ops::MultiClassNMS3Op,
659659
ops::MultiClassNMS3OpMaker,
660660
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
661-
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
662-
REGISTER_OP_CPU_KERNEL(multiclass_nms3,
663-
ops::MultiClassNMSKernel<float>,
664-
ops::MultiClassNMSKernel<double>);
661+
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
662+
MultiClassNMSShapeFunctor);

paddle/phi/api/yaml/legacy_api.yaml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1592,6 +1592,15 @@
15921592
func : multi_dot
15931593
backward : multi_dot_grad
15941594

1595+
- api : multiclass_nms3
1596+
args : (Tensor bboxes, Tensor scores, Tensor rois_num, float score_threshold, int nms_top_k, int keep_top_k, float nms_threshold=0.3, bool normalized=true, float nms_eta=1.0, int background_label=0)
1597+
output : Tensor(out), Tensor(index), Tensor(nms_rois_num)
1598+
infer_meta :
1599+
func : MultiClassNMSInferMeta
1600+
kernel :
1601+
func : multiclass_nms3
1602+
optional : rois_num
1603+
15951604
# multinomial
15961605
- api : multinomial
15971606
args : (Tensor x, int num_samples, bool replacement)

paddle/phi/infermeta/ternary.cc

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -743,6 +743,99 @@ void LinspaceInferMeta(const MetaTensor& start,
743743
LinspaceRawInferMeta(start, stop, number, out);
744744
}
745745

746+
void MultiClassNMSInferMeta(const MetaTensor& bboxes,
747+
const MetaTensor& scores,
748+
const MetaTensor& rois_num,
749+
float score_threshold,
750+
int nms_top_k,
751+
int keep_top_k,
752+
float nms_threshold,
753+
bool normalized,
754+
float nms_eta,
755+
int background_label,
756+
MetaTensor* out,
757+
MetaTensor* index,
758+
MetaTensor* nms_rois_num,
759+
MetaConfig config) {
760+
auto box_dims = bboxes.dims();
761+
auto score_dims = scores.dims();
762+
auto score_size = score_dims.size();
763+
764+
if (config.is_runtime) {
765+
PADDLE_ENFORCE_EQ(
766+
score_size == 2 || score_size == 3,
767+
true,
768+
errors::InvalidArgument("The rank of Input(Scores) must be 2 or 3"
769+
". But received rank = %d",
770+
score_size));
771+
PADDLE_ENFORCE_EQ(
772+
box_dims.size(),
773+
3,
774+
errors::InvalidArgument("The rank of Input(BBoxes) must be 3"
775+
". But received rank = %d",
776+
box_dims.size()));
777+
if (score_size == 3) {
778+
PADDLE_ENFORCE_EQ(box_dims[2] == 4 || box_dims[2] == 8 ||
779+
box_dims[2] == 16 || box_dims[2] == 24 ||
780+
box_dims[2] == 32,
781+
true,
782+
errors::InvalidArgument(
783+
"The last dimension of Input"
784+
"(BBoxes) must be 4 or 8, "
785+
"represents the layout of coordinate "
786+
"[xmin, ymin, xmax, ymax] or "
787+
"4 points: [x1, y1, x2, y2, x3, y3, x4, y4] or "
788+
"8 points: [xi, yi] i= 1,2,...,8 or "
789+
"12 points: [xi, yi] i= 1,2,...,12 or "
790+
"16 points: [xi, yi] i= 1,2,...,16"));
791+
PADDLE_ENFORCE_EQ(
792+
box_dims[1],
793+
score_dims[2],
794+
errors::InvalidArgument(
795+
"The 2nd dimension of Input(BBoxes) must be equal to "
796+
"last dimension of Input(Scores), which represents the "
797+
"predicted bboxes."
798+
"But received box_dims[1](%s) != socre_dims[2](%s)",
799+
box_dims[1],
800+
score_dims[2]));
801+
} else {
802+
PADDLE_ENFORCE_EQ(box_dims[2],
803+
4,
804+
errors::InvalidArgument(
805+
"The last dimension of Input"
806+
"(BBoxes) must be 4. But received dimension = %d",
807+
box_dims[2]));
808+
PADDLE_ENFORCE_EQ(
809+
box_dims[1],
810+
score_dims[1],
811+
errors::InvalidArgument(
812+
"The 2nd dimension of Input"
813+
"(BBoxes) must be equal to the 2nd dimension of Input(Scores). "
814+
"But received box dimension = %d, score dimension = %d",
815+
box_dims[1],
816+
score_dims[1]));
817+
}
818+
}
819+
PADDLE_ENFORCE_NE(out,
820+
nullptr,
821+
errors::InvalidArgument(
822+
"The out in MultiClassNMSInferMeta can't be nullptr."));
823+
PADDLE_ENFORCE_NE(
824+
index,
825+
nullptr,
826+
errors::InvalidArgument(
827+
"The index in MultiClassNMSInferMeta can't be nullptr."));
828+
// Here the box_dims[0] is not the real dimension of output.
829+
// It will be rewritten in the computing kernel.
830+
831+
out->set_dims(phi::make_ddim({-1, box_dims[2] + 2}));
832+
out->set_dtype(bboxes.dtype());
833+
index->set_dims(phi::make_ddim({-1, box_dims[2] + 2}));
834+
index->set_dtype(DataType::INT32);
835+
nms_rois_num->set_dims(phi::make_ddim({-1}));
836+
nms_rois_num->set_dtype(DataType::INT32);
837+
}
838+
746839
void NllLossRawInferMeta(const MetaTensor& input,
747840
const MetaTensor& label,
748841
const MetaTensor& weight,

paddle/phi/infermeta/ternary.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,21 @@ void LinspaceInferMeta(const MetaTensor& start,
123123
DataType dtype,
124124
MetaTensor* out);
125125

126+
void MultiClassNMSInferMeta(const MetaTensor& bboxes,
127+
const MetaTensor& scores,
128+
const MetaTensor& rois_num,
129+
float score_threshold,
130+
int nms_top_k,
131+
int keep_top_k,
132+
float nms_threshold,
133+
bool normalized,
134+
float nms_eta,
135+
int background_label,
136+
MetaTensor* out,
137+
MetaTensor* index,
138+
MetaTensor* nms_rois_num,
139+
MetaConfig config = MetaConfig());
140+
126141
void NllLossRawInferMeta(const MetaTensor& input,
127142
const MetaTensor& label,
128143
const MetaTensor& weight,

paddle/phi/kernels/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,8 @@ set(COMMON_KERNEL_DEPS
8080
lod_utils
8181
custom_kernel
8282
string_infermeta
83-
utf8proc)
83+
utf8proc
84+
gpc)
8485

8586
copy_if_different(${kernel_declare_file} ${kernel_declare_file_final})
8687

0 commit comments

Comments
 (0)