Skip to content

Conversation

@LeiWang1999
Copy link
Member

@LeiWang1999 LeiWang1999 commented Nov 12, 2025

This pull request introduces significant improvements to TVM's layout inference and buffer aliasing logic, with a focus on more robust handling of buffer layouts, enhanced support for buffer aliasing, and new reshape capabilities for layouts. The changes ensure that all buffers sharing the same underlying storage variable (Var) have consistent and correct layouts, even if their shapes differ, by propagating and reshaping layouts as needed. Additionally, new methods are added to support reshaping Layout and Fragment objects, and the handling of buffer regions in reduction operations is generalized.

Enhancements to Layout Inference and Buffer Aliasing:

  • Introduced logic to track all buffers sharing the same storage variable (Var) and propagate inferred layouts (reshaping if needed) to ensure consistency across all alias buffers. This includes a new finalization step after inference to guarantee all alias buffers have a layout. [1] [2] [3] [4] [5]
  • Replaced the single buffer_data_to_buffer_ map with a buffer_data_to_buffers_ map to handle multiple buffers per storage variable, and updated all relevant code to use this new structure. [1] [2]

New Layout Reshape Capabilities:

  • Added a virtual Reshape method to the LayoutNode and FragmentNode classes, with implementations that allow reshaping a layout to a new shape while preserving the total number of elements. This is used to propagate layouts to alias buffers with different shapes. [1] [2] [3]

Generalization of Buffer Region Handling in Reduce Operators:

  • Refactored the ReduceOp constructor to accept and normalize various region types (including BufferRegion, BufferLoad, and tl.region calls) for source and destination, storing the original regions for further use. [1] [2] [3]

Improvements to Layout Output Shape and Inference:

  • Improved the fallback logic in LayoutNode::OutputShape() to handle cases where the analyzer cannot form an interval set, ensuring safe extents and avoiding out-of-bounds errors.
  • Minor improvements and bug fixes in layout inference and reduction operator logic. [1] [2] [3]

Submodule Update:

  • Updated the 3rdparty/tvm submodule to a newer commit.

Summary by CodeRabbit

  • New Features

    • Reshape support for layouts and fragments that preserves element alignment
    • Reduce operations now handle explicit buffer regions for more uniform region semantics
    • Alias-aware layout propagation to keep layouts consistent across aliased buffers
  • Bug Fixes

    • Safer output-shape inference fallback to avoid out-of-bounds/crash scenarios
  • Tests

    • Added tests for reshape, layout-transform, and reduce-after-reshape (torch-based references)
  • Chores

    • Updated third-party submodule pointer (no functional changes)

- Updated the `LayoutNode` class to include a new `Reshape` method, allowing for dynamic reshaping of layouts based on input shapes.
- Enhanced the `OutputShape` method to provide better handling of cases where the analyzer cannot form an `IntervalSet`, implementing fallback mechanisms to ensure safe extents.
- Refactored the `ReduceOpNode` to utilize `BufferRegion` for improved memory handling during reduction operations.
- Added tests for reshaping functionality and layout transformations to ensure correctness and performance in various scenarios.
@github-actions
Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Nov 12, 2025

Walkthrough

Added Layout and Fragment Reshape methods and implementations; normalized Reduce arguments into BufferRegion and stored srcRegion_/dstRegion_; refactored layout inference to map data Vars to multiple Buffers and propagate layouts across aliases; bumped TVM submodule pointer; added reshape and reduce-after-reshape tests and adjusted reduce intrinsics to use buffer-region wrappers.

Changes

Cohort / File(s) Summary
Submodule Update
3rdparty/tvm
TVM submodule pointer advanced; no functional changes in this diff.
Layout Reshape Support
src/layout/layout.h, src/layout/layout.cc
Added Reshape(const Array<PrimExpr>&, arith::Analyzer*) to LayoutNode and FragmentNode; implemented reshape with fast-path equality, product validation (using analyzer), flat-index mapping and substitution, and fragment thread-range adjustments. Also adjusted OutputShape fallback to use const_int_bound defensively.
Reduce BufferRegion Normalization
src/op/reduce.h, src/op/reduce.cc
Added NormalizeToBufferRegion helper; ReduceOpNode now stores srcRegion_ and dstRegion_; constructor normalizes src/dst args into BufferRegion and derives src/dst from those regions; updated includes to use region.h.
Layout Inference Alias Propagation
src/transform/layout_inference.cc
Replaced Var→Buffer mapping with Var→Array (buffer_data_to_buffers_); added GetBufferMap() helper; propagate inferred layouts to alias buffers (reshape when shapes differ) and added a final alias-propagation pass; updated buffer lookup helpers and logging.
Tests & Reduce Wrapper
testing/python/language/test_tilelang_language_reshape.py, tilelang/language/reduce.py
Added reshape and reduce-after-reshape tests (fragment/shared/swizzle cases, uses torch reference); tests import torch; updated reduce intrinsics to obtain buffer regions via buffer_to_tile_region wrappers.

Sequence Diagram(s)

