Skip to content

Commit e77dce4

Browse files
authored
Merge pull request #16111 from NHZlX/anakin_support_facebox
Anakin support facebox
2 parents f8dcbd5 + 5dc3f81 commit e77dce4

26 files changed

+770
-27
lines changed

paddle/fluid/framework/ir/CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ pass_library(conv_elementwise_add_fuse_pass inference)
6666
pass_library(conv_affine_channel_fuse_pass inference)
6767
pass_library(transpose_flatten_concat_fuse_pass inference)
6868
pass_library(identity_scale_op_clean_pass base)
69+
pass_library(simplify_anakin_detection_pattern_pass inference)
6970

7071
# There may be many transpose-flatten structures in a model, and the output of
7172
# these structures will be used as inputs to the concat Op. This pattern will
@@ -76,6 +77,10 @@ foreach (index RANGE 3 6)
7677
file(APPEND ${pass_file} "USE_PASS(transpose_flatten${index}_concat_fuse_pass);\n")
7778
endforeach()
7879

80+
foreach (index RANGE 3 6)
81+
file(APPEND ${pass_file} "USE_PASS(simplify_anakin_detection_pattern_pass${index});\n")
82+
endforeach()
83+
7984
if(WITH_MKLDNN)
8085
pass_library(mkldnn_placement_pass base mkldnn)
8186
pass_library(depthwise_conv_mkldnn_pass base mkldnn)

paddle/fluid/framework/ir/graph_pattern_detector.cc

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1364,6 +1364,136 @@ PDNode *patterns::TransposeFlattenConcat::operator()(
13641364
return concat_out;
13651365
}
13661366

