@@ -49,11 +49,28 @@ bool IsTarget(const proto::OpDesc& op_desc) {
49
49
return false ;
50
50
}
51
51
52
- void prune_impl (const proto::ProgramDesc& input, proto::ProgramDesc* output,
53
- int block_id) {
54
- // TODO(tonyyang-svail):
55
- // - will change to use multiple blocks for RNN op and Cond Op
52
+ int GetSubBlockIndex (const proto::OpDesc& op_desc) {
53
+ for (auto & attr : op_desc.attrs ()) {
54
+ if (attr.type () == proto::AttrType::BLOCK) {
55
+ PADDLE_ENFORCE (attr.has_block_idx ());
56
+ return attr.block_idx ();
57
+ }
58
+ }
59
+ return -1 ;
60
+ }
61
+
62
+ bool HasSubBlock (const proto::OpDesc& op_desc) {
63
+ return GetSubBlockIndex (op_desc) > 0 ;
64
+ }
56
65
66
+ // block_id is the idx of the current block in the input desc
67
+ // parent_block_id is the idx of the parent of the current block
68
+ // in the output desc, -1 means the current block is global block
69
+ // dependent_vars is passed recursively from the parent block to
70
+ // the child block to help pruning
71
+ void prune_impl (const proto::ProgramDesc& input, proto::ProgramDesc* output,
72
+ int block_id, int parent_block_id,
73
+ std::set<std::string>& dependent_vars) {
57
74
auto & block = input.blocks (block_id);
58
75
auto & ops = block.ops ();
59
76
@@ -72,19 +89,16 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
72
89
expect_fetch = (op_desc.type () == kFetchOpType );
73
90
}
74
91
75
- std::set<std::string> dependent_vars;
76
92
std::vector<bool > should_run;
77
93
for (auto op_iter = ops.rbegin (); op_iter != ops.rend (); ++op_iter) {
78
94
auto & op_desc = *op_iter;
79
-
80
95
if (IsTarget (op_desc) || HasDependentVar (op_desc, dependent_vars)) {
81
96
// insert its input to the dependency graph
82
97
for (auto & var : op_desc.inputs ()) {
83
98
for (auto & argu : var.arguments ()) {
84
99
dependent_vars.insert (argu);
85
100
}
86
101
}
87
-
88
102
should_run.push_back (true );
89
103
} else {
90
104
should_run.push_back (false );
@@ -95,45 +109,81 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
95
109
// we reverse the should_run vector
96
110
std::reverse (should_run.begin (), should_run.end ());
97
111
98
- *output = input;
99
- auto * op_field = output->mutable_blocks (block_id)->mutable_ops ();
112
+ // copy the current block from input to output
113
+ auto * block_field = output->mutable_blocks ();
114
+ *block_field->Add () = input.blocks (block_id);
115
+
116
+ int output_block_id = output->blocks_size () - 1 ;
117
+ auto * output_block = output->mutable_blocks (output_block_id);
118
+ output_block->set_idx (output_block_id);
119
+ output_block->set_parent_idx (parent_block_id);
120
+
121
+ auto * op_field = output_block->mutable_ops ();
100
122
op_field->Clear ();
101
123
for (size_t i = 0 ; i < should_run.size (); ++i) {
102
124
if (should_run[i]) {
103
- *op_field->Add () = input.blocks (block_id).ops (i);
125
+ auto * op = op_field->Add ();
126
+ *op = input.blocks (block_id).ops (i);
127
+ if (HasSubBlock (*op)) {
128
+ // create sub_block_dependent_vars here to help prune the sub block
129
+ std::set<std::string> sub_block_dependent_vars;
130
+ for (auto & var : op->inputs ()) {
131
+ for (auto & argu : var.arguments ()) {
132
+ sub_block_dependent_vars.insert (argu);
133
+ }
134
+ }
135
+ for (auto & var : op->outputs ()) {
136
+ for (auto & argu : var.arguments ()) {
137
+ sub_block_dependent_vars.insert (argu);
138
+ }
139
+ }
140
+ // GetSubBlockIndex(*op) is the idx of the sub_block in the input desc
141
+ // output_block_id is the idx of the current block in the output desc
142
+ prune_impl (input, output, GetSubBlockIndex (*op), output_block_id,
143
+ sub_block_dependent_vars);
144
+ }
104
145
}
105
146
}
106
147
107
148
// remove the VarDescs in BlockDesc that are not referenced in
108
149
// the pruned OpDescs
109
150
std::unordered_map<std::string, proto::VarDesc> var_map;
110
- auto * var_field = output->mutable_blocks (block_id )->mutable_vars ();
151
+ auto * var_field = output->mutable_blocks (output_block_id )->mutable_vars ();
111
152
for (const auto & var : *var_field) {
112
153
var_map[var.name ()] = var;
113
154
}
114
155
115
- var_field-> Clear () ;
156
+ std::set<std::string> var_names ;
116
157
for (const auto & op : *op_field) {
117
- // add VarDescs of all input arguments for each OpDesc
118
158
auto & input_field = op.inputs ();
119
159
for (auto & input_var : input_field) {
120
160
for (auto & arg : input_var.arguments ()) {
121
- *var_field->Add () = var_map[arg];
161
+ if (var_map.count (arg) != 0 ) {
162
+ var_names.insert (arg);
163
+ }
122
164
}
123
165
}
124
- // add VarDescs of all output arguments for each OpDesc
125
166
auto & output_field = op.outputs ();
126
167
for (auto & output_var : output_field) {
127
168
for (auto & arg : output_var.arguments ()) {
128
- *var_field->Add () = var_map[arg];
169
+ if (var_map.count (arg) != 0 ) {
170
+ var_names.insert (arg);
171
+ }
129
172
}
130
173
}
131
174
}
175
+
176
+ var_field->Clear ();
177
+ for (const auto & name : var_names) {
178
+ *var_field->Add () = var_map[name];
179
+ }
132
180
}
133
181
134
182
// TODO(fengjiayi): Prune() could be inplaced to avoid unnecessary copies
135
183
void Prune (const proto::ProgramDesc& input, proto::ProgramDesc* output) {
136
- prune_impl (input, output, 0 );
184
+ std::set<std::string> dependent_vars;
185
+ output->clear_blocks ();
186
+ prune_impl (input, output, 0 , -1 , dependent_vars);
137
187
}
138
188
139
189
void inference_optimize_impl (const proto::ProgramDesc& input,
0 commit comments