Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/op/parallel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,6 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
}
}
});

if (read_source_buffer.defined() && allow_layout_propgate) {
loop_layout_ = compute_loop_layout_from_buffer(read_source_buffer);
}
Expand Down
22 changes: 20 additions & 2 deletions src/transform/layout_inference.cc
Original file line number Diff line number Diff line change
Expand Up @@ -719,7 +719,23 @@ class LayoutInferencer : public IRMutatorWithAnalyzer {
// A_local[i] = A_global[i]
// Here, A_local is a register-local buffer held independently by each
// thread, so explicit thread binding is not required.
//
bool store_into_local = false;
PostOrderVisit(root, [&](const ObjectRef &obj) {
if (const auto *store = obj.as<BufferStoreNode>()) {
if (store->buffer.scope() == "local") {
store_into_local = true;
}
// if the case is like:
// for i in T.Parallel(1024):
// A_local[i] = B_global[i]
// A_frag[i] = A_global[i]
// exception will be raise in Parallel::LayoutInference
}
});
// This check if for the loop that only manuplates "local" buffers,
Comment on lines +722 to +735
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Skip partition only when stores are exclusively register‑local; include local.fragment.

Current flag store_into_local becomes true on any local store, even if there are also non‑local stores. This over‑disables partitioning. It also ignores local.fragment, which is register‑local. Track both local and non‑local stores, then compute an exclusivity flag.

-      bool store_into_local = false;
-      PostOrderVisit(root, [&](const ObjectRef &obj) {
-        if (const auto *store = obj.as<BufferStoreNode>()) {
-          if (store->buffer.scope() == "local") {
-            store_into_local = true;
-          }
-          // if the case is like:
-          // for i in T.Parallel(1024):
-          //     A_local[i] = B_global[i]
-          //     A_frag[i] = A_global[i]
-          // exception will be raise in Parallel::LayoutInference
-        }
-      });
+      bool has_store_local = false;
+      bool has_store_non_local = false;
+      PostOrderVisit(root, [&](const ObjectRef &obj) {
+        if (const auto *store = obj.as<BufferStoreNode>()) {
+          const String scope = store->buffer.scope();
+          if (scope == "local" || scope == "local.fragment") {
+            has_store_local = true;
+          } else {
+            has_store_non_local = true;
+          }
+          // Mixed local/non-local stores should be handled by Parallel::LayoutInference.
+        }
+      });
+      const bool store_local_only = has_store_local && !has_store_non_local;

Optional: compute all loop-locality booleans (stores/loads) in a single PostOrderVisit to avoid two full walks.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
bool store_into_local = false;
PostOrderVisit(root, [&](const ObjectRef &obj) {
if (const auto *store = obj.as<BufferStoreNode>()) {
if (store->buffer.scope() == "local") {
store_into_local = true;
}
// if the case is like:
// for i in T.Parallel(1024):
// A_local[i] = B_global[i]
// A_frag[i] = A_global[i]
// exception will be raise in Parallel::LayoutInference
}
});
// This check if for the loop that only manuplates "local" buffers,
bool has_store_local = false;
bool has_store_non_local = false;
PostOrderVisit(root, [&](const ObjectRef &obj) {
if (const auto *store = obj.as<BufferStoreNode>()) {
const String scope = store->buffer.scope();
if (scope == "local" || scope == "local.fragment") {
has_store_local = true;
} else {
has_store_non_local = true;
}
// Mixed local/non-local stores should be handled by Parallel::LayoutInference.
}
});
const bool store_local_only = has_store_local && !has_store_non_local;
// This check if for the loop that only manuplates "local" buffers,

// for i in T.Parallel(1024):
// A_local[i] = B_local[i]
// Though this might be illegal
// We use PostOrderVisit to detect whether the loop only manuplates
// "local" buffers, which indicates register usage and justifies skipping
// thread binding.
Expand All @@ -738,7 +754,9 @@ class LayoutInferencer : public IRMutatorWithAnalyzer {

auto loop_layout = result_.for_map[root];
// FIXME: tell in-Parallel and out-of-Parallel `local`s apart
bool parallel_loop = !skip_thread_partition_ && !local_register_only;
// NOTE(lei): a bit ugly, we should rethink about this part in future.
bool parallel_loop =
!skip_thread_partition_ && !local_register_only && !store_into_local;

if (parallel_loop) {
for_node =
Expand Down