1367+
PDNode *patterns::AnakinDetectionPattern::operator()(
1368+
std::vector<PDNode *> conv_in, int times) {
1369+
// The times represents the repeat times of the
1370+
// {prior_box, prior_box_loc_out, flatten, prior_box_var_out, reshape}
1371+
const int kNumFields = 7;
1372+
const int kPriorBoxLocOffset = 1;
1373+
const int kReshape1Offset = 2;
1374+
const int kReshape1OutOffset = 3;
1375+
const int kPriorBoxVarOffset = 4;
1376+
const int kReshape2Offset = 5;
1377+
const int kReshape2OutOffset = 6;
1378+
1379+
const int kBoxCoderThirdInputOffset = times;
1380+
const int kMultiClassSecondInputNmsOffset = times + 1;
1381+
1382+
std::vector<PDNode *> nodes;
1383+
1384+
for (int i = 0; i < times; i++) {
1385+
nodes.push_back(
1386+
pattern->NewNode(GetNodeName("prior_box" + std::to_string(i)))
1387+
->assert_is_op("density_prior_box"));
1388+
nodes.push_back(pattern->NewNode(GetNodeName("box_out" + std::to_string(i)))
1389+
->assert_is_op_output("density_prior_box", "Boxes")
1390+
->assert_is_op_input("reshape2", "X")
1391+
->AsIntermediate());
1392+
nodes.push_back(
1393+
pattern->NewNode(GetNodeName("reshape1" + std::to_string(i)))
1394+
->assert_is_op("reshape2"));
1395+
1396+
nodes.push_back(
1397+
pattern->NewNode(GetNodeName("reshape1_out" + std::to_string(i)))
1398+
->assert_is_op_output("reshape2")
1399+
->assert_is_op_nth_input("concat", "X", i)
1400+
->AsIntermediate());
1401+
1402+
nodes.push_back(
1403+
pattern->NewNode(GetNodeName("box_var_out" + std::to_string(i)))
1404+
->assert_is_op_output("density_prior_box", "Variances")
1405+
->assert_is_op_input("reshape2", "X")
1406+
->AsIntermediate());
1407+
nodes.push_back(
1408+
pattern->NewNode(GetNodeName("reshape2" + std::to_string(i)))
1409+
->assert_is_op("reshape2"));
1410+
1411+
nodes.push_back(
1412+
pattern->NewNode(GetNodeName("reshape2_out" + std::to_string(i)))
1413+
->assert_is_op_output("reshape2")
1414+
->assert_is_op_nth_input("concat", "X", i)
1415+
->AsIntermediate());
1416+
}
1417+
1418+
auto concat_op1 = pattern->NewNode(GetNodeName("concat1"))
1419+
->assert_is_op("concat")
1420+
->assert_op_has_n_inputs("concat", times);
1421+
auto concat_out1 = pattern->NewNode(GetNodeName("concat1_out"))
1422+
->assert_is_op_output("concat")
1423+
->AsIntermediate();
1424+
1425+
auto concat_op2 = pattern->NewNode(GetNodeName("concat2"))
1426+
->assert_is_op("concat")
1427+
->assert_op_has_n_inputs("concat", times);
1428+
auto concat_out2 = pattern->NewNode(GetNodeName("concat2_out"))
1429+
->assert_is_op_output("concat")
1430+
->AsIntermediate();
1431+
1432+
auto box_coder_op = pattern->NewNode(GetNodeName("box_coder"))
1433+
->assert_is_op("box_coder")
1434+
->assert_op_has_n_inputs("box_coder", 3);
1435+
1436+
auto box_coder_out = pattern->NewNode(GetNodeName("box_coder_out"))
1437+
->assert_is_op_output("box_coder")
1438+
->AsIntermediate();
1439+
1440+
auto multiclass_nms_op = pattern->NewNode(GetNodeName("multiclass_nms"))
1441+
->assert_is_op("multiclass_nms")
1442+
->assert_op_has_n_inputs("multiclass_nms", 2);
1443+
1444+
auto multiclass_nms_out = pattern->NewNode(GetNodeName("multiclass_nms_out"))
1445+
->assert_is_op_output("multiclass_nms")
1446+
->AsOutput();
1447+
1448+
std::vector<PDNode *> reshape1_outs;
1449+
std::vector<PDNode *> reshape2_outs;
1450+
1451+
for (int i = 0; i < times; i++) {
1452+
conv_in[i]->AsInput();
1453+
// prior_box
1454+
nodes[i * kNumFields]->LinksFrom({conv_in[i]});
1455+
// prior_box box out
1456+
nodes[i * kNumFields + kPriorBoxLocOffset]->LinksFrom(
1457+
{nodes[i * kNumFields]});
1458+
// reshape
1459+
nodes[i * kNumFields + kReshape1Offset]->LinksFrom(
1460+
{nodes[i * kNumFields + kPriorBoxLocOffset]});
1461+
// reshape_out
1462+
nodes[i * kNumFields + kReshape1OutOffset]->LinksFrom(
1463+
{nodes[i * kNumFields + kReshape1Offset]});
1464+
1465+
nodes[i * kNumFields + kPriorBoxVarOffset]->LinksFrom(
1466+
{nodes[i * kNumFields]});
1467+
// reshape
1468+
nodes[i * kNumFields + kReshape2Offset]->LinksFrom(
1469+
{nodes[i * kNumFields + kPriorBoxVarOffset]});
1470+
// reshape_out
1471+
nodes[i * kNumFields + kReshape2OutOffset]->LinksFrom(
1472+
{nodes[i * kNumFields + kReshape2Offset]});
1473+
1474+
reshape1_outs.push_back(nodes[i * kNumFields + kReshape1OutOffset]);
1475+
reshape2_outs.push_back(nodes[i * kNumFields + kReshape2OutOffset]);
1476+
}
1477+
1478+
concat_op1->LinksFrom(reshape1_outs);
1479+
concat_op2->LinksFrom(reshape2_outs);
1480+
concat_out1->LinksFrom({concat_op1});
1481+
concat_out2->LinksFrom({concat_op2});
1482+
1483+
conv_in[kBoxCoderThirdInputOffset]->AsInput();
1484+
conv_in[kMultiClassSecondInputNmsOffset]->AsInput();
1485+
1486+
box_coder_op->LinksFrom(
1487+
{concat_out1, concat_out2, conv_in[kBoxCoderThirdInputOffset]});
1488+
box_coder_out->LinksFrom({box_coder_op});
1489+
1490+
multiclass_nms_op
1491+
->LinksFrom({box_coder_out, conv_in[kMultiClassSecondInputNmsOffset]})
1492+
.LinksTo({multiclass_nms_out});
1493+
1494+
return multiclass_nms_out;
1495+
}
1496+
13671497
} // namespace ir
13681498
} // namespace framework
13691499
} // namespace paddle

paddle/fluid/framework/ir/graph_pattern_detector.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,11 @@
1818
#include <gtest/gtest_prod.h>
1919
#endif
2020

21+
#include <memory>
2122
#include <numeric>
2223
#include <string>
24+
#include <unordered_map>
25+
#include <unordered_set>
2326
#include <utility>
2427
#include <vector>
2528
#include "paddle/fluid/framework/ir/graph.h"
@@ -781,6 +784,21 @@ struct TransposeFlattenConcat : public PatternBase {
781784
}
782785
};
783786

787+
struct AnakinDetectionPattern : public PatternBase {
788+
AnakinDetectionPattern(PDPattern* pattern, const std::string& name_scope)
789+
: PatternBase(pattern, name_scope, "anakin_detect_pattern") {}
790+
791+
PDNode* operator()(std::vector<PDNode*> conv_inputs, int times);
792+
793+
std::string GetNodeName(const std::string& op_type) {
794+
return PDNodeName(name_scope_, repr_, id_, op_type);
795+
}
796+
797+
PDNode* GetPDNode(const std::string& op_type) {
798+
return pattern->RetrieveNode(GetNodeName(op_type));
799+
}
800+
};
801+
784802
} // namespace patterns
785803

786804
// Link two ir::Nodes from each other.

0 commit comments

Comments
 (0)