@@ -1364,6 +1364,136 @@ PDNode *patterns::TransposeFlattenConcat::operator()(
13641364  return  concat_out;
13651365}
13661366
1367+ PDNode *patterns::AnakinDetectionPattern::operator ()(
1368+     std::vector<PDNode *> conv_in, int  times) {
1369+   //  The times represents the repeat times of the
1370+   //  {prior_box, prior_box_loc_out, flatten, prior_box_var_out, reshape}
1371+   const  int  kNumFields  = 7 ;
1372+   const  int  kPriorBoxLocOffset  = 1 ;
1373+   const  int  kReshape1Offset  = 2 ;
1374+   const  int  kReshape1OutOffset  = 3 ;
1375+   const  int  kPriorBoxVarOffset  = 4 ;
1376+   const  int  kReshape2Offset  = 5 ;
1377+   const  int  kReshape2OutOffset  = 6 ;
1378+ 
1379+   const  int  kBoxCoderThirdInputOffset  = times;
1380+   const  int  kMultiClassSecondInputNmsOffset  = times + 1 ;
1381+ 
1382+   std::vector<PDNode *> nodes;
1383+ 
1384+   for  (int  i = 0 ; i < times; i++) {
1385+     nodes.push_back (
1386+         pattern->NewNode (GetNodeName (" prior_box" std::to_string (i)))
1387+             ->assert_is_op (" density_prior_box" 
1388+     nodes.push_back (pattern->NewNode (GetNodeName (" box_out" std::to_string (i)))
1389+                         ->assert_is_op_output (" density_prior_box" " Boxes" 
1390+                         ->assert_is_op_input (" reshape2" " X" 
1391+                         ->AsIntermediate ());
1392+     nodes.push_back (
1393+         pattern->NewNode (GetNodeName (" reshape1" std::to_string (i)))
1394+             ->assert_is_op (" reshape2" 
1395+ 
1396+     nodes.push_back (
1397+         pattern->NewNode (GetNodeName (" reshape1_out" std::to_string (i)))
1398+             ->assert_is_op_output (" reshape2" 
1399+             ->assert_is_op_nth_input (" concat" " X" 
1400+             ->AsIntermediate ());
1401+ 
1402+     nodes.push_back (
1403+         pattern->NewNode (GetNodeName (" box_var_out" std::to_string (i)))
1404+             ->assert_is_op_output (" density_prior_box" " Variances" 
1405+             ->assert_is_op_input (" reshape2" " X" 
1406+             ->AsIntermediate ());
1407+     nodes.push_back (
1408+         pattern->NewNode (GetNodeName (" reshape2" std::to_string (i)))
1409+             ->assert_is_op (" reshape2" 
1410+ 
1411+     nodes.push_back (
1412+         pattern->NewNode (GetNodeName (" reshape2_out" std::to_string (i)))
1413+             ->assert_is_op_output (" reshape2" 
1414+             ->assert_is_op_nth_input (" concat" " X" 
1415+             ->AsIntermediate ());
1416+   }
1417+ 
1418+   auto  concat_op1 = pattern->NewNode (GetNodeName (" concat1" 
1419+                         ->assert_is_op (" concat" 
1420+                         ->assert_op_has_n_inputs (" concat" 
1421+   auto  concat_out1 = pattern->NewNode (GetNodeName (" concat1_out" 
1422+                          ->assert_is_op_output (" concat" 
1423+                          ->AsIntermediate ();
1424+ 
1425+   auto  concat_op2 = pattern->NewNode (GetNodeName (" concat2" 
1426+                         ->assert_is_op (" concat" 
1427+                         ->assert_op_has_n_inputs (" concat" 
1428+   auto  concat_out2 = pattern->NewNode (GetNodeName (" concat2_out" 
1429+                          ->assert_is_op_output (" concat" 
1430+                          ->AsIntermediate ();
1431+ 
1432+   auto  box_coder_op = pattern->NewNode (GetNodeName (" box_coder" 
1433+                           ->assert_is_op (" box_coder" 
1434+                           ->assert_op_has_n_inputs (" box_coder" 3 );
1435+ 
1436+   auto  box_coder_out = pattern->NewNode (GetNodeName (" box_coder_out" 
1437+                            ->assert_is_op_output (" box_coder" 
1438+                            ->AsIntermediate ();
1439+ 
1440+   auto  multiclass_nms_op = pattern->NewNode (GetNodeName (" multiclass_nms" 
1441+                                ->assert_is_op (" multiclass_nms" 
1442+                                ->assert_op_has_n_inputs (" multiclass_nms" 2 );
1443+ 
1444+   auto  multiclass_nms_out = pattern->NewNode (GetNodeName (" multiclass_nms_out" 
1445+                                 ->assert_is_op_output (" multiclass_nms" 
1446+                                 ->AsOutput ();
1447+ 
1448+   std::vector<PDNode *> reshape1_outs;
1449+   std::vector<PDNode *> reshape2_outs;
1450+ 
1451+   for  (int  i = 0 ; i < times; i++) {
1452+     conv_in[i]->AsInput ();
1453+     //  prior_box
1454+     nodes[i * kNumFields ]->LinksFrom ({conv_in[i]});
1455+     //  prior_box box out
1456+     nodes[i * kNumFields  + kPriorBoxLocOffset ]->LinksFrom (
1457+         {nodes[i * kNumFields ]});
1458+     //  reshape
1459+     nodes[i * kNumFields  + kReshape1Offset ]->LinksFrom (
1460+         {nodes[i * kNumFields  + kPriorBoxLocOffset ]});
1461+     //  reshape_out
1462+     nodes[i * kNumFields  + kReshape1OutOffset ]->LinksFrom (
1463+         {nodes[i * kNumFields  + kReshape1Offset ]});
1464+ 
1465+     nodes[i * kNumFields  + kPriorBoxVarOffset ]->LinksFrom (
1466+         {nodes[i * kNumFields ]});
1467+     //  reshape
1468+     nodes[i * kNumFields  + kReshape2Offset ]->LinksFrom (
1469+         {nodes[i * kNumFields  + kPriorBoxVarOffset ]});
1470+     //  reshape_out
1471+     nodes[i * kNumFields  + kReshape2OutOffset ]->LinksFrom (
1472+         {nodes[i * kNumFields  + kReshape2Offset ]});
1473+ 
1474+     reshape1_outs.push_back (nodes[i * kNumFields  + kReshape1OutOffset ]);
1475+     reshape2_outs.push_back (nodes[i * kNumFields  + kReshape2OutOffset ]);
1476+   }
1477+ 
1478+   concat_op1->LinksFrom (reshape1_outs);
1479+   concat_op2->LinksFrom (reshape2_outs);
1480+   concat_out1->LinksFrom ({concat_op1});
1481+   concat_out2->LinksFrom ({concat_op2});
1482+ 
1483+   conv_in[kBoxCoderThirdInputOffset ]->AsInput ();
1484+   conv_in[kMultiClassSecondInputNmsOffset ]->AsInput ();
1485+ 
1486+   box_coder_op->LinksFrom (
1487+       {concat_out1, concat_out2, conv_in[kBoxCoderThirdInputOffset ]});
1488+   box_coder_out->LinksFrom ({box_coder_op});
1489+ 
1490+   multiclass_nms_op
1491+       ->LinksFrom ({box_coder_out, conv_in[kMultiClassSecondInputNmsOffset ]})
1492+       .LinksTo ({multiclass_nms_out});
1493+ 
1494+   return  multiclass_nms_out;
1495+ }
1496+ 
13671497}  //  namespace ir
13681498}  //  namespace framework
13691499}  //  namespace paddle
0 commit comments