@@ -105,13 +105,15 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer {
105105 " required for layout inference." ;
106106
107107 // Run InferLayout
108- std::cerr << " [RunInferStep] working on " << cur_infer_id << std::endl;
109- auto updates = next->InferLayout (
110- LayoutInferArgs{target_, thread_bounds, layout_map}, level);
108+ DLOG (INFO) << " [RunInferStep] working on " << cur_infer_id << ' \n ' ;
109+ auto updates =
110+ next->InferLayout (LayoutInferArgs{target_, thread_bounds, layout_map,
111+ &analyzer_, buffer_oob},
112+ level);
111113 // Process the returned updates
112114 for (const auto &[buffer, layout] : updates) {
113115 DLOG (INFO) << " consider update " << buffer << " as "
114- << layout->DebugOutput () << std::endl ;
116+ << layout->DebugOutput () << ' \n ' ;
115117
116118 // Basic validity checks
117119 ICHECK (buffer.defined ()) << " InferLayout returned an undefined buffer." ;
@@ -142,7 +144,7 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer {
142144 inner_analyzer)) {
143145 layout_map.Set (buffer, layout);
144146 DLOG (INFO) << " layout broadcast from "
145- << src_layout->DebugOutput () << " , accepted" << std::endl ;
147+ << src_layout->DebugOutput () << " , accepted" << ' \n ' ;
146148 continue ;
147149 }
148150 }
@@ -154,7 +156,7 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer {
154156 } else {
155157 // Otherwise, update map
156158 layout_map.Set (buffer, layout);
157- DLOG (INFO) << " new layout accepted" << std::endl ;
159+ DLOG (INFO) << " new layout accepted" << ' \n ' ;
158160 if (!update_queue)
159161 continue ;
160162
@@ -214,9 +216,9 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer {
214216 << " Size mismatch: buffer_oob_vec_ and infer_list_ must match in "
215217 " length." ;
216218
217- std::cerr << " [InferLayout] all participating operators:" << std::endl ;
219+ DLOG (INFO) << " [InferLayout] all participating operators:" << ' \n ' ;
218220 for (int i = 0 ; i < infer_list_stmt_.size (); ++i) {
219- std::cerr << " op " << i << " :" << infer_list_stmt_[i] << std::endl ;
221+ DLOG (INFO) << " op " << i << " :" << infer_list_stmt_[i] << ' \n ' ;
220222 }
221223
222224 // If needed, you can also check that annotated_layout_map_ is not empty, or
@@ -480,11 +482,11 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer {
480482 void InferInFreeMode (LayoutMap &layout_map,
481483 const LayoutMap &strict_layout_map) {
482484
483- DLOG (INFO) << " Enforced layout maps:" << std::endl ;
485+ DLOG (INFO) << " Enforced layout maps:" << ' \n ' ;
484486 for (auto &&[k, v] : layout_map) {
485- DLOG (INFO) << " " << k << " : " << v->DebugOutput () << std::endl ;
487+ DLOG (INFO) << " " << k << " : " << v->DebugOutput () << ' \n ' ;
486488 }
487- DLOG (INFO) << std::endl ;
489+ DLOG (INFO) << ' \n ' ;
488490
489491 // Group operators into connected components
490492 UnionFind<int > uf;
@@ -522,59 +524,52 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer {
522524
523525 for (auto &&[root, members] : components) {
524526 DLOG (INFO) << " ======================= processing component " << root
525- << std::endl ;
527+ << ' \n ' ;
526528 decltype (infer_list_) best_infer_list;
527529 LayoutMap best_layout_map;
528530 int64_t min_reg_num = INT64_MAX;
529531 int min_reg_num_infer_root = -1 ;
532+
533+ // Try each member as the root of inference for this component
530534 for (int attempt_infer_root : members) {
531535 DLOG (INFO) << " ----------------------- try root " << attempt_infer_root
532- << std::endl ;
533- // backup infer_list_ in class member
536+ << ' \n ' ;
537+ // Backup the current infer_list_ state
534538 auto back_infer_list = BackupInferList ();
535- // create temporarily used layout_map, new handle so that it copies on
536- // write
539+ // Copy the current layout_map for temporary use
537540 LayoutMap tmp_layout_map = layout_map;
538- // infer from attempt_infer_root in free mode
539541 bool do_update = true ;
540542 try {
543+ // Run inference starting from attempt_infer_root
541544 RunInferStep (attempt_infer_root, InferLevel::kFree , true ,
542545 tmp_layout_map, strict_layout_map, q, in_queue);
543546 FinishInferQueue (InferLevel::kFree , tmp_layout_map, strict_layout_map,
544547 q, in_queue);
545- // Silly workaround: we have no clue if single root will iterate over
546- // the entire component, since the InferLayout implementations have
547- // complicated conditioning inside and we know nothing about it.
548- // This would constantly result in incomplete layouts for buffers in
549- // this component. Instead of trying all combinations of root
550- // selection order, we simply go through all other loops in order
551- // after the first search from attempt_infer_root.
548+
549+ // After the first search, run inference for all other members in
550+ // order
552551 for (int other_infer_root : members) {
553552 if (other_infer_root != attempt_infer_root) {
554553 RunInferStep (other_infer_root, InferLevel::kFree , true ,
555554 tmp_layout_map, strict_layout_map, q, in_queue);
556- // must also be kFree here to avoid conflicts.
557555 FinishInferQueue (InferLevel::kFree , tmp_layout_map,
558556 strict_layout_map, q, in_queue);
559557 }
560558 }
561- } catch (LayoutConflictException e) {
562- // such an order fails, try others
559+ } catch (const LayoutConflictException &e) {
563560 do_update = false ;
564561 DLOG (INFO) << " attempt failed due to LayoutConflictException "
565- << e.what () << std::endl;
566- } catch (NormalizeIterException e) {
567- // such an order encounters iterators that is not normalizable, try
568- // others e.g. i * 576 % 2048
562+ << e.what () << ' \n ' ;
563+ } catch (const NormalizeIterException &e) {
569564 do_update = false ;
570565 DLOG (INFO) << " attempt failed due to NormalizeIterException "
571- << e.what () << std::endl ;
566+ << e.what () << ' \n ' ;
572567 }
573568
574569 if (do_update) {
575- // compute total register number
570+ // Compute the total register number for this layout
576571 int64_t reg_num = 0 ;
577- for (auto & &[buffer, layout] : tmp_layout_map) {
572+ for (const auto &[buffer, layout] : tmp_layout_map) {
578573 if (auto frag = layout.as <Fragment>()) {
579574 int64_t frag_reg_num = 1 ;
580575 for (auto i : frag.value ()->OutputShape ()) {
@@ -585,24 +580,24 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer {
585580 reg_num += frag_reg_num;
586581 }
587582 }
588- // if it's any better, update the best_* storage
583+ // Update the best plan if this one uses fewer registers
589584 if (reg_num < min_reg_num) {
590- best_infer_list = std::move (infer_list_);
585+ best_infer_list =
586+ BackupInferList (); // Use backup to avoid moving out infer_list_
591587 best_layout_map = tmp_layout_map;
592588 min_reg_num = reg_num;
593589 min_reg_num_infer_root = attempt_infer_root;
594590 }
595591 }
596- // recover stateful infer_list_, head on next
592+ // Restore infer_list_ state for the next attempt
597593 infer_list_ = std::move (back_infer_list);
598594 }
599- ICHECK (min_reg_num < INT64_MAX)
600- << " no available layout found" << std::endl;
601- // now apply the best plan for this component
595+ ICHECK (min_reg_num < INT64_MAX) << " no available layout found" << ' \n ' ;
596+ // Apply the best plan for this component
602597 infer_list_ = std::move (best_infer_list);
603598 layout_map = best_layout_map;
604599 DLOG (INFO) << " [InferInFreeMode] Final selection is attempt_infer_root = "
605- << min_reg_num_infer_root << std::endl ;
600+ << min_reg_num_infer_root << ' \n ' ;
606601 }
607602 }
608603};
@@ -625,7 +620,7 @@ class LayoutInferencer : public IRMutatorWithAnalyzer {
625620 LayoutInferencer (const LayoutInferenceResult &result,
626621 bool skip_thread_partition, arith::Analyzer *analyzer)
627622 : arith::IRMutatorWithAnalyzer(analyzer), result_(result),
628- skip_thread_partition_ (skip_thread_partition) {};
623+ skip_thread_partition_ (skip_thread_partition){};
629624
630625 using arith::IRMutatorWithAnalyzer::IRMutatorWithAnalyzer;
631626
0 commit comments