@@ -743,6 +743,99 @@ void LinspaceInferMeta(const MetaTensor& start,
743743  LinspaceRawInferMeta (start, stop, number, out);
744744}
745745
746+ void  MultiClassNMSInferMeta (const  MetaTensor& bboxes,
747+                             const  MetaTensor& scores,
748+                             const  MetaTensor& rois_num,
749+                             float  score_threshold,
750+                             int  nms_top_k,
751+                             int  keep_top_k,
752+                             float  nms_threshold,
753+                             bool  normalized,
754+                             float  nms_eta,
755+                             int  background_label,
756+                             MetaTensor* out,
757+                             MetaTensor* index,
758+                             MetaTensor* nms_rois_num,
759+                             MetaConfig config) {
760+   auto  box_dims = bboxes.dims ();
761+   auto  score_dims = scores.dims ();
762+   auto  score_size = score_dims.size ();
763+ 
764+   if  (config.is_runtime ) {
765+     PADDLE_ENFORCE_EQ (
766+         score_size == 2  || score_size == 3 ,
767+         true ,
768+         errors::InvalidArgument (" The rank of Input(Scores) must be 2 or 3" 
769+                                 " . But received rank = %d" 
770+                                 score_size));
771+     PADDLE_ENFORCE_EQ (
772+         box_dims.size (),
773+         3 ,
774+         errors::InvalidArgument (" The rank of Input(BBoxes) must be 3" 
775+                                 " . But received rank = %d" 
776+                                 box_dims.size ()));
777+     if  (score_size == 3 ) {
778+       PADDLE_ENFORCE_EQ (box_dims[2 ] == 4  || box_dims[2 ] == 8  ||
779+                             box_dims[2 ] == 16  || box_dims[2 ] == 24  ||
780+                             box_dims[2 ] == 32 ,
781+                         true ,
782+                         errors::InvalidArgument (
783+                             " The last dimension of Input" 
784+                             " (BBoxes) must be 4 or 8, " 
785+                             " represents the layout of coordinate " 
786+                             " [xmin, ymin, xmax, ymax] or " 
787+                             " 4 points: [x1, y1, x2, y2, x3, y3, x4, y4] or " 
788+                             " 8 points: [xi, yi] i= 1,2,...,8 or " 
789+                             " 12 points: [xi, yi] i= 1,2,...,12 or " 
790+                             " 16 points: [xi, yi] i= 1,2,...,16" 
791+       PADDLE_ENFORCE_EQ (
792+           box_dims[1 ],
793+           score_dims[2 ],
794+           errors::InvalidArgument (
795+               " The 2nd dimension of Input(BBoxes) must be equal to " 
796+               " last dimension of Input(Scores), which represents the " 
797+               " predicted bboxes." 
798+               " But received box_dims[1](%s) != socre_dims[2](%s)" 
799+               box_dims[1 ],
800+               score_dims[2 ]));
801+     } else  {
802+       PADDLE_ENFORCE_EQ (box_dims[2 ],
803+                         4 ,
804+                         errors::InvalidArgument (
805+                             " The last dimension of Input" 
806+                             " (BBoxes) must be 4. But received dimension = %d" 
807+                             box_dims[2 ]));
808+       PADDLE_ENFORCE_EQ (
809+           box_dims[1 ],
810+           score_dims[1 ],
811+           errors::InvalidArgument (
812+               " The 2nd dimension of Input" 
813+               " (BBoxes) must be equal to the 2nd dimension of Input(Scores). " 
814+               " But received box dimension = %d, score dimension = %d" 
815+               box_dims[1 ],
816+               score_dims[1 ]));
817+     }
818+   }
819+   PADDLE_ENFORCE_NE (out,
820+                     nullptr ,
821+                     errors::InvalidArgument (
822+                         " The out in MultiClassNMSInferMeta can't be nullptr." 
823+   PADDLE_ENFORCE_NE (
824+       index,
825+       nullptr ,
826+       errors::InvalidArgument (
827+           " The index in MultiClassNMSInferMeta can't be nullptr." 
828+   //  Here the box_dims[0] is not the real dimension of output.
829+   //  It will be rewritten in the computing kernel.
830+ 
831+   out->set_dims (phi::make_ddim ({-1 , box_dims[2 ] + 2 }));
832+   out->set_dtype (bboxes.dtype ());
833+   index->set_dims (phi::make_ddim ({-1 , box_dims[2 ] + 2 }));
834+   index->set_dtype (DataType::INT32);
835+   nms_rois_num->set_dims (phi::make_ddim ({-1 }));
836+   nms_rois_num->set_dtype (DataType::INT32);
837+ }
838+ 
746839void  NllLossRawInferMeta (const  MetaTensor& input,
747840                         const  MetaTensor& label,
748841                         const  MetaTensor& weight,
0 commit comments