-
Notifications
You must be signed in to change notification settings - Fork 334
[Enhancement] Support Layout/Fragment Reshape #1241
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
- Updated the `LayoutNode` class to include a new `Reshape` method, allowing for dynamic reshaping of layouts based on input shapes. - Enhanced the `OutputShape` method to provide better handling of cases where the analyzer cannot form an `IntervalSet`, implementing fallback mechanisms to ensure safe extents. - Refactored the `ReduceOpNode` to utilize `BufferRegion` for improved memory handling during reduction operations. - Added tests for reshaping functionality and layout transformations to ensure correctness and performance in various scenarios.
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
WalkthroughAdded Layout and Fragment Reshape methods and implementations; normalized Reduce arguments into BufferRegion and stored srcRegion_/dstRegion_; refactored layout inference to map data Vars to multiple Buffers and propagate layouts across aliases; bumped TVM submodule pointer; added reshape and reduce-after-reshape tests and adjusted reduce intrinsics to use buffer-region wrappers. Changes
Sequence Diagram(s)sequenceDiagram
participant Caller
participant LayoutNode
participant Analyzer
participant Indexer
Caller->>LayoutNode: Reshape(new_shape, analyzer)
LayoutNode->>LayoutNode: if shapes equal → return self
alt shapes differ
LayoutNode->>Analyzer: validate product equality
Analyzer-->>LayoutNode: equality confirmed
LayoutNode->>Indexer: build flat-index → map new indices to old
Indexer-->>LayoutNode: substituted forward_index_
LayoutNode-->>Caller: return new Layout/Fragment (with thread-range updates if Fragment)
end
sequenceDiagram
participant ReduceCtor as ReduceOp ctor
participant Normalizer as NormalizeToBufferRegion
ReduceCtor->>Normalizer: normalize src arg
alt arg is BufferRegion
Normalizer-->>ReduceCtor: return same BufferRegion
else arg is BufferLoad (Ramp allowed)
Normalizer->>Normalizer: convert indices → Range(s)
Normalizer-->>ReduceCtor: return BufferRegion
else arg is tl.region call
Normalizer->>Normalizer: reconstruct BufferRegion via RegionOp
Normalizer-->>ReduceCtor: return BufferRegion
else unsupported
Normalizer-->>ReduceCtor: fatal error
end
ReduceCtor->>ReduceCtor: store srcRegion_, derive src
ReduceCtor->>Normalizer: normalize dst arg
Normalizer-->>ReduceCtor: return dstRegion_
ReduceCtor->>ReduceCtor: store dstRegion_, derive dst
sequenceDiagram
actor User
participant LayoutInference
participant BufferCollector as BufferUseDefCollector
participant LayoutNode
User->>LayoutInference: request layout inference
LayoutInference->>BufferCollector: collect buffers & aliases
BufferCollector-->>LayoutInference: Map Var → Array[Buffer]
LayoutInference->>LayoutNode: infer layout for representative buffer
LayoutNode-->>LayoutInference: return layout
LayoutInference->>LayoutInference: propagate layout to alias buffers
loop per alias buffer
LayoutInference->>LayoutNode: call Reshape(alias_layout, new_shape)
LayoutNode-->>LayoutInference: return reshaped layout
LayoutInference->>LayoutInference: enqueue alias for further inference if needed
end
LayoutInference-->>User: layouts finalized
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes
Possibly related PRs
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
📜 Recent review detailsConfiguration used: CodeRabbit UI Review profile: CHILL Plan: Pro 📒 Files selected for processing (1)
🧰 Additional context used🧬 Code graph analysis (1)src/layout/layout.cc (1)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (2)
src/transform/layout_inference.cc (2)
319-380: Consider extracting shape equality check into a helper.The finalization step correctly ensures all alias buffers have consistent layouts. However, the shape equality check (lines 338-349 and 360-371) is duplicated from the
propagate_aliaslambda (lines 129-139).Consider extracting this pattern into a helper method:
bool CheckShapeEquality(const Array<PrimExpr> &shape1, const Array<PrimExpr> &shape2, arith::Analyzer *analyzer) const { if (shape1.size() != shape2.size()) return false; for (size_t i = 0; i < shape1.size(); ++i) { if (!analyzer->CanProveEqual(shape1[i], shape2[i])) { return false; } } return true; }This would reduce duplication and make the code more maintainable.
638-696: Buffer collection from Load/Store is correct but contains duplication.Both
VisitExpr_(BufferLoadNode*)andVisitStmt_(BufferStoreNode*)correctly collect buffers and check for duplicates usingsame_as. The DLOG statements are helpful for debugging.The logic between BufferLoad and BufferStore visitors is nearly identical. Consider extracting a helper:
void CollectBuffer(const Buffer &buffer) { if (!buffer.defined() || !buffer->data.defined()) return; if (buffer_data_to_buffers_.count(buffer->data)) { auto buffers = buffer_data_to_buffers_[buffer->data]; bool found = std::find_if(buffers.begin(), buffers.end(), [&](const Buffer &buf) { return buf.same_as(buffer); }) != buffers.end(); if (!found) { buffers.push_back(buffer); buffer_data_to_buffers_.Set(buffer->data, buffers); DLOG(INFO) << "[LayoutInference] added buffer " << buffer << " data = " << buffer->data.get(); } } else { buffer_data_to_buffers_.Set(buffer->data, {buffer}); DLOG(INFO) << "[LayoutInference] new buffer " << buffer << " data = " << buffer->data.get(); } }
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (8)
3rdparty/tvm(1 hunks)src/layout/layout.cc(2 hunks)src/layout/layout.h(2 hunks)src/op/reduce.cc(3 hunks)src/op/reduce.h(2 hunks)src/transform/layout_inference.cc(11 hunks)testing/python/language/test_tilelang_language_reshape.py(2 hunks)tilelang/language/reduce.py(5 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-09-12T09:47:46.474Z
Learnt from: kurisu6912
Repo: tile-ai/tilelang PR: 794
File: tilelang/transform/add_bufstore_wrapper.py:30-33
Timestamp: 2025-09-12T09:47:46.474Z
Learning: In TVM's PyStmtExprMutator, visit_block_ methods typically call super().visit_block_(op) to process child nodes and update internal state, but return the original op when the block itself doesn't need transformation. The pattern `return op` is correct for blocks that serve as containers where mutations happen at deeper levels.
Applied to files:
src/transform/layout_inference.cc
🧬 Code graph analysis (7)
src/op/reduce.h (1)
tilelang/language/fill.py (1)
clear(50-74)
tilelang/language/reduce.py (1)
tilelang/language/utils.py (1)
buffer_to_tile_region(30-42)
src/layout/layout.h (1)
src/layout/layout.cc (6)
Layout(58-70)Layout(72-75)Reshape(300-369)Reshape(300-301)Reshape(371-442)Reshape(371-372)
testing/python/language/test_tilelang_language_reshape.py (4)
tilelang/language/allocate.py (2)
alloc_shared(27-42)alloc_fragment(59-70)tilelang/language/customize.py (1)
reshape(39-49)tilelang/profiler/__init__.py (1)
assert_allclose(77-146)tilelang/intrinsics/mma_layout.py (1)
make_mma_swizzle_layout(210-223)
src/layout/layout.cc (1)
src/layout/layout.h (1)
Inverse(87-145)
src/transform/layout_inference.cc (3)
src/op/operator.cc (4)
ParseOperator(32-41)ParseOperator(32-32)ParseOperator(55-61)ParseOperator(55-55)src/transform/multi_version_buffer_rewriter.cc (2)
call(443-473)call(443-444)src/transform/storage_access.cc (2)
VisitStmt_(64-93)VisitStmt_(64-64)
src/op/reduce.cc (2)
src/op/gemm_py.cc (2)
NormalizeToBufferRegion(25-72)NormalizeToBufferRegion(25-26)src/op/gemm.cc (2)
NormalizeToBufferRegion(53-100)NormalizeToBufferRegion(53-54)
🔇 Additional comments (22)
3rdparty/tvm (1)
1-1: Clarify necessity and rationale for the submodule update.The PR couples a TVM submodule update with internal reshape and layout inference changes. While the enriched summary suggests the update may provide needed API compatibility, the specific reasons for updating TVM are not clear from the PR context.
To ensure this change is well-justified and safe:
- Document the rationale: Why is this TVM commit required for the reshape functionality? Are there specific APIs or bug fixes needed?
- Consider separation: If possible, submodule updates should be in separate commits to simplify reviews and enable independent rollbacks.
- Verify the commit: Confirm that the new commit (
1b54bb0148381ff990e72625d10784fedd369768) exists in the official TVM repository and is stable/released.Please provide clarity on whether this update is a hard requirement for the reshape feature or if it can be deferred to a separate PR.
src/transform/layout_inference.cc (7)
120-161: LGTM with observation: Alias propagation logic is well-structured.The
propagate_aliaslambda correctly handles shape mismatches viaReshapeand ensures layout consistency across buffers sharing the same storage variable. The shape equality check is thorough, and the enqueueing logic properly updates the BFS queue for affected buffers.One observation: The lambda captures
analyzer_by reference from the enclosing class, which is safe here since the lambda is invoked immediately within the same scope.
197-198: Consistent alias propagation across all update paths.The calls to
propagate_aliasare correctly placed after updatinglayout_mapto ensure aliases are synchronized whenever layouts change. This maintains consistency across all code paths (fragment containment, equality checks, and new insertions).Also applies to: 207-213
440-450: GetBufferMap usage is appropriate but mark for future refactoring.The
GetBufferMap()helper returns the first buffer for each storage variable, which is acceptable for backward compatibility with APIs expecting a single buffer per variable. The TODO comment correctly indicates this should be phased out in the future.Ensure all callers of
GetBufferMap()understand they're getting only the first buffer from potentially multiple aliases.
458-458: Buffer access resolution correctly handles multi-buffer scenario.The updates to
ParseOperator(line 458) andgetBufferFromAccessPtr(lines 514-538) properly adapt to the newbuffer_data_to_buffers_structure. Returning the first buffer is appropriate when a single buffer is needed for API compatibility.The warning message at lines 522-523 is helpful for debugging unexpected argument types.
Also applies to: 514-538
574-625: Block annotation processing correctly handles aliased buffers.The update to
VisitStmt_(BlockNode*)properly:
- Collects buffers into
buffer_data_to_buffers_before processing annotations- Applies annotated layouts to all buffers sharing a storage variable
- Reshapes layouts when buffer shapes differ from annotation shapes
The comment at lines 585-587 helpfully clarifies the visit order. The shape equality check here is part of the previously noted duplication pattern.
747-767: Alias-aware component grouping is well-implemented.The additional union logic (lines 747-767) correctly handles the case where multiple
Bufferobjects share the same storage variable. By collecting and unioning all inference indices across aliased buffers, the code ensures that operations on any alias are treated as part of the same connected component.The use of
std::sortfollowed bystd::uniqueis the standard idiom for deduplication.
698-698: Core structural change enables alias tracking.The change from
Map<Var, Buffer>toMap<Var, Array<Buffer>>is the foundational modification that enables tracking multiple buffers per storage variable. All other changes in this file support and utilize this new structure.tilelang/language/reduce.py (2)
6-6: Import addition is appropriate.The import of
buffer_to_tile_regionfromtilelang.language.utilsis correctly placed and necessary for the subsequent changes.
52-60: Consistent usage of buffer_to_tile_region across all reduction pathways.The replacement of direct buffer access pointers with
buffer_to_tile_regionwrapper calls is applied consistently across all four reduction scenarios:
- shared → shared (lines 55-56)
- shared → fragment (lines 70-71)
- fragment → shared (lines 83-84)
- fragment → fragment (lines 94-95)
This aligns with the region-aware approach introduced in the C++ reduce operator (
src/op/reduce.cc), ensuring proper tile-region addressing for thetl.reduceintrinsic.Also applies to: 67-75, 80-89, 91-99
src/layout/layout.h (1)
45-46: Reshape method declarations are properly structured.The virtual
Reshapemethod added toLayoutNode(lines 45-46) follows the existing pattern for other virtual methods likeInverseandInverseWithLevel. The non-virtual declaration inFragmentNode(line 89) is appropriate sinceFragmentNodeis marked as final.Both methods accept an
arith::Analyzer*parameter for shape product verification, which is consistent with other layout transformation methods.Also applies to: 89-89
src/op/reduce.h (2)
85-90: Region members are appropriately added and exposed.The addition of
BufferRegion srcRegion_anddstRegion_members (lines 85-86) properly captures the original regions used to construct the reduce operation. The explanatory comment is helpful, and the naming follows existing conventions (trailing underscore for private/internal members).The existing members
dim,type, andclearretain their semantics.
99-100: Reflection bindings correctly expose region members.The read-only reflection bindings for
srcRegionanddstRegion(lines 99-100) properly expose these members to the Python API, maintaining consistency with the existing reflection pattern forsrcanddst.testing/python/language/test_tilelang_language_reshape.py (4)
4-4: Torch import is appropriate for reference implementations.The addition of
import torch(line 4) is necessary for the new test reference programs that usetorch.maxand tensor reshaping for validation.
133-174: Fragment reshape test provides good coverage.The
reshape_fragment_testandrun_reshape_fragmentfunctions (lines 133-174) properly test reshaping of fragment buffers, including:
- Allocation of shared and fragment memory
- Copy operations between memory scopes
- Reshape of fragment buffers
- Validation against torch reference
The test correctly exercises the newly added reshape functionality for fragment layouts.
177-218: Layout transform reshape test validates swizzled layouts.The
reshape_layout_transform_sharedtest (lines 177-218) adds important coverage for reshaping buffers with annotated swizzled layouts. The use ofmake_mma_swizzle_layoutensures that reshape works correctly with non-trivial layout transformations.
221-262: Reduce after reshape test validates end-to-end integration.The
reduce_after_reshape_test(lines 221-262) provides crucial integration testing by combining reshape with reduction operations:
- Reshape a 1D fragment to 2D
- Apply reduce_max along a dimension
- Validate against
torch.maxreferenceThis test ensures the region-aware reduction changes work correctly with reshaped buffers, validating the broader PR objectives.
src/op/reduce.cc (3)
17-17: Region header inclusion is necessary.The addition of
#include "region.h"(line 17) is required for theRegionOpusage inNormalizeToBufferRegion.
25-64: NormalizeToBufferRegion implementation is correct and follows existing patterns.The
NormalizeToBufferRegionfunction (lines 25-64) properly handles three cases:
- BufferRegion: Returns directly (lines 30-32)
- BufferLoad: Converts indices to ranges, handling both
Rampnodes (stride-1 vectorization) and scalar indices (lines 36-51)- tl.region calls: Reconstructs BufferRegion via RegionOp (lines 56-59)
The implementation mirrors similar functions in
src/op/gemm.ccandsrc/op/gemm_py.cc(see relevant snippets), ensuring consistency across operators. The error message (line 62) is clear.Note: The function does not handle
builtin.tvm_access_ptr()unlike the GEMM variants, which may be intentional for reduce operations.
66-78: ReduceOp constructor correctly adopts region-aware approach.The updated constructor (lines 66-78) properly:
- Normalizes source and destination arguments to BufferRegion (lines 69-70)
- Stores the normalized regions in
srcRegion_anddstRegion_- Derives
srcanddstbuffers from the regions (lines 71-72)- Preserves existing logic for
dim,type, andclear(lines 73-76)This change enables reduce operations to work uniformly with different argument forms (BufferRegion, BufferLoad, tl.region calls), which is essential for reshape support.
src/layout/layout.cc (2)
300-369: LayoutNode::Reshape implementation is correct and well-structured.The
LayoutNode::Reshapeimplementation (lines 300-369) correctly:
- Fast path: Returns self if shapes are equal (lines 302-305)
- Shape validation: Verifies total element count is preserved using analyzer (lines 307-324)
- Index remapping:
- Computes flat index from new shape indices (lines 332-341)
- Converts flat index back to original shape indices (lines 347-356)
- Substitutes original indices into forward_index_ (lines 358-367)
The algorithm properly handles row-major layout assumptions and preserves the forward mapping semantics.
371-442: FragmentNode::Reshape correctly extends LayoutNode::Reshape for thread-aware layouts.The
FragmentNode::Reshapeimplementation (lines 371-442) properly extends the base reshape logic:
- Includes the fast path and shape validation (lines 373-393)
- Uses the same flat-index remapping approach (lines 395-418)
- Additionally handles thread mapping: Substitutes into
forward_thread_(lines 430-434)- Preserves thread_range: Re-binds thread_range if defined (lines 438-440)
The implementation correctly maintains all fragment-specific properties (thread mapping, replication extent, thread range) while reshaping the spatial layout.
| if (arith::is_neg_inf(ist.min()) && arith::is_pos_inf(ist.max())) { | ||
| // X-OR Expression | ||
| ret.Set(i, input_size_[i]); | ||
| // Analyzer couldn't form an IntervalSet (e.g. bitwise ops). | ||
| // Fall back to ConstIntBound to derive a safe extent. | ||
| auto cib = analyzer.const_int_bound(forward_index_[i]); | ||
| if (cib->min_value != arith::ConstIntBound::kNegInf && | ||
| cib->max_value != arith::ConstIntBound::kPosInf && | ||
| cib->min_value >= 0) { | ||
| // extent = max - min + 1, using 64-bit integer literal | ||
| ret.Set(i, Integer(cib->max_value - cib->min_value + 1)); | ||
| } else { | ||
| // Last-resort conservative fallback to avoid OOB/crash | ||
| // Prefer to keep dimension from known input_size_ if available. | ||
| if (i < input_size_.size()) { | ||
| ret.Set(i, input_size_[i]); | ||
| } else { | ||
| ret.Set(i, Integer(1)); | ||
| } | ||
| } | ||
| } else { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
Improved fallback logic prevents crashes but may be overly conservative.
The updated OutputShape() fallback (lines 104-122) improves robustness when the analyzer cannot form an IntervalSet:
Improvements:
- Uses
const_int_boundto derive safe extents when possible (lines 107-112) - Falls back to
input_size_[i]or1as last resort (lines 115-120) - Preserves the
CHECK(is_one(ist.min()))assertion for the normal path (line 123)
Concern: The fallback to input_size_[i] or 1 (lines 116-119) is conservative but may produce incorrect shapes for complex layouts. Consider whether a runtime error or more sophisticated analysis would be preferable to silently using potentially incorrect dimensions.
Verify that the fallback behavior is acceptable for all layout types, especially when forward_index_[i] involves bitwise operations or other constructs the analyzer cannot handle.
🏁 Script executed:
#!/bin/bash
# Search for layouts that might trigger the fallback path
rg -n -C3 'OutputShape|forward_index_' --type=cpp | grep -E '(bitwise|&|\||\^)'Length of output: 2812
🏁 Script executed:
#!/bin/bash
# Search for forward_index_ definitions and assignments to see what expressions are stored
rg -n 'forward_index_\s*=' --type=cpp -A 2 | head -100Length of output: 539
🏁 Script executed:
#!/bin/bash
# Look for actual bitwise operations in layout files
rg -nP '\s+\d+\s*(&|\||\^|<<|>>)\s*\d+' --type=cpp src/layout/Length of output: 0
🏁 Script executed:
#!/bin/bash
# Find test files related to layout or OutputShape
fd -e cc -e h 'test|layout' --type f | grep -i test | head -20Length of output: 0
🏁 Script executed:
#!/bin/bash
# Search for calls to OutputShape to see how result is used
rg -n 'OutputShape\(\)' --type=cpp -A 2 -B 1 | head -80Length of output: 4699
🏁 Script executed:
#!/bin/bash
# Find input_size_ definition and its type
rg -n 'input_size_' --type=cpp -B 2 -A 2 | grep -E '(Array|vector|input_size_|PrimExpr)' | head -40Length of output: 2304
🏁 Script executed:
#!/bin/bash
# Check context around the ICHECK that uses OutputShape
rg -n 'as_const_int.*OutputShape' --type=cpp -B 5 -A 5Length of output: 632
🏁 Script executed:
#!/bin/bash
# Search for input_size_ member variable declaration in header files
rg -n 'input_size_' --type=h -B 1 -A 1Length of output: 402
🏁 Script executed:
#!/bin/bash
# Look for other places where OutputShape result is used without as_const_int conversion
rg -n 'OutputShape\(\)\[' --type=cpp -B 2 -A 2Length of output: 1455
🏁 Script executed:
#!/bin/bash
# Search for Layout/Fragment construction to see what input_size values are passed
rg -n 'Layout\(|Fragment\(' --type=cpp -A 2 | head -60Length of output: 3885
🏁 Script executed:
#!/bin/bash
# Check if input_size_ elements are checked/validated to be const before being returned
rg -n 'input_size_\[' --type=cpp -B 3 | head -50Length of output: 848
🏁 Script executed:
#!/bin/bash
# Look at the full context of layout_inference.cc lines 836-842 to understand the crash scenario
rg -n 'frag_reg_num' --type=cpp -B 8 -A 5Length of output: 1497
Fix the OutputShape() fallback to ensure it always returns const integers, or validate input_size_ at construction time.
The fallback logic (lines 104-122) has a critical bug: it returns input_size_[i] (line 117), which may contain non-const PrimExpr since input_size_ is populated from sources like buffer->shape (see layout_reducer.cc:216-217). This violates the implicit contract that OutputShape() returns const dimensions.
Specifically, layout_inference.cc:839 will crash with an ICHECK failure:
for (auto i : frag.value()->OutputShape()) {
auto pci = as_const_int(i);
ICHECK(pci != nullptr); // crashes if OutputShape() returns symbolic expr
frag_reg_num *= *pci;
}Fix: Either (1) validate that input_size_ elements are const at construction time, (2) ensure the fallback path only returns proven const dimensions (e.g., always Integer(1) instead of input_size_[i]), or (3) add a precondition guard preventing this fallback when input_size_ contains symbolic expressions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (1)
src/layout/layout.cc (1)
104-122: Critical issue from previous review remains unresolved.The fallback at line 117 still returns
input_size_[i], which may contain non-constPrimExprvalues (e.g., frombuffer->shape). This violates the implicit contract thatOutputShape()returns const dimensions and will cause theICHECK(pci != nullptr)atlayout_inference.cc:839to fail whenas_const_int(i)returns nullptr for symbolic expressions.The previous review recommended one of three fixes:
- Validate that
input_size_elements are const at construction time- Always return proven const dimensions in the fallback (e.g.,
Integer(1)instead ofinput_size_[i])- Add a precondition guard preventing this fallback when
input_size_contains symbolic expressionsConsider applying this conservative fix to eliminate the crash risk:
} else { // Last-resort conservative fallback to avoid OOB/crash - // Prefer to keep dimension from known input_size_ if available. - if (i < input_size_.size()) { - ret.Set(i, input_size_[i]); - } else { - ret.Set(i, Integer(1)); - } + // Fallback to 1 since we cannot guarantee input_size_[i] is const. + ret.Set(i, Integer(1)); }Alternatively, validate at construction time that all elements of
input_size_are const integers when the layout will be used in contexts requiring const dimensions.
🧹 Nitpick comments (5)
src/layout/layout.cc (2)
358-368: Consider optimizing the substitution loop.The nested loop iterates over each forward expression and applies
Substituteonce per input dimension. This creates intermediate expression trees on each iteration, which could be inefficient for layouts with many dimensions.Apply this refactor to build the substitution map once and apply it to all expressions:
// Step 5. Substitute original indices into forward_index_ + Map<Var, PrimExpr> vmap; + for (size_t i = 0; i < InputShape().size(); ++i) { + vmap.Set(InputPlaceholder(i), original_indices[i]); + } + Array<PrimExpr> new_forward_index; for (const auto &fwd_expr : forward_index_) { - PrimExpr substituted = fwd_expr; - // Replace each InputPlaceholder(i) with original_indices[i] - for (size_t i = 0; i < InputShape().size(); ++i) { - substituted = - Substitute(substituted, {{InputPlaceholder(i), original_indices[i]}}); - } - new_forward_index.push_back(substituted); + new_forward_index.push_back(Substitute(fwd_expr, vmap)); }
421-435: Consider optimizing the substitution loops.Similar to the LayoutNode::Reshape implementation, the substitution is performed iteratively in nested loops (lines 425-427 for
forward_index_and lines 432-434 forforward_thread_). Building the substitution map once and applying it to all expressions would be more efficient.Apply this refactor:
// 4) Substitute old placeholders with expressions of new indices + Map<Var, PrimExpr> vmap; + for (size_t i = 0; i < InputShape().size(); ++i) { + vmap.Set(InputPlaceholder(i), orig_indices[i]); + } + Array<PrimExpr> new_forward_index; for (const auto &e : forward_index_) { - PrimExpr cur = e; - for (size_t i = 0; i < InputShape().size(); ++i) { - cur = Substitute(cur, {{InputPlaceholder(i), orig_indices[i]}}); - } - new_forward_index.push_back(cur); + new_forward_index.push_back(Substitute(e, vmap)); } - PrimExpr new_forward_thread = forward_thread_; - for (size_t i = 0; i < InputShape().size(); ++i) { - new_forward_thread = Substitute(new_forward_thread, - {{InputPlaceholder(i), orig_indices[i]}}); - } + PrimExpr new_forward_thread = Substitute(forward_thread_, vmap);src/transform/layout_inference.cc (3)
120-161: Consider extracting the reshape logic into a helper method to reduce duplication.The layout propagation and reshape logic in this lambda is duplicated in the final alias propagation pass (lines 319-357) and in
VisitStmt_(BlockNode)(lines 579-598). Extracting this into a private helper method would improve maintainability and reduce the risk of inconsistencies across these three locations.Example signature:
Layout ReshapeLayoutIfNeeded(const Layout& src_layout, const Buffer& target_buffer, arith::Analyzer* analyzer) const { bool shapes_equal = src_layout->InputShape().size() == target_buffer->shape.size(); if (shapes_equal) { for (size_t i = 0; i < src_layout->InputShape().size(); ++i) { if (!analyzer->CanProveEqual(src_layout->InputShape()[i], target_buffer->shape[i])) { shapes_equal = false; break; } } } return shapes_equal ? src_layout : src_layout->Reshape(target_buffer->shape, analyzer); }
615-643: Consider optimizing the duplicate check for better performance.The linear search for duplicates (lines 621-634) is O(n) for each buffer insertion, which could become expensive with many alias buffers. While the current implementation is correct, consider using a helper map or set to track buffer pointers for O(1) duplicate detection if performance becomes an issue.
Example optimization:
// Add as a class member: std::unordered_set<const BufferNode*> seen_buffers_; // In the visitor: if (seen_buffers_.insert(op->buffer.get()).second) { // Buffer was not seen before, add it if (buffer_data_to_buffers_.count(op->buffer->data)) { auto buffers = buffer_data_to_buffers_[op->buffer->data]; buffers.push_back(op->buffer); buffer_data_to_buffers_.Set(op->buffer->data, buffers); } else { buffer_data_to_buffers_.Set(op->buffer->data, {op->buffer}); } }
645-673: Consider extracting the common buffer collection logic.The
VisitStmt_(BufferStoreNode)implementation is nearly identical toVisitExpr_(BufferLoadNode)(lines 615-643). Consider extracting a private helper method to handle buffer collection and duplicate checking to reduce code duplication.Example helper:
void CollectBuffer(const Buffer& buffer, const char* source) { if (!buffer.defined() || !buffer->data.defined()) return; if (buffer_data_to_buffers_.count(buffer->data)) { auto buffers = buffer_data_to_buffers_[buffer->data]; bool found = false; for (const auto &buf : buffers) { if (buf.same_as(buffer)) { found = true; break; } } if (!found) { buffers.push_back(buffer); buffer_data_to_buffers_.Set(buffer->data, buffers); DLOG(INFO) << "[LayoutInference] " << source << ": added buffer " << buffer; } } else { buffer_data_to_buffers_.Set(buffer->data, {buffer}); DLOG(INFO) << "[LayoutInference] " << source << ": new buffer " << buffer; } }
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
src/layout/layout.cc(2 hunks)src/transform/layout_inference.cc(11 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-09-12T09:47:46.474Z
Learnt from: kurisu6912
Repo: tile-ai/tilelang PR: 794
File: tilelang/transform/add_bufstore_wrapper.py:30-33
Timestamp: 2025-09-12T09:47:46.474Z
Learning: In TVM's PyStmtExprMutator, visit_block_ methods typically call super().visit_block_(op) to process child nodes and update internal state, but return the original op when the block itself doesn't need transformation. The pattern `return op` is correct for blocks that serve as containers where mutations happen at deeper levels.
Applied to files:
src/transform/layout_inference.cc
🧬 Code graph analysis (2)
src/layout/layout.cc (2)
tilelang/language/tir/op.py (3)
ret(1866-1879)floordiv(3180-3199)floormod(3202-3221)src/layout/layout.h (1)
Inverse(87-145)
src/transform/layout_inference.cc (3)
src/op/operator.cc (4)
ParseOperator(32-41)ParseOperator(32-32)ParseOperator(55-61)ParseOperator(55-55)src/transform/multi_version_buffer_rewriter.cc (2)
call(443-473)call(443-444)src/transform/storage_access.cc (2)
VisitStmt_(64-93)VisitStmt_(64-64)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
- GitHub Check: Test for Python 3.12 with ROCm-6.3 (on self-hosted-amd)
🔇 Additional comments (8)
src/layout/layout.cc (1)
300-443: Reshape implementations are mathematically sound and well-structured.Both
LayoutNode::ReshapeandFragmentNode::Reshapefollow the correct algorithm:
- Validate that total element count is preserved
- Compute flat index from new shape indices
- Recover original shape indices from flat index
- Substitute into forward mappings
The implementations properly handle:
- Fast-path optimization for identical shapes
- Analyzer parameter with fallback to local analyzer
- Helpful error messages with diagnostic information
- Fragment-specific concerns (forward_thread_, thread_range_)
src/transform/layout_inference.cc (7)
14-14: LGTM: Standard library include for sorting operations.The
<algorithm>header is correctly added to supportstd::sortandstd::uniqueoperations used in the enhancedInferInFreeModelogic (lines 737-738).
319-357: LGTM: Finalization step ensures layout completeness for alias buffers.This final propagation pass (step 4) correctly ensures that all buffers sharing the same storage Var have consistent layouts, addressing a key objective of this PR. The logic properly handles reshaping when buffer shapes differ.
417-427: LGTM: Transitional helper for backward compatibility.The
GetBufferMap()helper appropriately provides a representative buffer map for existing APIs that expectMap<Var, Buffer>. The TODO comment correctly identifies this as temporary until the codebase fully migrates to multi-buffer support.
491-515: LGTM: Correctly adapted for multi-buffer support.The method now properly retrieves the first buffer from the array of buffers sharing the same storage Var, consistent with the
GetBufferMap()strategy. The improved logging fromDLOGtoLOG(WARNING)provides better visibility for this edge case.
562-601: LGTM: Correctly applies annotated layouts to all alias buffers.The implementation properly iterates over all buffers sharing the same storage Var and applies the annotated layout with appropriate reshaping. The order of operations (visiting the body first to collect all buffers, then applying layouts) is correct.
724-744: LGTM: Enhanced component merging correctly handles buffer aliases.This addition properly unions operators that share buffers with the same underlying storage Var, which is essential for handling reshape and alias scenarios. The use of
std::sortandstd::uniqueto deduplicate the merged indices is correct and efficient.
140-142: No action required—Reshape error handling is properly implemented.After examining the
Reshape()implementations and all call sites, the concerns are addressed:
Edge cases are handled: Both
LayoutNode::Reshape()andFragmentNode::Reshape()useICHECK()to validate that the product of the input shape equals the product of the new shape. If shapes are incompatible or have different total sizes, the check fails and terminates cleanly.Exception handling is not needed:
Reshape()does not throw exceptions. Instead, it usesICHECK()assertions for contract validation. This is appropriate for an internal API—validation failures are logic errors, not runtime exceptions requiring try-catch handling.Analyzer's inability to prove equality is safe: If the analyzer cannot prove shape product equality, the
ICHECK()will fail (terminating the program) rather than proceeding with an incorrect reshape. This conservative approach prevents silent data corruption.All three call sites (lines 142, 353, 596 in layout_inference.cc) properly pass the analyzer and handle the returned
Layoutobject without needing exception handlers.
* Update layout handling and introduce reshape functionality - Updated the `LayoutNode` class to include a new `Reshape` method, allowing for dynamic reshaping of layouts based on input shapes. - Enhanced the `OutputShape` method to provide better handling of cases where the analyzer cannot form an `IntervalSet`, implementing fallback mechanisms to ensure safe extents. - Refactored the `ReduceOpNode` to utilize `BufferRegion` for improved memory handling during reduction operations. - Added tests for reshaping functionality and layout transformations to ensure correctness and performance in various scenarios. * lint fix * Revert tvm submodule pointer to 1815c3e0b6ec4ead36370bbd1562025d8529017c; keep src unchanged * Update tvm submodule to commit f0bbd3bf741413c35c389ba5dedd5be206000ad1 * Update tvm submodule to commit f0bbd3bf741413c35c389ba5dedd5be206000ad1 * remove useless prove * remove comment --------- Co-authored-by: tilelang-bot <bot@tilelang>
This pull request introduces significant improvements to TVM's layout inference and buffer aliasing logic, with a focus on more robust handling of buffer layouts, enhanced support for buffer aliasing, and new reshape capabilities for layouts. The changes ensure that all buffers sharing the same underlying storage variable (Var) have consistent and correct layouts, even if their shapes differ, by propagating and reshaping layouts as needed. Additionally, new methods are added to support reshaping
LayoutandFragmentobjects, and the handling of buffer regions in reduction operations is generalized.Enhancements to Layout Inference and Buffer Aliasing:
Var) and propagate inferred layouts (reshaping if needed) to ensure consistency across all alias buffers. This includes a new finalization step after inference to guarantee all alias buffers have a layout. [1] [2] [3] [4] [5]buffer_data_to_buffer_map with abuffer_data_to_buffers_map to handle multiple buffers per storage variable, and updated all relevant code to use this new structure. [1] [2]New Layout Reshape Capabilities:
Reshapemethod to theLayoutNodeandFragmentNodeclasses, with implementations that allow reshaping a layout to a new shape while preserving the total number of elements. This is used to propagate layouts to alias buffers with different shapes. [1] [2] [3]Generalization of Buffer Region Handling in Reduce Operators:
ReduceOpconstructor to accept and normalize various region types (includingBufferRegion,BufferLoad, andtl.regioncalls) for source and destination, storing the original regions for further use. [1] [2] [3]Improvements to Layout Output Shape and Inference:
LayoutNode::OutputShape()to handle cases where the analyzer cannot form an interval set, ensuring safe extents and avoiding out-of-bounds errors.Submodule Update:
3rdparty/tvmsubmodule to a newer commit.Summary by CodeRabbit
New Features
Bug Fixes
Tests
Chores