@@ -114,7 +114,8 @@ void AddOutputVar(const std::unordered_set<Node*>& output_vars,
114114// var node are from internal nodes
115115std::unique_ptr<Graph> CreateNewSubGraph (const GraphNodeSet& cluster,
116116 const GraphNodeSet& cluster_internals,
117- const GraphNodeSet& cluster_inputs) {
117+ const GraphNodeSet& cluster_inputs,
118+ const GraphNodeSet& cluster_outputs) {
118119 // Graph's constructor must has one parameter, and in our code,
119120 // the ProgramDesc is useless, so here we pass a temporary object.
120121 auto subgraph = std::make_unique<Graph>(framework::ProgramDesc ());
@@ -127,7 +128,12 @@ std::unique_ptr<Graph> CreateNewSubGraph(const GraphNodeSet& cluster,
127128
128129 std::unordered_map<Node*, Node*> old_var2new_var;
129130 for (auto * var : cluster_internals) {
130- auto sub_node = subgraph->CreateVarNode (var->Var ());
131+ Node* sub_node;
132+ if (var->Var () == nullptr ) {
133+ sub_node = subgraph->CreateEmptyNode (var->Name (), var->NodeType ());
134+ } else {
135+ sub_node = subgraph->CreateVarNode (var->Var ());
136+ }
131137 old_var2new_var[var] = sub_node;
132138 }
133139
@@ -140,7 +146,7 @@ std::unique_ptr<Graph> CreateNewSubGraph(const GraphNodeSet& cluster,
140146 for (auto * var : op->inputs ) {
141147 if (cluster_internals.count (var)) {
142148 old_op2new_op[op]->inputs .emplace_back (old_var2new_var[var]);
143- } else if (cluster_inputs.count (var)) {
149+ } else if (cluster_inputs.count (var) && var-> Var () != nullptr ) {
144150 if (var->Var ()->IsParameter ()) {
145151 // Parameters have been preserved in scope, compared to feed var,
146152 // param just need add new var and don't need add feed op.
@@ -157,7 +163,7 @@ std::unique_ptr<Graph> CreateNewSubGraph(const GraphNodeSet& cluster,
157163 for (auto * var : op->outputs ) {
158164 if (cluster_internals.count (var)) {
159165 old_op2new_op[op]->outputs .emplace_back (old_var2new_var[var]);
160- } else {
166+ } else if (cluster_outputs. count (var) && var-> Var () != nullptr ) {
161167 // Create new output var node to guarantee the independency of
162168 // subgraph. In other words, the subgraph has no connection with
163169 // other graph, even the input graph.
@@ -239,14 +245,20 @@ Node* AddSpecialOpToGraph(const GraphNodeSet& cluster_inputs,
239245 framework::OpDesc special_op_desc;
240246 special_op_desc.SetType (kCinnLaunchOp );
241247 std::vector<std::string> input_names;
242- std::transform (cluster_inputs.begin (), cluster_inputs.end (),
243- std::back_inserter (input_names),
244- [](Node* n) { return n->Name (); });
248+ std::for_each (cluster_inputs.begin (), cluster_inputs.end (),
249+ [&input_names](Node* n) {
250+ if (n->Var () != nullptr ) {
251+ input_names.emplace_back (n->Name ());
252+ }
253+ });
245254 special_op_desc.SetInput (" X" , input_names);
246255 std::vector<std::string> output_names;
247- std::transform (cluster_outputs.begin (), cluster_outputs.end (),
248- std::back_inserter (output_names),
249- [](Node* n) { return n->Name (); });
256+ std::for_each (cluster_outputs.begin (), cluster_outputs.end (),
257+ [&output_names](Node* n) {
258+ if (n->Var () != nullptr ) {
259+ output_names.emplace_back (n->Name ());
260+ }
261+ });
250262 special_op_desc.SetOutput (" Out" , output_names);
251263 special_op_desc.SetAttr (kCompilationKey , compilation_key);
252264 special_op_desc.Flush ();
@@ -362,8 +374,8 @@ void SearchAllSubgraphs(Graph* graph) {
362374 &cluster_internals);
363375 // Create a new subgraph according to the found cluster and
364376 // save it in CinnCompiler
365- std::string compilation_key = cinn_compiler->AddGraph (
366- CreateNewSubGraph ( cluster_set, cluster_internals, cluster_inputs));
377+ std::string compilation_key = cinn_compiler->AddGraph (CreateNewSubGraph (
378+ cluster_set, cluster_internals, cluster_inputs, cluster_outputs ));
367379 // Replace the found cluster to a new special op node
368380 ReplaceSubGraphWithSpecialOpNode (cluster_set, cluster_inputs,
369381 cluster_outputs, cluster_internals,
0 commit comments