sequenceDiagram
    participant Caller
    participant LayoutNode
    participant Analyzer
    participant Indexer
    Caller->>LayoutNode: Reshape(new_shape, analyzer)
    LayoutNode->>LayoutNode: if shapes equal → return self
    alt shapes differ
        LayoutNode->>Analyzer: validate product equality
        Analyzer-->>LayoutNode: equality confirmed
        LayoutNode->>Indexer: build flat-index → map new indices to old
        Indexer-->>LayoutNode: substituted forward_index_
        LayoutNode-->>Caller: return new Layout/Fragment (with thread-range updates if Fragment)
    end
Loading
sequenceDiagram
    participant ReduceCtor as ReduceOp ctor
    participant Normalizer as NormalizeToBufferRegion
    ReduceCtor->>Normalizer: normalize src arg
    alt arg is BufferRegion
        Normalizer-->>ReduceCtor: return same BufferRegion
    else arg is BufferLoad (Ramp allowed)
        Normalizer->>Normalizer: convert indices → Range(s)
        Normalizer-->>ReduceCtor: return BufferRegion
    else arg is tl.region call
        Normalizer->>Normalizer: reconstruct BufferRegion via RegionOp
        Normalizer-->>ReduceCtor: return BufferRegion
    else unsupported
        Normalizer-->>ReduceCtor: fatal error
    end
    ReduceCtor->>ReduceCtor: store srcRegion_, derive src
    ReduceCtor->>Normalizer: normalize dst arg
    Normalizer-->>ReduceCtor: return dstRegion_
    ReduceCtor->>ReduceCtor: store dstRegion_, derive dst
Loading
sequenceDiagram
    actor User
    participant LayoutInference
    participant BufferCollector as BufferUseDefCollector
    participant LayoutNode
    User->>LayoutInference: request layout inference
    LayoutInference->>BufferCollector: collect buffers & aliases
    BufferCollector-->>LayoutInference: Map Var → Array[Buffer]
    LayoutInference->>LayoutNode: infer layout for representative buffer
    LayoutNode-->>LayoutInference: return layout
    LayoutInference->>LayoutInference: propagate layout to alias buffers
    loop per alias buffer
        LayoutInference->>LayoutNode: call Reshape(alias_layout, new_shape)
        LayoutNode-->>LayoutInference: return reshaped layout
        LayoutInference->>LayoutInference: enqueue alias for further inference if needed
    end
    LayoutInference-->>User: layouts finalized
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

  • Focus areas:
    • Flat-index computation, substitution and bounds handling in LayoutNode::Reshape and FragmentNode::Reshape.
    • Correct adjustment of forward_thread_ and thread_range_ in Fragment::Reshape.
    • Alias-propagation correctness and termination conditions in layout_inference.cc (cycles, enqueue logic).
    • NormalizeToBufferRegion handling of BufferLoad/Ramp and tl.region call shapes and failure modes.
    • Call sites updated to use Var→Array mapping and GetBufferMap().

Possibly related PRs

Poem

I hopped through indices, soft and fleet,
Flattened shapes beneath my feet,
Regions met and aliases too,
Threads aligned — the mapping true,
A rabbit’s reshape, quick and sweet! 🐰

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 8.51% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title '[Enhancement] Support Layout/Fragment Reshape' directly and accurately summarizes the main change of this PR: adding reshape functionality to Layout and Fragment classes.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

📜 Recent review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 53243be and 29fd82c.

