@@ -12,7 +12,7 @@ void inst_combine_pass::visit_store(lir::store* s) {
1212 //
1313 // (
1414 // ssa_temp_0 = a,
15- // b = ssa_temp_1 ,
15+ // ssa_temp_1 = b ,
1616 // call(ssa_temp_2, ssa_temp_0, ssa_temp_1)
1717 // )
1818 //
@@ -86,26 +86,67 @@ void inst_combine_pass::visit_compare(lir::compare* c) {
8686 }
8787}
8888
89+ void inst_combine_pass::visit_call (lir::call* c) {
90+ if (c->get_func_kind () != lir::call::kind::key_cmp) {
91+ return ;
92+ }
93+ if (c->get_function_name () != " key_eq" ) {
94+ return ;
95+ }
96+
97+ const auto & left = c->get_arguments ()[0 ];
98+ const auto & right = c->get_arguments ()[1 ];
99+
100+ // record this case:
101+ //
102+ // a.key_eq(b.getParent())
103+ // -->
104+ // (
105+ // getParent(ssa_temp_0, b),
106+ // a = ssa_temp_0
107+ // )
108+ //
109+ // and optimize this case to:
110+ //
111+ // getParent(a, b)
112+ //
113+ if (left.kind ==lir::inst_value_kind::variable &&
114+ right.kind ==lir::inst_value_kind::variable) {
115+ variable_reference_graph[left.content ].insert ({right.content , c});
116+ variable_reference_graph[right.content ].insert ({left.content , c});
117+ }
118+ }
119+
89120bool inst_combine_pass::run () {
90- for (auto impl : ctx->rule_impls ) {
91- scan (impl);
92- inst_elimination_worker ().copy (impl);
121+ for (auto impl : ctx->rule_impls ) {
122+ run_on_single_impl (impl);
93123 }
94- for (auto impl : ctx->database_get_table ) {
95- scan (impl);
96- inst_elimination_worker ().copy (impl);
124+ for (auto impl : ctx->database_get_table ) {
125+ run_on_single_impl (impl);
97126 }
98- for (auto impl : ctx->schema_get_field ) {
99- scan (impl);
100- inst_elimination_worker ().copy (impl);
127+ for (auto impl : ctx->schema_get_field ) {
128+ run_on_single_impl (impl);
101129 }
102- for (auto impl : ctx->schema_data_constraint_impls ) {
103- scan (impl);
104- inst_elimination_worker ().copy (impl);
130+ for (auto impl : ctx->schema_data_constraint_impls ) {
131+ run_on_single_impl (impl);
105132 }
106133 return true ;
107134}
108135
136+ void inst_combine_pass::run_on_single_impl (souffle_rule_impl* b) {
137+ auto worker = inst_elimination_worker ();
138+ size_t pass_run_count = 0 ;
139+ const size_t max_pass_run_count = 16 ;
140+ scan (b);
141+ worker.copy (b);
142+ ++ pass_run_count;
143+ while (worker.get_eliminated_count () && pass_run_count < max_pass_run_count) {
144+ scan (b);
145+ worker.copy (b);
146+ ++ pass_run_count;
147+ }
148+ }
149+
109150void inst_combine_pass::scan (souffle_rule_impl* b) {
110151 variable_reference_graph.clear ();
111152 b->get_block ()->accept (this );
@@ -265,6 +306,7 @@ void inst_elimination_worker::visit_block(lir::block* node) {
265306 for (auto i : node->get_content ()) {
266307 // skip eliminated instruction
267308 if (i->get_flag_eliminated ()) {
309+ ++ eliminated_count;
268310 continue ;
269311 }
270312
@@ -338,6 +380,8 @@ void inst_elimination_worker::visit_aggregator(lir::aggregator* node) {
338380}
339381
340382void inst_elimination_worker::copy (souffle_rule_impl* impl) {
383+ eliminated_count = 0 ;
384+ blk.clear ();
341385 auto impl_blk = new lir::block (impl->get_block ()->get_location ());
342386
343387 blk.push_back (impl_blk);
@@ -354,4 +398,69 @@ void inst_elimination_worker::copy(souffle_rule_impl* impl) {
354398 delete impl_blk;
355399}
356400
401+ void replace_find_call::visit_block (lir::block* node) {
402+ bool has_find_call = false ;
403+ for (auto i : node->get_content ()) {
404+ if (i->get_kind () != lir::inst_kind::inst_call) {
405+ continue ;
406+ }
407+ auto call = reinterpret_cast <lir::call*>(i);
408+ if (call->get_func_kind () == lir::call::kind::find &&
409+ call->get_function_name () == " find" ) {
410+ has_find_call = true ;
411+ break ;
412+ }
413+ }
414+
415+ if (has_find_call) {
416+ std::vector<lir::inst*> new_content;
417+ for (auto i : node->get_content ()) {
418+ if (i->get_kind () != lir::inst_kind::inst_call) {
419+ new_content.push_back (i);
420+ continue ;
421+ }
422+
423+ auto call = reinterpret_cast <lir::call*>(i);
424+ if (call->get_func_kind () != lir::call::kind::find ||
425+ call->get_function_name () != " find" ) {
426+ new_content.push_back (i);
427+ continue ;
428+ }
429+
430+ auto dst = call->get_return ();
431+ auto arg0 = call->get_arguments ()[0 ];
432+ auto arg1 = call->get_arguments ()[1 ];
433+ auto new_block = new lir::block (call->get_location ());
434+ new_block->set_use_comma ();
435+ new_content.push_back (new_block);
436+
437+ new_block->add_new_content (new lir::store (arg0, dst, call->get_location ()));
438+ new_block->add_new_content (new lir::store (arg1, arg0, call->get_location ()));
439+
440+ delete i;
441+ }
442+ node->get_mutable_content ().swap (new_content);
443+ } else {
444+ for (auto i : node->get_content ()) {
445+ i->accept (this );
446+ }
447+ }
448+ }
449+
450+ bool replace_find_call::run () {
451+ for (auto impl : ctx->rule_impls ) {
452+ impl->get_block ()->accept (this );
453+ }
454+ for (auto impl : ctx->database_get_table ) {
455+ impl->get_block ()->accept (this );
456+ }
457+ for (auto impl : ctx->schema_get_field ) {
458+ impl->get_block ()->accept (this );
459+ }
460+ for (auto impl : ctx->schema_data_constraint_impls ) {
461+ impl->get_block ()->accept (this );
462+ }
463+ return true ;
464+ }
465+
357466}
0 commit comments