📒 Files selected for processing (1)
  • src/layout/layout.cc (2 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/layout/layout.cc (1)
src/layout/layout.h (1)
  • Inverse (87-145)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
  • GitHub Check: Test for Python 3.12 with ROCm-6.3 (on self-hosted-amd)
  • GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (2)
src/transform/layout_inference.cc (2)

319-380: Consider extracting shape equality check into a helper.

The finalization step correctly ensures all alias buffers have consistent layouts. However, the shape equality check (lines 338-349 and 360-371) is duplicated from the propagate_alias lambda (lines 129-139).

Consider extracting this pattern into a helper method:

bool CheckShapeEquality(const Array<PrimExpr> &shape1, const Array<PrimExpr> &shape2, arith::Analyzer *analyzer) const {
  if (shape1.size() != shape2.size()) return false;
  for (size_t i = 0; i < shape1.size(); ++i) {
    if (!analyzer->CanProveEqual(shape1[i], shape2[i])) {
      return false;
    }
  }
  return true;
}

This would reduce duplication and make the code more maintainable.


638-696: Buffer collection from Load/Store is correct but contains duplication.

Both VisitExpr_(BufferLoadNode*) and VisitStmt_(BufferStoreNode*) correctly collect buffers and check for duplicates using same_as. The DLOG statements are helpful for debugging.

The logic between BufferLoad and BufferStore visitors is nearly identical. Consider extracting a helper:

void CollectBuffer(const Buffer &buffer) {
  if (!buffer.defined() || !buffer->data.defined()) return;
  
  if (buffer_data_to_buffers_.count(buffer->data)) {
    auto buffers = buffer_data_to_buffers_[buffer->data];
    bool found = std::find_if(buffers.begin(), buffers.end(),
                              [&](const Buffer &buf) { return buf.same_as(buffer); })
                 != buffers.end();
    if (!found) {
      buffers.push_back(buffer);
      buffer_data_to_buffers_.Set(buffer->data, buffers);
      DLOG(INFO) << "[LayoutInference] added buffer " << buffer 
                 << " data = " << buffer->data.get();
    }
  } else {
    buffer_data_to_buffers_.Set(buffer->data, {buffer});
    DLOG(INFO) << "[LayoutInference] new buffer " << buffer 
               << " data = " << buffer->data.get();
  }
}
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 30d8ded and e266482.

📒 Files selected for processing (8)
  • 3rdparty/tvm (1 hunks)
  • src/layout/layout.cc (2 hunks)
  • src/layout/layout.h (2 hunks)
  • src/op/reduce.cc (3 hunks)
  • src/op/reduce.h (2 hunks)
  • src/transform/layout_inference.cc (11 hunks)
  • testing/python/language/test_tilelang_language_reshape.py (2 hunks)
  • tilelang/language/reduce.py (5 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-09-12T09:47:46.474Z
Learnt from: kurisu6912
Repo: tile-ai/tilelang PR: 794
File: tilelang/transform/add_bufstore_wrapper.py:30-33
Timestamp: 2025-09-12T09:47:46.474Z
Learning: In TVM's PyStmtExprMutator, visit_block_ methods typically call super().visit_block_(op) to process child nodes and update internal state, but return the original op when the block itself doesn't need transformation. The pattern `return op` is correct for blocks that serve as containers where mutations happen at deeper levels.

Applied to files:

  • src/transform/layout_inference.cc
🧬 Code graph analysis (7)
src/op/reduce.h (1)
tilelang/language/fill.py (1)
  • clear (50-74)
tilelang/language/reduce.py (1)
tilelang/language/utils.py (1)
  • buffer_to_tile_region (30-42)
src/layout/layout.h (1)
src/layout/layout.cc (6)
  • Layout (58-70)
  • Layout (72-75)
  • Reshape (300-369)
  • Reshape (300-301)
  • Reshape (371-442)
  • Reshape (371-372)
testing/python/language/test_tilelang_language_reshape.py (4)
tilelang/language/allocate.py (2)
  • alloc_shared (27-42)
  • alloc_fragment (59-70)
tilelang/language/customize.py (1)
  • reshape (39-49)
tilelang/profiler/__init__.py (1)
  • assert_allclose (77-146)
tilelang/intrinsics/mma_layout.py (1)
  • make_mma_swizzle_layout (210-223)
src/layout/layout.cc (1)
src/layout/layout.h (1)
  • Inverse (87-145)
src/transform/layout_inference.cc (3)
src/op/operator.cc (4)
  • ParseOperator (32-41)
  • ParseOperator (32-32)
  • ParseOperator (55-61)
  • ParseOperator (55-55)
src/transform/multi_version_buffer_rewriter.cc (2)
  • call (443-473)
  • call (443-444)
src/transform/storage_access.cc (2)
  • VisitStmt_ (64-93)
  • VisitStmt_ (64-64)
src/op/reduce.cc (2)
src/op/gemm_py.cc (2)
  • NormalizeToBufferRegion (25-72)
  • NormalizeToBufferRegion (25-26)
src/op/gemm.cc (2)
  • NormalizeToBufferRegion (53-100)
  • NormalizeToBufferRegion (53-54)
🔇 Additional comments (22)
3rdparty/tvm (1)

1-1: Clarify necessity and rationale for the submodule update.

The PR couples a TVM submodule update with internal reshape and layout inference changes. While the enriched summary suggests the update may provide needed API compatibility, the specific reasons for updating TVM are not clear from the PR context.

To ensure this change is well-justified and safe:

  1. Document the rationale: Why is this TVM commit required for the reshape functionality? Are there specific APIs or bug fixes needed?
  2. Consider separation: If possible, submodule updates should be in separate commits to simplify reviews and enable independent rollbacks.
  3. Verify the commit: Confirm that the new commit (1b54bb0148381ff990e72625d10784fedd369768) exists in the official TVM repository and is stable/released.

Please provide clarity on whether this update is a hard requirement for the reshape feature or if it can be deferred to a separate PR.

src/transform/layout_inference.cc (7)

120-161: LGTM with observation: Alias propagation logic is well-structured.

The propagate_alias lambda correctly handles shape mismatches via Reshape and ensures layout consistency across buffers sharing the same storage variable. The shape equality check is thorough, and the enqueueing logic properly updates the BFS queue for affected buffers.

One observation: The lambda captures analyzer_ by reference from the enclosing class, which is safe here since the lambda is invoked immediately within the same scope.


197-198: Consistent alias propagation across all update paths.

The calls to propagate_alias are correctly placed after updating layout_map to ensure aliases are synchronized whenever layouts change. This maintains consistency across all code paths (fragment containment, equality checks, and new insertions).

Also applies to: 207-213


440-450: GetBufferMap usage is appropriate but mark for future refactoring.

The GetBufferMap() helper returns the first buffer for each storage variable, which is acceptable for backward compatibility with APIs expecting a single buffer per variable. The TODO comment correctly indicates this should be phased out in the future.

Ensure all callers of GetBufferMap() understand they're getting only the first buffer from potentially multiple aliases.


458-458: Buffer access resolution correctly handles multi-buffer scenario.

The updates to ParseOperator (line 458) and getBufferFromAccessPtr (lines 514-538) properly adapt to the new buffer_data_to_buffers_ structure. Returning the first buffer is appropriate when a single buffer is needed for API compatibility.

The warning message at lines 522-523 is helpful for debugging unexpected argument types.

Also applies to: 514-538


574-625: Block annotation processing correctly handles aliased buffers.

The update to VisitStmt_(BlockNode*) properly:

  1. Collects buffers into buffer_data_to_buffers_ before processing annotations
  2. Applies annotated layouts to all buffers sharing a storage variable
  3. Reshapes layouts when buffer shapes differ from annotation shapes

The comment at lines 585-587 helpfully clarifies the visit order. The shape equality check here is part of the previously noted duplication pattern.


747-767: Alias-aware component grouping is well-implemented.

The additional union logic (lines 747-767) correctly handles the case where multiple Buffer objects share the same storage variable. By collecting and unioning all inference indices across aliased buffers, the code ensures that operations on any alias are treated as part of the same connected component.

The use of std::sort followed by std::unique is the standard idiom for deduplication.


698-698: Core structural change enables alias tracking.

The change from Map<Var, Buffer> to Map<Var, Array<Buffer>> is the foundational modification that enables tracking multiple buffers per storage variable. All other changes in this file support and utilize this new structure.

tilelang/language/reduce.py (2)

6-6: Import addition is appropriate.

The import of buffer_to_tile_region from tilelang.language.utils is correctly placed and necessary for the subsequent changes.


52-60: Consistent usage of buffer_to_tile_region across all reduction pathways.

The replacement of direct buffer access pointers with buffer_to_tile_region wrapper calls is applied consistently across all four reduction scenarios:

  1. shared → shared (lines 55-56)
  2. shared → fragment (lines 70-71)
  3. fragment → shared (lines 83-84)
  4. fragment → fragment (lines 94-95)

This aligns with the region-aware approach introduced in the C++ reduce operator (src/op/reduce.cc), ensuring proper tile-region addressing for the tl.reduce intrinsic.

Also applies to: 67-75, 80-89, 91-99

src/layout/layout.h (1)

45-46: Reshape method declarations are properly structured.

The virtual Reshape method added to LayoutNode (lines 45-46) follows the existing pattern for other virtual methods like Inverse and InverseWithLevel. The non-virtual declaration in FragmentNode (line 89) is appropriate since FragmentNode is marked as final.

Both methods accept an arith::Analyzer* parameter for shape product verification, which is consistent with other layout transformation methods.

Also applies to: 89-89

src/op/reduce.h (2)

85-90: Region members are appropriately added and exposed.

The addition of BufferRegion srcRegion_ and dstRegion_ members (lines 85-86) properly captures the original regions used to construct the reduce operation. The explanatory comment is helpful, and the naming follows existing conventions (trailing underscore for private/internal members).

The existing members dim, type, and clear retain their semantics.


99-100: Reflection bindings correctly expose region members.

The read-only reflection bindings for srcRegion and dstRegion (lines 99-100) properly expose these members to the Python API, maintaining consistency with the existing reflection pattern for src and dst.

testing/python/language/test_tilelang_language_reshape.py (4)

4-4: Torch import is appropriate for reference implementations.

The addition of import torch (line 4) is necessary for the new test reference programs that use torch.max and tensor reshaping for validation.


133-174: Fragment reshape test provides good coverage.

The reshape_fragment_test and run_reshape_fragment functions (lines 133-174) properly test reshaping of fragment buffers, including:

  • Allocation of shared and fragment memory
  • Copy operations between memory scopes
  • Reshape of fragment buffers
  • Validation against torch reference

The test correctly exercises the newly added reshape functionality for fragment layouts.


177-218: Layout transform reshape test validates swizzled layouts.

The reshape_layout_transform_shared test (lines 177-218) adds important coverage for reshaping buffers with annotated swizzled layouts. The use of make_mma_swizzle_layout ensures that reshape works correctly with non-trivial layout transformations.


221-262: Reduce after reshape test validates end-to-end integration.

The reduce_after_reshape_test (lines 221-262) provides crucial integration testing by combining reshape with reduction operations:

  1. Reshape a 1D fragment to 2D
  2. Apply reduce_max along a dimension
  3. Validate against torch.max reference

This test ensures the region-aware reduction changes work correctly with reshaped buffers, validating the broader PR objectives.

src/op/reduce.cc (3)

17-17: Region header inclusion is necessary.

The addition of #include "region.h" (line 17) is required for the RegionOp usage in NormalizeToBufferRegion.


25-64: NormalizeToBufferRegion implementation is correct and follows existing patterns.

The NormalizeToBufferRegion function (lines 25-64) properly handles three cases:

  1. BufferRegion: Returns directly (lines 30-32)
  2. BufferLoad: Converts indices to ranges, handling both Ramp nodes (stride-1 vectorization) and scalar indices (lines 36-51)
  3. tl.region calls: Reconstructs BufferRegion via RegionOp (lines 56-59)

The implementation mirrors similar functions in src/op/gemm.cc and src/op/gemm_py.cc (see relevant snippets), ensuring consistency across operators. The error message (line 62) is clear.

Note: The function does not handle builtin.tvm_access_ptr() unlike the GEMM variants, which may be intentional for reduce operations.


66-78: ReduceOp constructor correctly adopts region-aware approach.

The updated constructor (lines 66-78) properly:

  1. Normalizes source and destination arguments to BufferRegion (lines 69-70)
  2. Stores the normalized regions in srcRegion_ and dstRegion_
  3. Derives src and dst buffers from the regions (lines 71-72)
  4. Preserves existing logic for dim, type, and clear (lines 73-76)

This change enables reduce operations to work uniformly with different argument forms (BufferRegion, BufferLoad, tl.region calls), which is essential for reshape support.

src/layout/layout.cc (2)

300-369: LayoutNode::Reshape implementation is correct and well-structured.

The LayoutNode::Reshape implementation (lines 300-369) correctly:

  1. Fast path: Returns self if shapes are equal (lines 302-305)
  2. Shape validation: Verifies total element count is preserved using analyzer (lines 307-324)
  3. Index remapping:
    • Computes flat index from new shape indices (lines 332-341)
    • Converts flat index back to original shape indices (lines 347-356)
    • Substitutes original indices into forward_index_ (lines 358-367)

The algorithm properly handles row-major layout assumptions and preserves the forward mapping semantics.


371-442: FragmentNode::Reshape correctly extends LayoutNode::Reshape for thread-aware layouts.

The FragmentNode::Reshape implementation (lines 371-442) properly extends the base reshape logic:

  1. Includes the fast path and shape validation (lines 373-393)
  2. Uses the same flat-index remapping approach (lines 395-418)
  3. Additionally handles thread mapping: Substitutes into forward_thread_ (lines 430-434)
  4. Preserves thread_range: Re-binds thread_range if defined (lines 438-440)

The implementation correctly maintains all fragment-specific properties (thread mapping, replication extent, thread range) while reshaping the spatial layout.

Comment on lines 104 to 122
if (arith::is_neg_inf(ist.min()) && arith::is_pos_inf(ist.max())) {
// X-OR Expression
ret.Set(i, input_size_[i]);
// Analyzer couldn't form an IntervalSet (e.g. bitwise ops).
// Fall back to ConstIntBound to derive a safe extent.
auto cib = analyzer.const_int_bound(forward_index_[i]);
if (cib->min_value != arith::ConstIntBound::kNegInf &&
cib->max_value != arith::ConstIntBound::kPosInf &&
cib->min_value >= 0) {
// extent = max - min + 1, using 64-bit integer literal
ret.Set(i, Integer(cib->max_value - cib->min_value + 1));
} else {
// Last-resort conservative fallback to avoid OOB/crash
// Prefer to keep dimension from known input_size_ if available.
if (i < input_size_.size()) {
ret.Set(i, input_size_[i]);
} else {
ret.Set(i, Integer(1));
}
}
} else {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

Improved fallback logic prevents crashes but may be overly conservative.

The updated OutputShape() fallback (lines 104-122) improves robustness when the analyzer cannot form an IntervalSet:

Improvements:

  1. Uses const_int_bound to derive safe extents when possible (lines 107-112)
  2. Falls back to input_size_[i] or 1 as last resort (lines 115-120)
  3. Preserves the CHECK(is_one(ist.min())) assertion for the normal path (line 123)

Concern: The fallback to input_size_[i] or 1 (lines 116-119) is conservative but may produce incorrect shapes for complex layouts. Consider whether a runtime error or more sophisticated analysis would be preferable to silently using potentially incorrect dimensions.

Verify that the fallback behavior is acceptable for all layout types, especially when forward_index_[i] involves bitwise operations or other constructs the analyzer cannot handle.


🏁 Script executed:

#!/bin/bash
# Search for layouts that might trigger the fallback path
rg -n -C3 'OutputShape|forward_index_' --type=cpp | grep -E '(bitwise|&|\||\^)'

Length of output: 2812


🏁 Script executed:

#!/bin/bash
# Search for forward_index_ definitions and assignments to see what expressions are stored
rg -n 'forward_index_\s*=' --type=cpp -A 2 | head -100

Length of output: 539


🏁 Script executed:

#!/bin/bash
# Look for actual bitwise operations in layout files
rg -nP '\s+\d+\s*(&|\||\^|<<|>>)\s*\d+' --type=cpp src/layout/

Length of output: 0


🏁 Script executed:

#!/bin/bash
# Find test files related to layout or OutputShape
fd -e cc -e h 'test|layout' --type f | grep -i test | head -20

Length of output: 0


🏁 Script executed:

#!/bin/bash
# Search for calls to OutputShape to see how result is used
rg -n 'OutputShape\(\)' --type=cpp -A 2 -B 1 | head -80

Length of output: 4699


🏁 Script executed:

#!/bin/bash
# Find input_size_ definition and its type
rg -n 'input_size_' --type=cpp -B 2 -A 2 | grep -E '(Array|vector|input_size_|PrimExpr)' | head -40

Length of output: 2304


🏁 Script executed:

#!/bin/bash
# Check context around the ICHECK that uses OutputShape
rg -n 'as_const_int.*OutputShape' --type=cpp -B 5 -A 5

Length of output: 632


🏁 Script executed:

#!/bin/bash
# Search for input_size_ member variable declaration in header files
rg -n 'input_size_' --type=h -B 1 -A 1

Length of output: 402


🏁 Script executed:

#!/bin/bash
# Look for other places where OutputShape result is used without as_const_int conversion
rg -n 'OutputShape\(\)\[' --type=cpp -B 2 -A 2

Length of output: 1455


🏁 Script executed:

#!/bin/bash
# Search for Layout/Fragment construction to see what input_size values are passed
rg -n 'Layout\(|Fragment\(' --type=cpp -A 2 | head -60

Length of output: 3885


🏁 Script executed:

#!/bin/bash
# Check if input_size_ elements are checked/validated to be const before being returned
rg -n 'input_size_\[' --type=cpp -B 3 | head -50

Length of output: 848


🏁 Script executed:

#!/bin/bash
# Look at the full context of layout_inference.cc lines 836-842 to understand the crash scenario
rg -n 'frag_reg_num' --type=cpp -B 8 -A 5

Length of output: 1497


Fix the OutputShape() fallback to ensure it always returns const integers, or validate input_size_ at construction time.

The fallback logic (lines 104-122) has a critical bug: it returns input_size_[i] (line 117), which may contain non-const PrimExpr since input_size_ is populated from sources like buffer->shape (see layout_reducer.cc:216-217). This violates the implicit contract that OutputShape() returns const dimensions.

Specifically, layout_inference.cc:839 will crash with an ICHECK failure:

for (auto i : frag.value()->OutputShape()) {
  auto pci = as_const_int(i);
  ICHECK(pci != nullptr);  // crashes if OutputShape() returns symbolic expr
  frag_reg_num *= *pci;
}

Fix: Either (1) validate that input_size_ elements are const at construction time, (2) ensure the fallback path only returns proven const dimensions (e.g., always Integer(1) instead of input_size_[i]), or (3) add a precondition guard preventing this fallback when input_size_ contains symbolic expressions.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (1)
src/layout/layout.cc (1)

104-122: Critical issue from previous review remains unresolved.

The fallback at line 117 still returns input_size_[i], which may contain non-const PrimExpr values (e.g., from buffer->shape). This violates the implicit contract that OutputShape() returns const dimensions and will cause the ICHECK(pci != nullptr) at layout_inference.cc:839 to fail when as_const_int(i) returns nullptr for symbolic expressions.

The previous review recommended one of three fixes:

  1. Validate that input_size_ elements are const at construction time
  2. Always return proven const dimensions in the fallback (e.g., Integer(1) instead of input_size_[i])
  3. Add a precondition guard preventing this fallback when input_size_ contains symbolic expressions

Consider applying this conservative fix to eliminate the crash risk:

       } else {
         // Last-resort conservative fallback to avoid OOB/crash
-        // Prefer to keep dimension from known input_size_ if available.
-        if (i < input_size_.size()) {
-          ret.Set(i, input_size_[i]);
-        } else {
-          ret.Set(i, Integer(1));
-        }
+        // Fallback to 1 since we cannot guarantee input_size_[i] is const.
+        ret.Set(i, Integer(1));
       }

Alternatively, validate at construction time that all elements of input_size_ are const integers when the layout will be used in contexts requiring const dimensions.

🧹 Nitpick comments (5)
src/layout/layout.cc (2)

358-368: Consider optimizing the substitution loop.

The nested loop iterates over each forward expression and applies Substitute once per input dimension. This creates intermediate expression trees on each iteration, which could be inefficient for layouts with many dimensions.

Apply this refactor to build the substitution map once and apply it to all expressions:

   // Step 5. Substitute original indices into forward_index_
+  Map<Var, PrimExpr> vmap;
+  for (size_t i = 0; i < InputShape().size(); ++i) {
+    vmap.Set(InputPlaceholder(i), original_indices[i]);
+  }
+  
   Array<PrimExpr> new_forward_index;
   for (const auto &fwd_expr : forward_index_) {
-    PrimExpr substituted = fwd_expr;
-    // Replace each InputPlaceholder(i) with original_indices[i]
-    for (size_t i = 0; i < InputShape().size(); ++i) {
-      substituted =
-          Substitute(substituted, {{InputPlaceholder(i), original_indices[i]}});
-    }
-    new_forward_index.push_back(substituted);
+    new_forward_index.push_back(Substitute(fwd_expr, vmap));
   }

421-435: Consider optimizing the substitution loops.

Similar to the LayoutNode::Reshape implementation, the substitution is performed iteratively in nested loops (lines 425-427 for forward_index_ and lines 432-434 for forward_thread_). Building the substitution map once and applying it to all expressions would be more efficient.

Apply this refactor:

   // 4) Substitute old placeholders with expressions of new indices
+  Map<Var, PrimExpr> vmap;
+  for (size_t i = 0; i < InputShape().size(); ++i) {
+    vmap.Set(InputPlaceholder(i), orig_indices[i]);
+  }
+  
   Array<PrimExpr> new_forward_index;
   for (const auto &e : forward_index_) {
-    PrimExpr cur = e;
-    for (size_t i = 0; i < InputShape().size(); ++i) {
-      cur = Substitute(cur, {{InputPlaceholder(i), orig_indices[i]}});
-    }
-    new_forward_index.push_back(cur);
+    new_forward_index.push_back(Substitute(e, vmap));
   }

-  PrimExpr new_forward_thread = forward_thread_;
-  for (size_t i = 0; i < InputShape().size(); ++i) {
-    new_forward_thread = Substitute(new_forward_thread,
-                                    {{InputPlaceholder(i), orig_indices[i]}});
-  }
+  PrimExpr new_forward_thread = Substitute(forward_thread_, vmap);
src/transform/layout_inference.cc (3)

120-161: Consider extracting the reshape logic into a helper method to reduce duplication.

The layout propagation and reshape logic in this lambda is duplicated in the final alias propagation pass (lines 319-357) and in VisitStmt_(BlockNode) (lines 579-598). Extracting this into a private helper method would improve maintainability and reduce the risk of inconsistencies across these three locations.

Example signature:

Layout ReshapeLayoutIfNeeded(const Layout& src_layout, 
                              const Buffer& target_buffer,
                              arith::Analyzer* analyzer) const {
  bool shapes_equal = src_layout->InputShape().size() == target_buffer->shape.size();
  if (shapes_equal) {
    for (size_t i = 0; i < src_layout->InputShape().size(); ++i) {
      if (!analyzer->CanProveEqual(src_layout->InputShape()[i], target_buffer->shape[i])) {
        shapes_equal = false;
        break;
      }
    }
  }
  return shapes_equal ? src_layout : src_layout->Reshape(target_buffer->shape, analyzer);
}

615-643: Consider optimizing the duplicate check for better performance.

The linear search for duplicates (lines 621-634) is O(n) for each buffer insertion, which could become expensive with many alias buffers. While the current implementation is correct, consider using a helper map or set to track buffer pointers for O(1) duplicate detection if performance becomes an issue.

Example optimization:

// Add as a class member:
std::unordered_set<const BufferNode*> seen_buffers_;

// In the visitor:
if (seen_buffers_.insert(op->buffer.get()).second) {
  // Buffer was not seen before, add it
  if (buffer_data_to_buffers_.count(op->buffer->data)) {
    auto buffers = buffer_data_to_buffers_[op->buffer->data];
    buffers.push_back(op->buffer);
    buffer_data_to_buffers_.Set(op->buffer->data, buffers);
  } else {
    buffer_data_to_buffers_.Set(op->buffer->data, {op->buffer});
  }
}

645-673: Consider extracting the common buffer collection logic.

The VisitStmt_(BufferStoreNode) implementation is nearly identical to VisitExpr_(BufferLoadNode) (lines 615-643). Consider extracting a private helper method to handle buffer collection and duplicate checking to reduce code duplication.

Example helper:

void CollectBuffer(const Buffer& buffer, const char* source) {
  if (!buffer.defined() || !buffer->data.defined()) return;
  
  if (buffer_data_to_buffers_.count(buffer->data)) {
    auto buffers = buffer_data_to_buffers_[buffer->data];
    bool found = false;
    for (const auto &buf : buffers) {
      if (buf.same_as(buffer)) {
        found = true;
        break;
      }
    }
    if (!found) {
      buffers.push_back(buffer);
      buffer_data_to_buffers_.Set(buffer->data, buffers);
      DLOG(INFO) << "[LayoutInference] " << source << ": added buffer " << buffer;
    }
  } else {
    buffer_data_to_buffers_.Set(buffer->data, {buffer});
    DLOG(INFO) << "[LayoutInference] " << source << ": new buffer " << buffer;
  }
}
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 2a76b53 and 53243be.

📒 Files selected for processing (2)
  • src/layout/layout.cc (2 hunks)
  • src/transform/layout_inference.cc (11 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-09-12T09:47:46.474Z
Learnt from: kurisu6912
Repo: tile-ai/tilelang PR: 794
File: tilelang/transform/add_bufstore_wrapper.py:30-33
Timestamp: 2025-09-12T09:47:46.474Z
Learning: In TVM's PyStmtExprMutator, visit_block_ methods typically call super().visit_block_(op) to process child nodes and update internal state, but return the original op when the block itself doesn't need transformation. The pattern `return op` is correct for blocks that serve as containers where mutations happen at deeper levels.

Applied to files:

  • src/transform/layout_inference.cc
🧬 Code graph analysis (2)
src/layout/layout.cc (2)
tilelang/language/tir/op.py (3)
  • ret (1866-1879)
  • floordiv (3180-3199)
  • floormod (3202-3221)
src/layout/layout.h (1)
  • Inverse (87-145)
src/transform/layout_inference.cc (3)
src/op/operator.cc (4)
  • ParseOperator (32-41)
  • ParseOperator (32-32)
  • ParseOperator (55-61)
  • ParseOperator (55-55)
src/transform/multi_version_buffer_rewriter.cc (2)
  • call (443-473)
  • call (443-444)
src/transform/storage_access.cc (2)
  • VisitStmt_ (64-93)
  • VisitStmt_ (64-64)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
  • GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
  • GitHub Check: Test for Python 3.12 with ROCm-6.3 (on self-hosted-amd)
🔇 Additional comments (8)
src/layout/layout.cc (1)

300-443: Reshape implementations are mathematically sound and well-structured.

Both LayoutNode::Reshape and FragmentNode::Reshape follow the correct algorithm:

  1. Validate that total element count is preserved
  2. Compute flat index from new shape indices
  3. Recover original shape indices from flat index
  4. Substitute into forward mappings

The implementations properly handle:

  • Fast-path optimization for identical shapes
  • Analyzer parameter with fallback to local analyzer
  • Helpful error messages with diagnostic information
  • Fragment-specific concerns (forward_thread_, thread_range_)
src/transform/layout_inference.cc (7)

14-14: LGTM: Standard library include for sorting operations.

The <algorithm> header is correctly added to support std::sort and std::unique operations used in the enhanced InferInFreeMode logic (lines 737-738).


319-357: LGTM: Finalization step ensures layout completeness for alias buffers.

This final propagation pass (step 4) correctly ensures that all buffers sharing the same storage Var have consistent layouts, addressing a key objective of this PR. The logic properly handles reshaping when buffer shapes differ.


417-427: LGTM: Transitional helper for backward compatibility.

The GetBufferMap() helper appropriately provides a representative buffer map for existing APIs that expect Map<Var, Buffer>. The TODO comment correctly identifies this as temporary until the codebase fully migrates to multi-buffer support.


491-515: LGTM: Correctly adapted for multi-buffer support.

The method now properly retrieves the first buffer from the array of buffers sharing the same storage Var, consistent with the GetBufferMap() strategy. The improved logging from DLOG to LOG(WARNING) provides better visibility for this edge case.


562-601: LGTM: Correctly applies annotated layouts to all alias buffers.

The implementation properly iterates over all buffers sharing the same storage Var and applies the annotated layout with appropriate reshaping. The order of operations (visiting the body first to collect all buffers, then applying layouts) is correct.


724-744: LGTM: Enhanced component merging correctly handles buffer aliases.

This addition properly unions operators that share buffers with the same underlying storage Var, which is essential for handling reshape and alias scenarios. The use of std::sort and std::unique to deduplicate the merged indices is correct and efficient.


140-142: No action required—Reshape error handling is properly implemented.

After examining the Reshape() implementations and all call sites, the concerns are addressed:

  1. Edge cases are handled: Both LayoutNode::Reshape() and FragmentNode::Reshape() use ICHECK() to validate that the product of the input shape equals the product of the new shape. If shapes are incompatible or have different total sizes, the check fails and terminates cleanly.

  2. Exception handling is not needed: Reshape() does not throw exceptions. Instead, it uses ICHECK() assertions for contract validation. This is appropriate for an internal API—validation failures are logic errors, not runtime exceptions requiring try-catch handling.

  3. Analyzer's inability to prove equality is safe: If the analyzer cannot prove shape product equality, the ICHECK() will fail (terminating the program) rather than proceeding with an incorrect reshape. This conservative approach prevents silent data corruption.

All three call sites (lines 142, 353, 596 in layout_inference.cc) properly pass the analyzer and handle the returned Layout object without needing exception handlers.

@LeiWang1999 LeiWang1999 merged commit 4370309 into tile-ai:main Nov 12, 2025
6 checks passed
RubiaCx pushed a commit to RubiaCx/tilelang that referenced this pull request Nov 24, 2025
* Update layout handling and introduce reshape functionality

- Updated the `LayoutNode` class to include a new `Reshape` method, allowing for dynamic reshaping of layouts based on input shapes.
- Enhanced the `OutputShape` method to provide better handling of cases where the analyzer cannot form an `IntervalSet`, implementing fallback mechanisms to ensure safe extents.
- Refactored the `ReduceOpNode` to utilize `BufferRegion` for improved memory handling during reduction operations.
- Added tests for reshaping functionality and layout transformations to ensure correctness and performance in various scenarios.

* lint fix

* Revert tvm submodule pointer to 1815c3e0b6ec4ead36370bbd1562025d8529017c; keep src unchanged

* Update tvm submodule to commit f0bbd3bf741413c35c389ba5dedd5be206000ad1

* Update tvm submodule to commit f0bbd3bf741413c35c389ba5dedd5be206000ad1

* remove useless prove

* remove comment

---------

Co-authored-by: tilelang-bot <bot@tilelang>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant