Skip to content

Conversation

@LeiWang1999
Copy link
Member

@LeiWang1999 LeiWang1999 commented Sep 5, 2025

  • Updated CopyNode to introduce separate handling for 1D bulk load/store operations, including new methods for checking and lowering these operations.
  • Modified InferLayout and GetCopyInst to accommodate additional parameters for layout maps and analyzers.
  • Enhanced AtomicAddNode and FillNode to utilize the updated layout inference logic.
  • Improved buffer out-of-bounds checks during layout inference to ensure safe memory access.

This update improves the efficiency and correctness of memory operations in the TileLang framework.

Summary by CodeRabbit

  • New Features

    • Added explicit 1D bulk-copy support for memory transfers.
  • Improvements

    • Layout inference now uses richer analysis context and per-operation out-of-bounds tracking to improve lowering decisions.
    • Copy lowering unified to handle 1D and 2D paths with consistent address/stride computation and layout-aware swizzling.
    • Public APIs updated to carry layout/analysis context for copy and layout inference.
  • Refactor

    • Simplified warp-specialization barrier and thread-partition logic.
  • Documentation

    • Minor wording update in copy indices description.

- Updated `CopyNode` to introduce separate handling for 1D bulk load/store operations, including new methods for checking and lowering these operations.
- Modified `InferLayout` and `GetCopyInst` to accommodate additional parameters for layout maps and analyzers.
- Enhanced `AtomicAddNode` and `FillNode` to utilize the updated layout inference logic.
- Improved buffer out-of-bounds checks during layout inference to ensure safe memory access.

This update improves the efficiency and correctness of memory operations in the TileLang framework.
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Sep 5, 2025

Walkthrough

Adds an arith::Analyzer pointer and a buffer_oob flag to LayoutInferArgs and propagates them through layout-inference and lowering call sites; implements explicit 1D bulk-copy support and checks in CopyNode, refactors 2D bulk-copy addressing to use layout_map, introduces per-op OOB tracking in layout inference, and removes WgMMA gating from warp-specialized rewriting.

Changes

Cohort / File(s) Summary
Layout inference plumbing
src/op/operator.h, src/transform/layout_inference.cc, src/op/atomic_add.cc, src/op/fill.cc
Added arith::Analyzer *analyzer and bool buffer_oob to LayoutInferArgs; updated InferLayout call sites to pass analyzer and buffer_oob; BufferUseDefCollector now records per-op OOB and thread bounds; added using-declarations for analyzer-aware visitor/mutator ctors.
Copy op: 1D bulk and API expansion
src/op/copy.h, src/op/copy.cc
Added kBulkLoad1D/kBulkStore1D; extended GetCopyInst to accept layout_map, arith::Analyzer*, and buffer_oob; added CheckBulkLoad1D/CheckBulkStore1D/CheckBulkCopy1D and LowerBulkCopy1D; Copy::Lower dispatches to 1D/2D bulk, LDS/SM, or normal paths; refactored offset/stride addressing to use layout_map and updated swizzle decisions.
Operator lowers updated
src/op/atomic_add.cc, src/op/fill.cc
Expanded calls to par_op->InferLayout in lowering paths to include analyzer and buffer_oob in the LayoutInferArgs initializer.
Warp-specialized rewriter simplification
src/transform/warp_specialized_rewriter.cc
Removed WgMMA-specific gating, constructor flag, and onlyHasWgMMA accessor from WSCodeEmitter; simplified barrier-arrival and thread-partition logic to a uniform non-WgMMA path.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant LI as LayoutInferencer
  participant AZ as Analyzer
  participant Op as OpNode
  participant LM as LayoutMap

  LI->>AZ: Analyze bounds for current op
  LI->>LI: compute per-op buffer_oob
  LI->>Op: InferLayout({target, thread_bounds, LM, AZ, buffer_oob})
  Op-->>LI: return LayoutMap
  LI->>LI: store per-op thread/OOB context
Loading
sequenceDiagram
  autonumber
  participant Copy as CopyNode
  participant LA as LowerArgs
  participant AZ as Analyzer
  participant Inst as CopyInstSelector

  Copy->>Inst: GetCopyInst(target, disable_tma_lower, layout_map, AZ, buffer_oob)
  alt kBulkLoad1D / kBulkStore1D
    Copy->>Copy: LowerBulkCopy1D(LA, AZ, inst)
  else kBulkLoad / kBulkStore
    Copy->>Copy: LowerBulkCopy(LA, AZ, inst)
  else kLDSM / kSTSM
    Copy->>Copy: LowerLDSMCopy(LA, inst)
  else kNormal
    Copy->>Copy: LowerNormalCopy(LA, AZ)
  end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

Poem

I nibble bytes and map the land,
I probe each bound with careful hand,
One-dimensional hops, a tidy leap,
WgMMA's veil I let it sleep,
Analyzer hums — layouts sing and keep. 🐇✨

✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

CodeRabbit Commands (Invoked using PR/Issue comments)

Type @coderabbitai help to get the list of available commands.

Other keywords and placeholders

  • Add @coderabbitai ignore or @coderabbit ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Status, Documentation and Community

  • Visit our Status Page to check the current availability of CodeRabbit.
  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

@github-actions
Copy link

github-actions bot commented Sep 5, 2025

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run bash format.sh in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work!

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @LeiWang1999, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces specialized handling for 1D Tensor Memory Access (TMA) operations within the TileLang framework, aiming to optimize memory transfers for contiguous data. It refactors the CopyNode to better manage different types of bulk copies and enhances the layout inference system with more detailed analytical capabilities, including explicit buffer out-of-bounds checks. Additionally, it includes a cleanup of the warp specialization rewriter.

Highlights

  • Introduction of 1D TMA Operations: Added specific support for 1D Tensor Memory Access (TMA) bulk load and store operations (kBulkLoad1D, kBulkStore1D) to improve efficiency for contiguous memory transfers.
  • Refactored CopyNode Logic: The CopyNode now includes dedicated methods (CheckBulkLoad1D, CheckBulkStore1D, LowerBulkCopy1D) to check for and lower 1D bulk copies, separating this logic from general bulk copy handling.
  • Enhanced Layout Inference: The InferLayout and GetCopyInst functions, along with the LayoutInferArgs structure, have been extended to include arith::Analyzer and buffer_oob parameters, allowing for more precise layout inference and out-of-bounds checking.
  • Improved Buffer OOB Checks: Out-of-bounds checks for buffers are now explicitly incorporated into the layout inference process, particularly for CopyNode operations, ensuring safer memory access.
  • Updated Operator Inference: AtomicAddNode and FillNode have been updated to utilize the new InferLayout parameters, ensuring they benefit from the enhanced layout inference capabilities.
  • Warp Specialization Refactoring: Simplified the WSCodeEmitter and WarpSpecializedRewriter by removing the only_has_wgmma_ flag and related logic, streamlining the warp specialization pass.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for 1D TMA (Tensor Memory Access) bulk copy operations, which can improve performance in appropriate cases. The changes include new functions to check for and lower 1D TMA, updates to the CopyInst enum, and modifications to layout inference to handle these new cases. While the changes are generally good and include some nice refactoring, I've identified a critical issue with potential for a crash due to unsafe map access, and a high-severity bug in the logic for checking 1D bulk store conditions. I've also noted a couple of places where documentation was removed, which could impact future maintainability.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (3)
src/op/atomic_add.cc (1)

289-307: address_of applied to a conditional rvalue can generate invalid IR

address_of expects a BufferLoad (lvalue). Wrapping dst_value in if_then_else before address_of makes it a non-lvalue expression. Guard the atomic call with a combined predicate instead of taking address_of on a conditional.

-  PrimExpr dst_value = BufferLoad(dst, dst_indices);
-  if (dst_predicate.defined())
-    dst_value = if_then_else(dst_predicate, dst_value, make_zero(dst->dtype));
-
-  Call address_of_value =
-      tvm::tir::Call(DataType::Handle(), builtin::address_of(), {dst_value});
+  PrimExpr dst_load = BufferLoad(dst, dst_indices);
+  Call address_of_value =
+      tvm::tir::Call(DataType::Handle(), builtin::address_of(), {dst_load});
   new_args.push_back(address_of_value);
   new_args.push_back(src_value);
 
-  Call atomicadd_call =
-      tvm::tir::Call(dst->dtype, builtin::call_extern(), new_args);
-
-  Stmt body = tvm::tir::Evaluate(atomicadd_call);
+  Call atomicadd_call =
+      tvm::tir::Call(dst->dtype, builtin::call_extern(), new_args);
+  Stmt body = tvm::tir::Evaluate(atomicadd_call);
+  // Combine source/destination predicates to guard the atomic op
+  PrimExpr call_pred = Bool(true);
+  if (src_predicate.defined()) call_pred = And(call_pred, src_predicate);
+  if (dst_predicate.defined()) call_pred = And(call_pred, dst_predicate);
+  if (!(is_one(call_pred))) {
+    body = IfThenElse(call_pred, body);
+  }
src/transform/layout_inference.cc (1)

88-110: Include buffer_remap in LayoutInferArgs initializer
In transform/layout_inference.cc (lines 88–110), the call to

next->InferLayout(LayoutInferArgs{target_, thread_bounds, layout_map, &analyzer_, buffer_oob}, level);

omits the buffer_remap map, defaulting it to empty. Many operator-level InferLayout implementations (e.g. fill, atomic_add, reduce, copy, finalize_reducer, gemm_sp) read T.buffer_remap and will misbehave. Change the call to pass the current buffer_remap, e.g.:

auto updates = next->InferLayout(
    LayoutInferArgs{target_, thread_bounds, layout_map, &analyzer_, buffer_oob, buffer_remap}, level);
src/op/copy.cc (1)

1121-1161: Avoid .at() on layout_map; fall back cleanly when no layout is annotated.

.at() will CHECK-fail if the shared buffer has no layout entry.

-  auto shared_layout = T.layout_map.at(shared_tensor);
-  if (!shared_layout.defined()) {
+  Layout shared_layout;
+  if (T.layout_map.count(shared_tensor)) {
+    shared_layout = T.layout_map[shared_tensor];
+  }
+  if (!shared_layout.defined()) {
     desc.swizzle = static_cast<int>(CU_TENSOR_MAP_SWIZZLE_NONE);
🧹 Nitpick comments (7)
src/transform/warp_specialized_rewriter.cc (1)

1226-1229: Nit: drop the explicit false — it’s the default.

The extra boolean obscures intent and adds churn. Prefer the defaulted ctor.

Apply:

-WSCodeEmitter consumer(false, thread_iv_, buffer_data_to_buffer_, marker,
-                       false);
+WSCodeEmitter consumer(false, thread_iv_, buffer_data_to_buffer_, marker);
src/op/operator.h (1)

48-51: Make analyzer pointer explicitly nullable and document buffer_oob semantics

Avoid accidental uninitialized use and clarify meaning/expectations of buffer_oob (true = may be OOB). Recommend default-init and brief comment.

-  arith::Analyzer *analyzer;
-  bool buffer_oob = false;
+  // Optional analysis context; nullptr means no analyzer available
+  arith::Analyzer* analyzer = nullptr;
+  // True if this op may access out-of-bounds (used to steer conservative lowering)
+  bool buffer_oob = false;
src/op/copy.h (2)

156-175: Public 1D bulk-copy checks: add brief doc on preconditions

Short comment on contiguity/stride assumptions (e.g., unit stride in either src or dst) will help callers.


189-193: Add = false default for buffer_oob in GetCopyInst declaration
In src/op/copy.h (line 189), change the last parameter from

bool buffer_oob

to

bool buffer_oob = false

All call sites now pass 4 or 5 arguments (no legacy 2-arg uses), so this default reduces churn.

src/op/copy.cc (3)

199-206: Fix doc typo: “Create s” → “Create indices”.

Small nit in the brief line.

- * \brief Create s for the copy operation.
+ * \brief Create indices for the copy operation.

996-1005: Address computation refactor looks good, but add small robustness guard.

The stride/offset derivation is fine. Add early return if shared_indices.size()==0 to avoid accidental 0-d edge cases.

-  PrimExpr shared_offset = 0;
-  for (size_t i = 0; i < shared_indices.size(); i++) {
+  PrimExpr shared_offset = 0;
+  if (shared_indices.empty()) return LowerNormalCopy(T, analyzer);
+  for (size_t i = 0; i < shared_indices.size(); i++) {
     shared_offset += shared_indices[i] * shared_strides[i];
   }

Also applies to: 1019-1025


727-729: Pass actual buffer_oob to InferLayout
Replace the hard-coded false with T.buffer_oob so the OOB flag isn’t dropped:

-      {T.target, T.thread_bounds, T.layout_map, analyzer, false, T.buffer_remap},
+      {T.target, T.thread_bounds, T.layout_map, analyzer, T.buffer_oob, T.buffer_remap},
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 6e0c350 and 65bbb80.

📒 Files selected for processing (7)
  • src/op/atomic_add.cc (1 hunks)
  • src/op/copy.cc (12 hunks)
  • src/op/copy.h (5 hunks)
  • src/op/fill.cc (2 hunks)
  • src/op/operator.h (1 hunks)
  • src/transform/layout_inference.cc (10 hunks)
  • src/transform/warp_specialized_rewriter.cc (4 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/transform/layout_inference.cc (1)
tilelang/language/copy.py (1)
  • copy (84-152)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: bot-task
  • GitHub Check: format-check
  • GitHub Check: format-check
🔇 Additional comments (16)
src/transform/warp_specialized_rewriter.cc (2)

770-778: Barrier-arrival now unconditional on consumer path — validate counts against init.

With WgMMA gating removed, consumer makeArriveBarrier(release_barrier_id) executes whenever release_after is true. Ensure the corresponding barrier’s arrive_thread_count (computed below) matches this arrival source to avoid hangs or early releases.

You can sanity-check by instrumenting a small IR containing both producer/consumer releases and verifying the number of barriers equals the number of arrive insertions per pattern. If helpful, I can draft a small IR snippet test.


1251-1258: Thread-count policy for barriers: double-check producer-vs-consumer arrival math.

Logic reads: producer-released ⇒ producer_thread_extent if SIMT copy else 1; otherwise ⇒ consumer_thread_extent. This matches TMA (single-arrive) vs SIMT semantics, but edge-cases to verify:

  • Bulk-copy-only path forces producer_thread_extent = 128; ensure kernels actually launch ≥ consumer+128 threads.
  • Mixed patterns where a barrier can be released by consumer in one stage and producer in another should still be consistent with released_barrier_ bookkeeping.

If desired, add a brief comment here to document the rationale.

I can add a unit-style IR test that asserts no mismatch between initialized counts and total arrives per barrier across producer/consumer paths.

src/op/atomic_add.cc (1)

364-365: Pass-through of analyzer and buffer_remap looks correct

Extending InferLayout args here is consistent with the new signature. No issues.

src/op/fill.cc (2)

173-175: LGTM: InferLayout arg expansion (fragment path)

Correctly threads analyzer, buffer_oob=false, and buffer_remap.


190-192: LGTM: InferLayout arg expansion (shared/global path)

Matches the new LayoutInferArgs layout.

src/transform/layout_inference.cc (5)

19-20: Include copy.h is necessary for CopyNode OOB logic

Good addition to enable type checks.


69-70: Using-declarations to inherit constructors are fine

Keeps analyzer-aware visitors/mutators consistent.


207-210: Alignment check for buffer_oob_vec_ with infer_list_

Good guard against index skew between vectors.


408-409: Parallel For bookkeeping includes buffer_oob false

Keeps vectors aligned. Good.


601-602: Using-declaration for IRMutatorWithAnalyzer

Looks good; consistent with visitor.

src/op/copy.h (4)

18-26: Enum extension for 1D bulk-copy looks sensible

Clear separation between 1D and multi-D TMA paths.


144-145: Marking InferLayout as override is correct

Matches TileOperatorNode vtable.


201-206: Adding LowerBulkCopy1D hook is good; ensure parity with 2D path on swizzle/oob

Confirm swizzle, l2_promotion, and oob_fill behaviors match LowerBulkCopy for consistency.


349-350: Conv2DIm2Col InferLayout override

Signature matches base. Good.

src/op/copy.cc (2)

405-407: Good: GetCopyInst now considers analyzer, layout_map, and buffer_oob during inference.

This keeps 1D TMA selection aligned with OOB knowledge at infer time.


520-542: 1D checks look consistent with 2D gating; OK after fixing CheckBulkCopy1D.

Call routing and parameterization are correct.

Comment on lines 473 to 519
bool CopyNode::CheckBulkCopy1D(const Buffer &global_tensor,
const Buffer &shared_tensor,
const Array<Range> &global_range,
const Array<Range> &shared_range,
const LayoutMap &layout_map,
arith::Analyzer *analyzer) const {

// Step 1: check shared is contiguous
bool shared_is_contiguous = true;
if (layout_map.count(dst)) {
shared_is_contiguous = false;
}
// Step 2: check global is contiguous
bool global_is_contiguous = true;
bool global_not_full_dim_encounter = false;
for (int i = global_range.size() - 1; i >= 0; i--) {
if (!global_not_full_dim_encounter) {
if (!analyzer->CanProve(global_range[i]->extent ==
global_tensor->shape[i] &&
global_range[i]->min == 0,
arith::ProofStrength::kSymbolicBound)) {
global_not_full_dim_encounter = true;
}
} else {
if (!analyzer->CanProve(global_range[i]->extent == 1,
arith::ProofStrength::kSymbolicBound)) {
global_is_contiguous = false;
break;
}
}
}

// Step 3: check element match and no OOB
PrimExpr shared_elements = 1;
for (size_t i = 0; i < shared_range.size(); i++) {
shared_elements *= shared_range[i]->extent;
}
PrimExpr global_elements = 1;
for (size_t i = 0; i < global_range.size(); i++) {
global_elements *= global_range[i]->extent;
}
bool element_match =
analyzer->CanProveEqual(shared_elements, global_elements);

return (shared_is_contiguous && global_is_contiguous && element_match);
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix variable misuse and strengthen contiguity checks in CheckBulkCopy1D.

  • Uses dst field instead of the shared_tensor parameter when consulting layout_map.
  • Only checks contiguity on the global side; shared side should be checked too.
-  // Step 1: check shared is contiguous
-  bool shared_is_contiguous = true;
-  if (layout_map.count(dst)) {
-    shared_is_contiguous = false;
-  }
+  // Step 1: check shared is contiguous (no remapped/swizzled layout)
+  bool shared_is_contiguous = true;
+  if (layout_map.count(shared_tensor)) {
+    shared_is_contiguous = false;
+  }
@@
-  // Step 2: check global is contiguous
+  // Step 2: check global is contiguous
   bool global_is_contiguous = true;
   bool global_not_full_dim_encounter = false;
@@
   }
+
+  // Step 2b: check shared is contiguous by ranges (same rule as global)
+  bool shared_not_full_dim_encounter = false;
+  for (int i = shared_range.size() - 1; i >= 0; i--) {
+    if (!shared_not_full_dim_encounter) {
+      if (!analyzer->CanProve(shared_range[i]->extent == shared_tensor->shape[i] &&
+                                   shared_range[i]->min == 0,
+                               arith::ProofStrength::kSymbolicBound)) {
+        shared_not_full_dim_encounter = true;
+      }
+    } else {
+      if (!analyzer->CanProve(shared_range[i]->extent == 1,
+                              arith::ProofStrength::kSymbolicBound)) {
+        shared_is_contiguous = false;
+        break;
+      }
+    }
+  }
@@
-  return (shared_is_contiguous && global_is_contiguous && element_match);
+  return (shared_is_contiguous && global_is_contiguous && element_match);
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
bool CopyNode::CheckBulkCopy1D(const Buffer &global_tensor,
const Buffer &shared_tensor,
const Array<Range> &global_range,
const Array<Range> &shared_range,
const LayoutMap &layout_map,
arith::Analyzer *analyzer) const {
// Step 1: check shared is contiguous
bool shared_is_contiguous = true;
if (layout_map.count(dst)) {
shared_is_contiguous = false;
}
// Step 2: check global is contiguous
bool global_is_contiguous = true;
bool global_not_full_dim_encounter = false;
for (int i = global_range.size() - 1; i >= 0; i--) {
if (!global_not_full_dim_encounter) {
if (!analyzer->CanProve(global_range[i]->extent ==
global_tensor->shape[i] &&
global_range[i]->min == 0,
arith::ProofStrength::kSymbolicBound)) {
global_not_full_dim_encounter = true;
}
} else {
if (!analyzer->CanProve(global_range[i]->extent == 1,
arith::ProofStrength::kSymbolicBound)) {
global_is_contiguous = false;
break;
}
}
}
// Step 3: check element match and no OOB
PrimExpr shared_elements = 1;
for (size_t i = 0; i < shared_range.size(); i++) {
shared_elements *= shared_range[i]->extent;
}
PrimExpr global_elements = 1;
for (size_t i = 0; i < global_range.size(); i++) {
global_elements *= global_range[i]->extent;
}
bool element_match =
analyzer->CanProveEqual(shared_elements, global_elements);
return (shared_is_contiguous && global_is_contiguous && element_match);
}
bool CopyNode::CheckBulkCopy1D(const Buffer &global_tensor,
const Buffer &shared_tensor,
const Array<Range> &global_range,
const Array<Range> &shared_range,
const LayoutMap &layout_map,
arith::Analyzer *analyzer) const {
// Step 1: check shared is contiguous (no remapped/swizzled layout)
bool shared_is_contiguous = true;
if (layout_map.count(shared_tensor)) {
shared_is_contiguous = false;
}
// Step 2: check global is contiguous
bool global_is_contiguous = true;
bool global_not_full_dim_encounter = false;
for (int i = global_range.size() - 1; i >= 0; i--) {
if (!global_not_full_dim_encounter) {
if (!analyzer->CanProve(global_range[i]->extent == global_tensor->shape[i] &&
global_range[i]->min == 0,
arith::ProofStrength::kSymbolicBound)) {
global_not_full_dim_encounter = true;
}
} else {
if (!analyzer->CanProve(global_range[i]->extent == 1,
arith::ProofStrength::kSymbolicBound)) {
global_is_contiguous = false;
break;
}
}
}
// Step 2b: check shared is contiguous by ranges (same rule as global)
bool shared_not_full_dim_encounter = false;
for (int i = shared_range.size() - 1; i >= 0; i--) {
if (!shared_not_full_dim_encounter) {
if (!analyzer->CanProve(shared_range[i]->extent == shared_tensor->shape[i] &&
shared_range[i]->min == 0,
arith::ProofStrength::kSymbolicBound)) {
shared_not_full_dim_encounter = true;
}
} else {
if (!analyzer->CanProve(shared_range[i]->extent == 1,
arith::ProofStrength::kSymbolicBound)) {
shared_is_contiguous = false;
break;
}
}
}
// Step 3: check element match and no OOB
PrimExpr shared_elements = 1;
for (size_t i = 0; i < shared_range.size(); i++) {
shared_elements *= shared_range[i]->extent;
}
PrimExpr global_elements = 1;
for (size_t i = 0; i < global_range.size(); i++) {
global_elements *= global_range[i]->extent;
}
bool element_match =
analyzer->CanProveEqual(shared_elements, global_elements);
return (shared_is_contiguous && global_is_contiguous && element_match);
}
🤖 Prompt for AI Agents
In src/op/copy.cc around lines 473-519, the code mistakenly uses dst when
querying layout_map and only checks contiguity for the global side; replace the
layout_map.count(dst) call with layout_map.count(shared_tensor) and add a
symmetric contiguity check for shared_tensor (mirror the global_is_contiguous
logic: walk shared_range from the last dimension to first, use
analyzer->CanProve to detect a full-dimension then require trailing extents ==
1, set shared_is_contiguous accordingly), keeping the existing element_match
check and return expression.

src/op/copy.cc Outdated
Comment on lines 621 to 636
CopyInst CopyNode::GetCopyInst(Target target, bool disable_tma_lower,
const LayoutMap &layout_map,
arith::Analyzer *analyzer,
bool buffer_oob = false) const {
// disable_tma_lower is from pass_configs
// when tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER is True,
// we will not use tma for bulk load/store
if (!disable_tma_lower && CheckBulkLoad(target)) {

// 1d tma access can not support out of bound access
if (!disable_tma_lower && !buffer_oob &&
CheckBulkLoad1D(target, layout_map, analyzer)) {
return CopyInst::kBulkLoad1D;
} else if (!disable_tma_lower && !buffer_oob &&
CheckBulkStore1D(target, layout_map, analyzer)) {
return CopyInst::kBulkStore1D;
} else if (!disable_tma_lower && CheckBulkLoad(target)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Make buffer_oob non-optional in GetCopyInst to prevent unsafe defaults.

Defaulting buffer_oob to false risks call sites forgetting to pass it and enabling 1D TMA when OOB exists.

-CopyInst CopyNode::GetCopyInst(Target target, bool disable_tma_lower,
-                               const LayoutMap &layout_map,
-                               arith::Analyzer *analyzer,
-                               bool buffer_oob = false) const {
+CopyInst CopyNode::GetCopyInst(Target target, bool disable_tma_lower,
+                               const LayoutMap &layout_map,
+                               arith::Analyzer *analyzer,
+                               bool buffer_oob) const {

Follow-up: remove the default from the declaration in the header as well.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
CopyInst CopyNode::GetCopyInst(Target target, bool disable_tma_lower,
const LayoutMap &layout_map,
arith::Analyzer *analyzer,
bool buffer_oob = false) const {
// disable_tma_lower is from pass_configs
// when tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER is True,
// we will not use tma for bulk load/store
if (!disable_tma_lower && CheckBulkLoad(target)) {
// 1d tma access can not support out of bound access
if (!disable_tma_lower && !buffer_oob &&
CheckBulkLoad1D(target, layout_map, analyzer)) {
return CopyInst::kBulkLoad1D;
} else if (!disable_tma_lower && !buffer_oob &&
CheckBulkStore1D(target, layout_map, analyzer)) {
return CopyInst::kBulkStore1D;
} else if (!disable_tma_lower && CheckBulkLoad(target)) {
CopyInst CopyNode::GetCopyInst(Target target, bool disable_tma_lower,
const LayoutMap &layout_map,
arith::Analyzer *analyzer,
bool buffer_oob) const {
// disable_tma_lower is from pass_configs
// when tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER is True,
// we will not use tma for bulk load/store
// 1d tma access can not support out of bound access
if (!disable_tma_lower && !buffer_oob &&
CheckBulkLoad1D(target, layout_map, analyzer)) {
return CopyInst::kBulkLoad1D;
} else if (!disable_tma_lower && !buffer_oob &&
CheckBulkStore1D(target, layout_map, analyzer)) {
return CopyInst::kBulkStore1D;
} else if (!disable_tma_lower && CheckBulkLoad(target)) {
// ...
🤖 Prompt for AI Agents
In src/op/copy.cc around lines 621 to 636, the GetCopyInst signature currently
defaults buffer_oob to false which can lead callers to unintentionally enable 1D
TMA when out-of-bounds is possible; remove the default (= false) from the
definition and also remove the default in the corresponding declaration in the
header file, then audit and update all call sites to pass an explicit buffer_oob
boolean value (true/false) appropriate for their context so no caller relies on
the implicit false.

Comment on lines 608 to 616
WSCodeEmitter(bool is_emitting_producer, const IterVar &thread_iv,
Map<Var, Buffer> buffer_data_to_buffer,
const WarpSpecializedRoleMarker &marker,
bool mbarrier_only = false, bool only_has_wgmma = false)
bool mbarrier_only = false)
: is_emitting_producer_(is_emitting_producer),
buffer_data_to_buffer_(std::move(buffer_data_to_buffer)),
marker_(marker), thread_var_(thread_iv->var),
mbarrier_only_(mbarrier_only), only_has_wgmma_(only_has_wgmma) {}
mbarrier_only_(mbarrier_only) {}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Constructor signature change looks good; confirm all call sites and remove dead WgMMA remnants.

Dropping the WgMMA flag simplifies the API. Please verify there are no lingering callers still passing the old parameter or referencing the removed accessor.

Run to confirm no stale uses remain:


🏁 Script executed:

#!/bin/bash
# Constructor/call sites
rg -nP --glob '!**/build/**' '\bWSCodeEmitter\s*\(' -C2
# Old API remnants
rg -nP --glob '!**/build/**' 'onlyHasWgMMA|has_wgmma|WgMMA' -C2

Length of output: 2330


Delete dead WgMMACollector and associated HasWgMMA code
The WSCodeEmitter constructor change is confirmed—no external calls use the old signature—but the WgMMACollector class and its HasWgMMA members (src/transform/warp_specialized_rewriter.cc:572–604) remain unreferenced and should be removed.

🤖 Prompt for AI Agents
In src/transform/warp_specialized_rewriter.cc around lines 572 to 604 (and
constructor change at 608–616), remove the dead WgMMACollector class and its
HasWgMMA members which are no longer referenced: delete the entire class
definition and any HasWgMMA helper methods/fields, remove or update any forward
declarations and includes that exist solely to support them, and run a quick
search to eliminate any lingering references; ensure the WSCodeEmitter
constructor and its usage remain unchanged and compile cleanly after the
removals.

- Updated `InferLayout` calls in `AtomicAddNode`, `CopyNode`, and `FillNode` to enhance code clarity by formatting parameters across multiple lines.
- Cleaned up whitespace and formatting in `copy.h` and `layout_inference.cc` to adhere to coding standards and improve maintainability.

This refactor aims to streamline the layout inference logic and improve overall code organization.
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (3)
src/transform/layout_inference.cc (1)

319-332: Bug: const_int_bound queried with IterVar instead of Var

Use thread_var_->var when querying const_int_bound; current code won’t compile or will pick the wrong overload.

Apply:

-      if (analyzer_.const_int_bound.IsBound(thread_var_->var)) {
-        auto const_int_bound = analyzer_.const_int_bound(thread_var_);
+      if (analyzer_.const_int_bound.IsBound(thread_var_->var)) {
+        auto const_int_bound = analyzer_.const_int_bound(thread_var_->var);
@@
-        thread_bounds_vec_.push_back(Range::FromMinExtent(
-            IntImm(dtype, min_value), IntImm(dtype, extent)));
+        thread_bounds_vec_.push_back(Range::FromMinExtent(
+            IntImm(dtype, min_value), IntImm(dtype, extent)));

And similarly below:

-      if (thread_var_.defined() &&
-          analyzer_.const_int_bound.IsBound(thread_var_->var)) {
-        auto const_int_bound = analyzer_.const_int_bound(thread_var_);
+      if (thread_var_.defined() &&
+          analyzer_.const_int_bound.IsBound(thread_var_->var)) {
+        auto const_int_bound = analyzer_.const_int_bound(thread_var_->var);
src/op/copy.cc (2)

1089-1093: Potential OOB read in while-condition

Check index bounds before indexing shared_range.

Apply:

-    while (is_one(shared_range[s_range_idx]->extent) &&
-           s_range_idx < shared_range.size()) {
+    while (s_range_idx < shared_range.size() &&
+           is_one(shared_range[s_range_idx]->extent)) {
       s_range_idx++;
     }

1121-1141: Guard layout_map access; .at() throws when absent

Use count()/operator[] pattern before dereference.

Apply:

-  auto shared_layout = T.layout_map.at(shared_tensor);
-  if (!shared_layout.defined()) {
+  Layout shared_layout;
+  if (T.layout_map.count(shared_tensor)) {
+    shared_layout = T.layout_map[shared_tensor];
+  }
+  if (!shared_layout.defined()) {
♻️ Duplicate comments (5)
src/transform/layout_inference.cc (1)

333-361: OOB check misses lower-bound; add min >= 0 proofs for src/dst

Current logic only proves min+extent <= shape. Also prove min >= 0 to avoid enabling 1D TMA on negative starts.

Apply:

-        for (size_t i = 0; i < src_range.size(); i++) {
-          if (!analyzer_.CanProve(src_range[i]->min + src_range[i]->extent <=
-                                      src_tensor->shape[i],
-                                  arith::ProofStrength::kSymbolicBound)) {
-            src_oob = true;
-            break;
-          }
-        }
-        for (size_t i = 0; i < dst_range.size(); i++) {
-          if (!analyzer_.CanProve(dst_range[i]->min + dst_range[i]->extent <=
-                                      dst_tensor->shape[i],
-                                  arith::ProofStrength::kSymbolicBound)) {
-            dst_oob = true;
-            break;
-          }
-        }
+        for (size_t i = 0; i < src_range.size(); i++) {
+          bool upper_ok = analyzer_.CanProve(
+              src_range[i]->min + src_range[i]->extent <= src_tensor->shape[i],
+              arith::ProofStrength::kSymbolicBound);
+          bool lower_ok = analyzer_.CanProve(
+              src_range[i]->min >= 0, arith::ProofStrength::kSymbolicBound);
+          if (!(upper_ok && lower_ok)) {
+            src_oob = true;
+            break;
+          }
+        }
+        for (size_t i = 0; i < dst_range.size(); i++) {
+          bool upper_ok = analyzer_.CanProve(
+              dst_range[i]->min + dst_range[i]->extent <= dst_tensor->shape[i],
+              arith::ProofStrength::kSymbolicBound);
+          bool lower_ok = analyzer_.CanProve(
+              dst_range[i]->min >= 0, arith::ProofStrength::kSymbolicBound);
+          if (!(upper_ok && lower_ok)) {
+            dst_oob = true;
+            break;
+          }
+        }
src/op/copy.cc (4)

621-647: Remove default for buffer_oob in GetCopyInst definition

Header requires explicit buffer_oob; keep definition consistent to avoid unsafe defaults and mismatched redeclarations.

Apply:

-CopyInst CopyNode::GetCopyInst(Target target, bool disable_tma_lower,
-                               const LayoutMap &layout_map,
-                               arith::Analyzer *analyzer,
-                               bool buffer_oob = false) const {
+CopyInst CopyNode::GetCopyInst(Target target, bool disable_tma_lower,
+                               const LayoutMap &layout_map,
+                               arith::Analyzer *analyzer,
+                               bool buffer_oob) const {

661-674: Lower(): pass buffer_oob into GetCopyInst

Without this, 1D TMA may be selected despite OOB.

Apply:

-  auto copy_inst =
-      GetCopyInst(target, disable_tma_lower, T.layout_map, analyzer);
+  auto copy_inst =
+      GetCopyInst(target, disable_tma_lower, T.layout_map, analyzer, T.buffer_oob);

473-519: Fix shared tensor lookup and add shared contiguity scan

  • Uses dst instead of shared_tensor when checking layout_map.
  • Lacks symmetric contiguity check on the shared side.

Apply:

-  // Step 1: check shared is contiguous
-  bool shared_is_contiguous = true;
-  if (layout_map.count(dst)) {
-    shared_is_contiguous = false;
-  }
+  // Step 1: check shared is contiguous (no remapped/swizzled layout)
+  bool shared_is_contiguous = true;
+  if (layout_map.count(shared_tensor)) {
+    shared_is_contiguous = false;
+  }
@@
   // Step 2: check global is contiguous
   bool global_is_contiguous = true;
   bool global_not_full_dim_encounter = false;
@@
   }
+
+  // Step 2b: check shared is contiguous by ranges (same rule as global)
+  bool shared_not_full_dim_encounter = false;
+  for (int i = shared_range.size() - 1; i >= 0; i--) {
+    if (!shared_not_full_dim_encounter) {
+      if (!analyzer->CanProve(shared_range[i]->extent == shared_tensor->shape[i] &&
+                                   shared_range[i]->min == 0,
+                               arith::ProofStrength::kSymbolicBound)) {
+        shared_not_full_dim_encounter = true;
+      }
+    } else {
+      if (!analyzer->CanProve(shared_range[i]->extent == 1,
+                              arith::ProofStrength::kSymbolicBound)) {
+        shared_is_contiguous = false;
+        break;
+      }
+    }
+  }

1247-1320: LowerBulkCopy1D: honor buffer_remap and no-op on zero elements

  • Use T.buffer_remap for both tensors (aligns with LDSM path).
  • Early return if elements == 0.

Apply:

   bool is_load = copy_inst == CopyInst::kBulkLoad1D;
@@
-  auto shared_tensor = is_load ? dst : src;
-  auto global_tensor = is_load ? src : dst;
+  auto shared_tensor = is_load ? dst : src;
+  auto global_tensor = is_load ? src : dst;
+  if (T.buffer_remap.count(shared_tensor)) shared_tensor = T.buffer_remap[shared_tensor];
+  if (T.buffer_remap.count(global_tensor)) global_tensor = T.buffer_remap[global_tensor];
@@
-  PrimExpr elements = analyzer->Simplify(shared_elements);
+  PrimExpr elements = analyzer->Simplify(shared_elements);
+  if (analyzer->CanProveEqual(elements, 0)) {
+    return IfThenElse(EQ(T.thread_var, T.thread_bounds->min), Evaluate(0));
+  }
🧹 Nitpick comments (1)
src/op/copy.cc (1)

198-206: Fix doc typo

“Create s” → “Create indices”.

Apply:

- * \brief Create s for the copy operation.
+ * \brief Create indices for the copy operation.
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 65bbb80 and 1a9f0ad.

📒 Files selected for processing (5)
  • src/op/atomic_add.cc (1 hunks)
  • src/op/copy.cc (12 hunks)
  • src/op/copy.h (5 hunks)
  • src/op/fill.cc (3 hunks)
  • src/transform/layout_inference.cc (10 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
  • src/op/atomic_add.cc
  • src/op/fill.cc
🧰 Additional context used
🧬 Code graph analysis (3)
src/op/copy.h (1)
src/op/copy.cc (16)
  • InferLayout (398-433)
  • InferLayout (398-399)
  • InferLayout (1611-1614)
  • InferLayout (1611-1612)
  • CheckBulkLoad1D (520-530)
  • CheckBulkLoad1D (520-521)
  • CheckBulkStore1D (532-542)
  • CheckBulkStore1D (532-533)
  • CheckBulkCopy1D (473-518)
  • CheckBulkCopy1D (473-478)
  • GetCopyInst (621-647)
  • GetCopyInst (621-624)
  • LowerBulkCopy (976-1245)
  • LowerBulkCopy (976-977)
  • LowerBulkCopy1D (1247-1320)
  • LowerBulkCopy1D (1247-1248)
src/transform/layout_inference.cc (1)
tilelang/language/copy.py (1)
  • copy (84-152)
src/op/copy.cc (1)
src/tl_templates/cuda/copy_sm90.h (12)
  • tma_load (18-27)
  • tma_load (43-60)
  • tma_load (64-82)
  • tma_load (86-104)
  • tma_load (107-126)
  • tma_load (130-150)
  • tma_store (175-182)
  • tma_store (185-195)
  • tma_store (198-209)
  • tma_store (212-223)
  • tma_store (226-238)
  • tma_store (241-253)
🔇 Additional comments (14)
src/transform/layout_inference.cc (6)

17-20: Include dependency looks correct

Including copy.h here is appropriate for CopyNode checks.


69-70: Constructor inheritance: OK

Inheriting IRVisitorWithAnalyzer constructors is fine.


209-212: Size consistency check: good

Ensuring buffer_oob_vec_ aligns with infer_list_ is necessary and correct.


410-410: Parallel For default OOB flag = false: OK

Initializing per-op OOB to false for kParallel nodes makes sense.


603-604: Constructor inheritance: OK

IRMutatorWithAnalyzer constructor inheritance is fine.


107-112: LayoutInferArgs initializer order is correct
The brace-init values—target_, thread_bounds, layout_map, &analyzer_, buffer_oob—exactly match the first five fields of LayoutInferArgs (in declaration order), so there’s no misbinding risk.

src/op/copy.h (6)

18-26: New 1D TMA enums: OK

Separating 1D from multi-dim variants is a good call.


144-146: Override annotation: OK

Marking InferLayout as override is correct.


157-177: 1D viability checks: OK (ensure consistent impl in .cc)

Interfaces look good; see .cc comments for impl fixes.


191-194: Signature requires explicit buffer_oob: good

Header removes any default. Make sure the .cc definition also has no default (see comment there).


203-207: LowerBulkCopy1D hook: OK

API surface looks good; see .cc for remap/zero-length refinements.


350-352: Conv2DIm2Col override: OK

Signature matches usage.

src/op/copy.cc (2)

398-408: InferLayout now gates by buffer_oob: OK

Passing analyzer and buffer_oob into GetCopyInst is correct.


726-729: LayoutInferArgs initializer arguments match struct field order
The call’s arguments (analyzer, false, T.buffer_remap) align with the struct’s fields (arith::Analyzer* analyzer, bool buffer_oob, Map<Buffer, Buffer> buffer_remap).

- Updated the condition in `CheckBulkCopy1D` to verify contiguity of `shared_tensor` instead of `dst`, ensuring correct handling of shared memory layouts during bulk copy operations.
- This change enhances the accuracy of memory operations in the TileLang framework.
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
src/op/copy.cc (1)

621-647: Make buffer_oob explicit in GetCopyInst (remove default).

Defaulting to false is unsafe; callers will silently enable 1D TMA under OOB.

-CopyInst CopyNode::GetCopyInst(Target target, bool disable_tma_lower,
-                               const LayoutMap &layout_map,
-                               arith::Analyzer *analyzer,
-                               bool buffer_oob = false) const {
+CopyInst CopyNode::GetCopyInst(Target target, bool disable_tma_lower,
+                               const LayoutMap &layout_map,
+                               arith::Analyzer *analyzer,
+                               bool buffer_oob) const {

Also remove the default in the header declaration and update all call sites to pass an explicit boolean.

♻️ Duplicate comments (2)
src/op/copy.cc (2)

473-519: Complete shared-side contiguity validation (and nice fix on layout_map var).

You correctly switched to layout_map.count(shared_tensor). However, contiguity is only proven for the global side; 1D TMA can still be mis-selected if shared_range isn’t contiguous-by-ranges (even when not swizzled). Mirror the global check for shared_range.

Apply this diff within CheckBulkCopy1D:

   // Step 1: check shared is contiguous
   bool shared_is_contiguous = true;
   if (layout_map.count(shared_tensor)) {
     shared_is_contiguous = false;
   }
   // Step 2: check global is contiguous
   bool global_is_contiguous = true;
   bool global_not_full_dim_encounter = false;
   for (int i = global_range.size() - 1; i >= 0; i--) {
     if (!global_not_full_dim_encounter) {
       if (!analyzer->CanProve(global_range[i]->extent ==
                                       global_tensor->shape[i] &&
                                   global_range[i]->min == 0,
                               arith::ProofStrength::kSymbolicBound)) {
         global_not_full_dim_encounter = true;
       }
     } else {
       if (!analyzer->CanProve(global_range[i]->extent == 1,
                               arith::ProofStrength::kSymbolicBound)) {
         global_is_contiguous = false;
         break;
       }
     }
   }
+
+  // Step 2b: check shared is contiguous by ranges (symmetric rule)
+  bool shared_not_full_dim_encounter = false;
+  for (int i = shared_range.size() - 1; i >= 0; i--) {
+    if (!shared_not_full_dim_encounter) {
+      if (!analyzer->CanProve(shared_range[i]->extent ==
+                                      shared_tensor->shape[i] &&
+                                  shared_range[i]->min == 0,
+                              arith::ProofStrength::kSymbolicBound)) {
+        shared_not_full_dim_encounter = true;
+      }
+    } else {
+      if (!analyzer->CanProve(shared_range[i]->extent == 1,
+                              arith::ProofStrength::kSymbolicBound)) {
+        shared_is_contiguous = false;
+        break;
+      }
+    }
+  }

667-676: Include buffer_oob in GetCopyInst call in Lower()
In src/op/copy.cc:667 the call to GetCopyInst omits T.buffer_oob, so OOB checks aren’t considered and 1D TMA may be mis-picked. Change to:

- auto copy_inst = GetCopyInst(target, disable_tma_lower, T.layout_map, analyzer);
+ auto copy_inst = GetCopyInst(target, disable_tma_lower, T.layout_map, analyzer, T.buffer_oob);
🧹 Nitpick comments (3)
src/op/copy.cc (3)

199-206: Fix typo in MakeIndices doc.

"Create s" -> "Create indices".

- * \brief Create s for the copy operation.
+ * \brief Create indices for the copy operation.

726-729: Propagate real buffer_oob into parallel layout inference.

Minor, but keeps the pipeline state consistent.

-      par_op->InferLayout({T.target, T.thread_bounds, T.layout_map, analyzer,
-                           false, T.buffer_remap},
+      par_op->InferLayout({T.target, T.thread_bounds, T.layout_map, analyzer,
+                           T.buffer_oob, T.buffer_remap},
                           level);

1116-1127: Avoid layout_map.at() without existence check.

Map::at may throw; use count() then operator[] to allow “undefined” layouts.

-  auto shared_layout = T.layout_map.at(shared_tensor);
+  Layout shared_layout;
+  if (T.layout_map.count(shared_tensor)) {
+    shared_layout = T.layout_map[shared_tensor];
+  }
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 1a9f0ad and 1545fb1.

📒 Files selected for processing (1)
  • src/op/copy.cc (12 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/op/copy.cc (2)
src/tl_templates/cuda/copy_sm90.h (12)
  • tma_load (18-27)
  • tma_load (43-60)
  • tma_load (64-82)
  • tma_load (86-104)
  • tma_load (107-126)
  • tma_load (130-150)
  • tma_store (175-182)
  • tma_store (185-195)
  • tma_store (198-209)
  • tma_store (212-223)
  • tma_store (226-238)
  • tma_store (241-253)
tilelang/language/builtin.py (1)
  • tma_load (67-76)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: format-check
  • GitHub Check: bot-task
🔇 Additional comments (3)
src/op/copy.cc (3)

406-408: Good: GetCopyInst now uses analyzer and buffer_oob during layout inference.

This keeps 1D TMA gating consistent between inference and lowering.


520-543: LGTM: 1D load/store viability wrappers.

The scaffolding reuses base checks and correctly threads layout_map/analyzer.


996-1005: LGTM: explicit shared/global offsets for TMA addressing.

The stride/offset derivation is correct and simplifies subsequent address math.

Also applies to: 1019-1025

Comment on lines +1247 to +1320
Stmt CopyNode::LowerBulkCopy1D(const LowerArgs &T, arith::Analyzer *analyzer,
CopyInst copy_inst) const {
ICHECK(copy_inst == CopyInst::kBulkLoad1D ||
copy_inst == CopyInst::kBulkStore1D);

// Add 1D TMA copy when the global and shared memory is contiguous
// Check if shared_tensor->name is present in T.buffer_var_gemm
// (Array<PrimExpr>) to avoid use 1D TMA copy for swizzled layout
bool is_load = copy_inst == CopyInst::kBulkLoad1D;
auto shared_range = is_load ? dst_range : src_range;
auto global_range = is_load ? src_range : dst_range;
auto shared_tensor = is_load ? dst : src;
auto global_tensor = is_load ? src : dst;

PrimExpr shared_elements = 1;
for (size_t i = 0; i < shared_range.size(); i++) {
shared_elements *= shared_range[i]->extent;
}

std::vector<PrimExpr> shared_strides;
PrimExpr shared_stride = 1;
for (size_t i = 0; i < shared_tensor->shape.size(); i++) {
auto s = shared_tensor->shape[shared_tensor->shape.size() - i - 1];
shared_strides.insert(shared_strides.begin(), shared_stride);
shared_stride *= s;
}

Array<PrimExpr> shared_indices;
for (auto r : shared_range)
shared_indices.push_back(r->min);

Array<PrimExpr> global_indices;
for (auto r : global_range) {
global_indices.push_back(r->min);
}
std::vector<PrimExpr> global_strides;
PrimExpr global_stride = 1;
for (size_t i = 0; i < global_tensor->shape.size(); i++) {
auto s = global_tensor->shape[global_tensor->shape.size() - i - 1];
global_strides.insert(global_strides.begin(), global_stride);
global_stride *= s;
}

PrimExpr global_offset = 0;
for (size_t i = 0; i < global_indices.size(); i++) {
global_offset += global_indices[i] * global_strides[i];
}

PrimExpr shared_offset = 0;
for (size_t i = 0; i < shared_indices.size(); i++) {
shared_offset += shared_indices[i] * shared_strides[i];
}

PrimExpr elements = analyzer->Simplify(shared_elements);
PrimExpr shared_addr = shared_tensor.access_ptr(
is_load ? 2 : 1, DataType::Handle(), 1, shared_offset, elements);
PrimExpr global_addr = global_tensor.access_ptr(
is_load ? 1 : 2, DataType::Handle(), 1, global_offset, elements);
Stmt tma_copy;
if (is_load) {
// the zero is a placeholder for mbarrier ids
tma_copy = Evaluate(
Call(DataType::Handle(), tma_load(),
{shared_addr, global_addr, 0,
elements * shared_tensor->dtype.bytes(), this->eviction_policy}));
} else {
tma_copy = Evaluate(
Call(DataType::Handle(), tma_store(),
{global_addr, shared_addr, elements * shared_tensor->dtype.bytes(),
this->eviction_policy}));
}
tma_copy = IfThenElse(EQ(T.thread_var, T.thread_bounds->min), tma_copy);
return tma_copy;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix 1D TMA lowering: wrong intrinsic signature, missing descriptor, no remap, no zero-length guard.

Current 1D path passes (shared_addr, global_addr, mbar, bytes, eviction) to tma_load/tma_store, but intrinsics expect (descriptor, [mbar], smem_ptr, crd0, cache_hint). Also ignores T.buffer_remap and doesn’t early-out for zero elements.

Apply this rewrite (minimal, descriptor-based, remap-aware, 0-eval guard):

 Stmt CopyNode::LowerBulkCopy1D(const LowerArgs &T, arith::Analyzer *analyzer,
                                CopyInst copy_inst) const {
   ICHECK(copy_inst == CopyInst::kBulkLoad1D ||
          copy_inst == CopyInst::kBulkStore1D);

-  // Add 1D TMA copy when the global and shared memory is contiguous
-  // Check if shared_tensor->name is present in T.buffer_var_gemm
-  // (Array<PrimExpr>) to avoid use 1D TMA copy for swizzled layout
   bool is_load = copy_inst == CopyInst::kBulkLoad1D;
-  auto shared_range = is_load ? dst_range : src_range;
-  auto global_range = is_load ? src_range : dst_range;
-  auto shared_tensor = is_load ? dst : src;
-  auto global_tensor = is_load ? src : dst;
+  auto shared_range = is_load ? dst_range : src_range;
+  auto global_range = is_load ? src_range : dst_range;
+  Buffer shared_tensor = is_load ? dst : src;
+  Buffer global_tensor = is_load ? src : dst;
+  if (T.buffer_remap.count(shared_tensor)) shared_tensor = T.buffer_remap[shared_tensor];
+  if (T.buffer_remap.count(global_tensor)) global_tensor = T.buffer_remap[global_tensor];

-  PrimExpr shared_elements = 1;
+  PrimExpr shared_elements = 1;
   for (size_t i = 0; i < shared_range.size(); i++) {
     shared_elements *= shared_range[i]->extent;
   }
 
-  std::vector<PrimExpr> shared_strides;
-  PrimExpr shared_stride = 1;
+  std::vector<PrimExpr> shared_strides;
+  PrimExpr shared_stride = 1;
   for (size_t i = 0; i < shared_tensor->shape.size(); i++) {
     auto s = shared_tensor->shape[shared_tensor->shape.size() - i - 1];
     shared_strides.insert(shared_strides.begin(), shared_stride);
     shared_stride *= s;
   }
 
-  Array<PrimExpr> shared_indices;
+  Array<PrimExpr> shared_indices;
   for (auto r : shared_range)
     shared_indices.push_back(r->min);
 
-  Array<PrimExpr> global_indices;
+  Array<PrimExpr> global_indices;
   for (auto r : global_range) {
     global_indices.push_back(r->min);
   }
-  std::vector<PrimExpr> global_strides;
-  PrimExpr global_stride = 1;
+  std::vector<PrimExpr> global_strides;
+  PrimExpr global_stride = 1;
   for (size_t i = 0; i < global_tensor->shape.size(); i++) {
     auto s = global_tensor->shape[global_tensor->shape.size() - i - 1];
     global_strides.insert(global_strides.begin(), global_stride);
     global_stride *= s;
   }
 
-  PrimExpr global_offset = 0;
+  PrimExpr global_offset = 0;
   for (size_t i = 0; i < global_indices.size(); i++) {
     global_offset += global_indices[i] * global_strides[i];
   }
 
-  PrimExpr shared_offset = 0;
+  PrimExpr shared_offset = 0;
   for (size_t i = 0; i < shared_indices.size(); i++) {
     shared_offset += shared_indices[i] * shared_strides[i];
   }
 
-  PrimExpr elements = analyzer->Simplify(shared_elements);
-  PrimExpr shared_addr = shared_tensor.access_ptr(
-      is_load ? 2 : 1, DataType::Handle(), 1, shared_offset, elements);
-  PrimExpr global_addr = global_tensor.access_ptr(
-      is_load ? 1 : 2, DataType::Handle(), 1, global_offset, elements);
-  Stmt tma_copy;
-  if (is_load) {
-    // the zero is a placeholder for mbarrier ids
-    tma_copy = Evaluate(
-        Call(DataType::Handle(), tma_load(),
-             {shared_addr, global_addr, 0,
-              elements * shared_tensor->dtype.bytes(), this->eviction_policy}));
-  } else {
-    tma_copy = Evaluate(
-        Call(DataType::Handle(), tma_store(),
-             {global_addr, shared_addr, elements * shared_tensor->dtype.bytes(),
-              this->eviction_policy}));
-  }
-  tma_copy = IfThenElse(EQ(T.thread_var, T.thread_bounds->min), tma_copy);
-  return tma_copy;
+  PrimExpr elements = analyzer->Simplify(shared_elements);
+  if (analyzer->CanProveEqual(elements, 0)) {
+    return IfThenElse(EQ(T.thread_var, T.thread_bounds->min), Evaluate(0));
+  }
+
+  // Build a minimal 1D TMA descriptor over the flattened global tensor.
+  TMADesc desc;
+  desc.rank = 1;
+  desc.data_type = to_CUtensorMapDataType(global_tensor->dtype);
+  desc.global_addr = global_tensor->data;
+  PrimExpr total_global_elems = 1;
+  for (auto s : global_tensor->shape) total_global_elems *= s;
+  desc.global_shape = { total_global_elems };
+  desc.global_stride = { cast(DataType::Int(64), global_tensor->dtype.bytes()) };
+  desc.smem_box = { elements };
+  desc.smem_stride = { PrimExpr(1) };
+  desc.interleave = static_cast<int>(CU_TENSOR_MAP_INTERLEAVE_NONE);
+  desc.swizzle    = static_cast<int>(CU_TENSOR_MAP_SWIZZLE_NONE);
+  desc.l2_promotion = static_cast<int>(CU_TENSOR_MAP_L2_PROMOTION_L2_128B);
+  desc.oob_fill     = static_cast<int>(CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE);
+
+  Call create_descriptor =
+      Call(DataType::Handle(), create_tma_descriptor(), desc.EncodeCallArgs());
+
+  // Addresses and 1D coordinate
+  PrimExpr shared_addr = shared_tensor.access_ptr(
+      is_load ? 2 : 1, DataType::Handle(), 1, shared_offset, elements);
+  PrimExpr crd0 = cast(DataType::Int(32), global_offset);
+
+  Array<PrimExpr> args;
+  args.push_back(create_descriptor);
+  if (is_load) args.push_back(0); // mbarrier placeholder
+  args.push_back(shared_addr);
+  args.push_back(crd0);
+  args.push_back(this->eviction_policy);
+  auto op = is_load ? tma_load() : tma_store();
+  Stmt tma_copy = Evaluate(Call(DataType::Handle(), op, args));
+  tma_copy = IfThenElse(EQ(T.thread_var, T.thread_bounds->min), tma_copy);
+  return tma_copy;
 }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
Stmt CopyNode::LowerBulkCopy1D(const LowerArgs &T, arith::Analyzer *analyzer,
CopyInst copy_inst) const {
ICHECK(copy_inst == CopyInst::kBulkLoad1D ||
copy_inst == CopyInst::kBulkStore1D);
// Add 1D TMA copy when the global and shared memory is contiguous
// Check if shared_tensor->name is present in T.buffer_var_gemm
// (Array<PrimExpr>) to avoid use 1D TMA copy for swizzled layout
bool is_load = copy_inst == CopyInst::kBulkLoad1D;
auto shared_range = is_load ? dst_range : src_range;
auto global_range = is_load ? src_range : dst_range;
auto shared_tensor = is_load ? dst : src;
auto global_tensor = is_load ? src : dst;
PrimExpr shared_elements = 1;
for (size_t i = 0; i < shared_range.size(); i++) {
shared_elements *= shared_range[i]->extent;
}
std::vector<PrimExpr> shared_strides;
PrimExpr shared_stride = 1;
for (size_t i = 0; i < shared_tensor->shape.size(); i++) {
auto s = shared_tensor->shape[shared_tensor->shape.size() - i - 1];
shared_strides.insert(shared_strides.begin(), shared_stride);
shared_stride *= s;
}
Array<PrimExpr> shared_indices;
for (auto r : shared_range)
shared_indices.push_back(r->min);
Array<PrimExpr> global_indices;
for (auto r : global_range) {
global_indices.push_back(r->min);
}
std::vector<PrimExpr> global_strides;
PrimExpr global_stride = 1;
for (size_t i = 0; i < global_tensor->shape.size(); i++) {
auto s = global_tensor->shape[global_tensor->shape.size() - i - 1];
global_strides.insert(global_strides.begin(), global_stride);
global_stride *= s;
}
PrimExpr global_offset = 0;
for (size_t i = 0; i < global_indices.size(); i++) {
global_offset += global_indices[i] * global_strides[i];
}
PrimExpr shared_offset = 0;
for (size_t i = 0; i < shared_indices.size(); i++) {
shared_offset += shared_indices[i] * shared_strides[i];
}
PrimExpr elements = analyzer->Simplify(shared_elements);
PrimExpr shared_addr = shared_tensor.access_ptr(
is_load ? 2 : 1, DataType::Handle(), 1, shared_offset, elements);
PrimExpr global_addr = global_tensor.access_ptr(
is_load ? 1 : 2, DataType::Handle(), 1, global_offset, elements);
Stmt tma_copy;
if (is_load) {
// the zero is a placeholder for mbarrier ids
tma_copy = Evaluate(
Call(DataType::Handle(), tma_load(),
{shared_addr, global_addr, 0,
elements * shared_tensor->dtype.bytes(), this->eviction_policy}));
} else {
tma_copy = Evaluate(
Call(DataType::Handle(), tma_store(),
{global_addr, shared_addr, elements * shared_tensor->dtype.bytes(),
this->eviction_policy}));
}
tma_copy = IfThenElse(EQ(T.thread_var, T.thread_bounds->min), tma_copy);
return tma_copy;
}
Stmt CopyNode::LowerBulkCopy1D(const LowerArgs &T, arith::Analyzer *analyzer,
CopyInst copy_inst) const {
ICHECK(copy_inst == CopyInst::kBulkLoad1D ||
copy_inst == CopyInst::kBulkStore1D);
bool is_load = copy_inst == CopyInst::kBulkLoad1D;
auto shared_range = is_load ? dst_range : src_range;
auto global_range = is_load ? src_range : dst_range;
Buffer shared_tensor = is_load ? dst : src;
Buffer global_tensor = is_load ? src : dst;
// Honor any buffer remappings
if (T.buffer_remap.count(shared_tensor)) {
shared_tensor = T.buffer_remap[shared_tensor];
}
if (T.buffer_remap.count(global_tensor)) {
global_tensor = T.buffer_remap[global_tensor];
}
// Compute flattened element count
PrimExpr shared_elements = 1;
for (size_t i = 0; i < shared_range.size(); i++) {
shared_elements *= shared_range[i]->extent;
}
// Compute shared strides & base offset
std::vector<PrimExpr> shared_strides;
PrimExpr shared_stride = 1;
for (size_t i = 0; i < shared_tensor->shape.size(); i++) {
auto s = shared_tensor->shape[shared_tensor->shape.size() - i - 1];
shared_strides.insert(shared_strides.begin(), shared_stride);
shared_stride *= s;
}
Array<PrimExpr> shared_indices;
for (auto r : shared_range) {
shared_indices.push_back(r->min);
}
PrimExpr shared_offset = 0;
for (size_t i = 0; i < shared_indices.size(); i++) {
shared_offset += shared_indices[i] * shared_strides[i];
}
// Compute global strides & base offset
Array<PrimExpr> global_indices;
for (auto r : global_range) {
global_indices.push_back(r->min);
}
std::vector<PrimExpr> global_strides;
PrimExpr global_stride = 1;
for (size_t i = 0; i < global_tensor->shape.size(); i++) {
auto s = global_tensor->shape[global_tensor->shape.size() - i - 1];
global_strides.insert(global_strides.begin(), global_stride);
global_stride *= s;
}
PrimExpr global_offset = 0;
for (size_t i = 0; i < global_indices.size(); i++) {
global_offset += global_indices[i] * global_strides[i];
}
// Simplify and early-exit on zero elements
PrimExpr elements = analyzer->Simplify(shared_elements);
if (analyzer->CanProveEqual(elements, 0)) {
return IfThenElse(EQ(T.thread_var, T.thread_bounds->min),
Evaluate(0));
}
// Build a minimal 1D TMA descriptor
TMADesc desc;
desc.rank = 1;
desc.data_type = to_CUtensorMapDataType(global_tensor->dtype);
desc.global_addr = global_tensor->data;
PrimExpr total_elems = 1;
for (auto s : global_tensor->shape) {
total_elems *= s;
}
desc.global_shape = { total_elems };
desc.global_stride = { cast(DataType::Int(64),
global_tensor->dtype.bytes()) };
desc.smem_box = { elements };
desc.smem_stride = { PrimExpr(1) };
desc.interleave = static_cast<int>(CU_TENSOR_MAP_INTERLEAVE_NONE);
desc.swizzle = static_cast<int>(CU_TENSOR_MAP_SWIZZLE_NONE);
desc.l2_promotion = static_cast<int>(CU_TENSOR_MAP_L2_PROMOTION_L2_128B);
desc.oob_fill = static_cast<int>(CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE);
Call create_descriptor =
Call(DataType::Handle(),
create_tma_descriptor(),
desc.EncodeCallArgs());
// Prepare the final call arguments
PrimExpr shared_addr = shared_tensor.access_ptr(
is_load ? 2 : 1, DataType::Handle(), 1, shared_offset, elements);
PrimExpr crd0 = cast(DataType::Int(32), global_offset);
Array<PrimExpr> args;
args.push_back(create_descriptor);
if (is_load) {
args.push_back(0); // mbarrier placeholder
}
args.push_back(shared_addr);
args.push_back(crd0);
args.push_back(this->eviction_policy);
auto op = is_load ? tma_load() : tma_store();
Stmt tma_copy = Evaluate(Call(DataType::Handle(), op, args));
tma_copy = IfThenElse(EQ(T.thread_var, T.thread_bounds->min), tma_copy);
return tma_copy;
}
🤖 Prompt for AI Agents
In src/op/copy.cc around lines 1247 to 1320 the 1D TMA lowering is incorrect: it
calls tma_load/tma_store with (shared_addr, global_addr, mbar, bytes, eviction)
but the intrinsic expects a descriptor as the first arg, optional mbar id, then
smem_ptr, crd0, and cache hint; it also ignores T.buffer_remap and doesn't guard
zero-length copies. Fix by (1) early-out: wrap the TMA call in an IfThenElse
checking Simplified(elements) > 0 so zero-length does nothing; (2) build a TMA
descriptor from the global address/shape/strides/element count (use the proper
helper/intrinsic to create a descriptor from global_addr + elements/strides) and
pass that descriptor as the first argument to tma_load/tma_store; (3) include
the mbar placeholder only where required, then pass smem_ptr (shared_addr) and
the coordinate/crd0 derived from global_offset/indices, and finally pass the
eviction/cache hint; (4) respect T.buffer_remap: if the shared/global tensor is
remapped (check T.buffer_remap for the tensor name) either fall back to the
non-1D path or apply the remap when computing addresses so the descriptor and
pointers reflect the remapped layout. Ensure the final Call uses the corrected
argument order and types and keep the IfThenElse guard around the whole TMA
Evaluate.

- Commented out the call to `tilelang.testing.main()` in `test_example_gdn_compilation.py` and replaced it with a direct call to `test_example_chunk_delta_bwd_compilation()`. This change simplifies the test execution flow and focuses on the specific test case.
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
src/op/copy.cc (1)

1018-1050: Compute shared_offset after remap using layout->Forward/OffsetOf

shared_offset is derived before potential buffer remap and without applying the shared layout transform. This can yield wrong SMEM base for swizzled layouts.

Apply:

-  Array<PrimExpr> shared_indices;
-  for (auto r : shared_range)
-    shared_indices.push_back(r->min);
-  std::vector<PrimExpr> shared_strides;
-  PrimExpr shared_stride = 1;
-  for (size_t i = 0; i < shared_tensor->shape.size(); i++) {
-    auto s = shared_tensor->shape[shared_tensor->shape.size() - i - 1];
-    shared_strides.insert(shared_strides.begin(), shared_stride);
-    shared_stride *= s;
-  }
+  Array<PrimExpr> shared_indices;
+  for (auto r : shared_range) shared_indices.push_back(r->min);
@@
-  ICHECK(shared_strides.size() == shared_indices.size())
-      << "shared_strides.size() != shared_indices.size()"
-      << shared_strides.size() << " " << shared_indices.size();
-  PrimExpr shared_offset = 0;
-  for (size_t i = 0; i < shared_indices.size(); i++) {
-    shared_offset += shared_indices[i] * shared_strides[i];
-  }
+  // shared_offset is computed after possible remap below
♻️ Duplicate comments (5)
src/op/copy.h (1)

191-194: Make buffer_oob non-defaulted across declaration/definition

Header correctly requires an explicit buffer_oob; the definition in copy.cc still supplies a default. Please remove the default from the definition for consistency and to prevent unsafe callsites.

src/op/copy.cc (4)

646-650: Remove default argument for buffer_oob in definition

Keep defaults only in declarations (and here we want no default).

Apply:

-CopyInst CopyNode::GetCopyInst(Target target, bool disable_tma_lower,
-                               const LayoutMap &layout_map,
-                               arith::Analyzer *analyzer,
-                               bool buffer_oob = false) const {
+CopyInst CopyNode::GetCopyInst(Target target, bool disable_tma_lower,
+                               const LayoutMap &layout_map,
+                               arith::Analyzer *analyzer,
+                               bool buffer_oob) const {

486-531: Strengthen 1D contiguity: add symmetric shared-range check

Currently only global contiguity is validated by ranges; shared contiguity is only inferred from absence in layout_map. Add a shared-range walk mirroring the global logic to avoid false positives.

Apply:

   // Step 1: check shared is contiguous
   bool shared_is_contiguous = true;
   if (layout_map.count(shared_tensor)) {
     shared_is_contiguous = false;
   }
   // Step 2: check global is contiguous
   bool global_is_contiguous = true;
   bool global_not_full_dim_encounter = false;
@@
   }
 
+  // Step 2b: check shared is contiguous by ranges (same rule as global)
+  bool shared_not_full_dim_encounter = false;
+  for (int i = shared_range.size() - 1; i >= 0; i--) {
+    if (!shared_not_full_dim_encounter) {
+      if (!analyzer->CanProve(shared_range[i]->extent == shared_tensor->shape[i] &&
+                                   shared_range[i]->min == 0,
+                               arith::ProofStrength::kSymbolicBound)) {
+        shared_not_full_dim_encounter = true;
+      }
+    } else {
+      if (!analyzer->CanProve(shared_range[i]->extent == 1,
+                              arith::ProofStrength::kSymbolicBound)) {
+        shared_is_contiguous = false;
+        break;
+      }
+    }
+  }

692-699: Pass buffer_oob to GetCopyInst in Lower()

Lower() re-selects the copy inst but drops OOB gating, potentially choosing 1D TMA incorrectly.

Apply:

-  auto copy_inst =
-      GetCopyInst(target, disable_tma_lower, T.layout_map, analyzer);
+  auto copy_inst =
+      GetCopyInst(target, disable_tma_lower, T.layout_map, analyzer, T.buffer_oob);

1281-1354: Fix 1D TMA lowering: wrong intrinsic shape, missing descriptor/remap/zero-guard

Current path calls tma_load/tma_store with raw pointers/byte counts, ignores descriptor+coords form, ignores buffer_remap, and doesn’t guard zero-length copies. Rewrite to the descriptor-based 1D tensor form.

Apply:

 Stmt CopyNode::LowerBulkCopy1D(const LowerArgs &T, arith::Analyzer *analyzer,
                                CopyInst copy_inst) const {
   ICHECK(copy_inst == CopyInst::kBulkLoad1D ||
          copy_inst == CopyInst::kBulkStore1D);
 
-  // Add 1D TMA copy when the global and shared memory is contiguous
-  // Check if shared_tensor->name is present in T.buffer_var_gemm
-  // (Array<PrimExpr>) to avoid use 1D TMA copy for swizzled layout
   bool is_load = copy_inst == CopyInst::kBulkLoad1D;
-  auto shared_range = is_load ? dst_range : src_range;
-  auto global_range = is_load ? src_range : dst_range;
-  auto shared_tensor = is_load ? dst : src;
-  auto global_tensor = is_load ? src : dst;
+  auto shared_range = is_load ? dst_range : src_range;
+  auto global_range = is_load ? src_range : dst_range;
+  Buffer shared_tensor = is_load ? dst : src;
+  Buffer global_tensor = is_load ? src : dst;
+  if (T.buffer_remap.count(shared_tensor)) shared_tensor = T.buffer_remap.at(shared_tensor);
+  if (T.buffer_remap.count(global_tensor)) global_tensor = T.buffer_remap.at(global_tensor);
 
   PrimExpr shared_elements = 1;
   for (size_t i = 0; i < shared_range.size(); i++) {
     shared_elements *= shared_range[i]->extent;
   }
@@
-  PrimExpr elements = analyzer->Simplify(shared_elements);
-  PrimExpr shared_addr = shared_tensor.access_ptr(
-      is_load ? 2 : 1, DataType::Handle(), 1, shared_offset, elements);
-  PrimExpr global_addr = global_tensor.access_ptr(
-      is_load ? 1 : 2, DataType::Handle(), 1, global_offset, elements);
-  Stmt tma_copy;
-  if (is_load) {
-    // the zero is a placeholder for mbarrier ids
-    tma_copy = Evaluate(
-        Call(DataType::Handle(), tma_load(),
-             {shared_addr, global_addr, 0,
-              elements * shared_tensor->dtype.bytes(), this->eviction_policy}));
-  } else {
-    tma_copy = Evaluate(
-        Call(DataType::Handle(), tma_store(),
-             {global_addr, shared_addr, elements * shared_tensor->dtype.bytes(),
-              this->eviction_policy}));
-  }
-  tma_copy = IfThenElse(EQ(T.thread_var, T.thread_bounds->min), tma_copy);
-  return tma_copy;
+  PrimExpr elements = analyzer->Simplify(shared_elements);
+  if (analyzer->CanProveEqual(elements, 0)) {
+    return IfThenElse(EQ(T.thread_var, T.thread_bounds->min), Evaluate(0));
+  }
+
+  // Build a minimal 1D TMA descriptor over the flattened global tensor
+  TMADesc desc;
+  desc.rank = 1;
+  desc.data_type = to_CUtensorMapDataType(global_tensor->dtype);
+  desc.global_addr = global_tensor->data;
+  PrimExpr total_gmem_elems = 1;
+  for (auto s : global_tensor->shape) total_gmem_elems *= s;
+  desc.global_shape = { total_gmem_elems };
+  desc.global_stride = { cast(DataType::Int(64), global_tensor->dtype.bytes()) };
+  desc.smem_box = { elements };
+  desc.smem_stride = { PrimExpr(1) };
+  desc.interleave = static_cast<int>(CU_TENSOR_MAP_INTERLEAVE_NONE);
+  desc.swizzle    = static_cast<int>(CU_TENSOR_MAP_SWIZZLE_NONE);
+  desc.l2_promotion = static_cast<int>(CU_TENSOR_MAP_L2_PROMOTION_L2_128B);
+  desc.oob_fill     = static_cast<int>(CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE);
+
+  Call create_descriptor =
+      Call(DataType::Handle(), create_tma_descriptor(), desc.EncodeCallArgs());
+
+  // Addresses and 1D coordinate
+  PrimExpr shared_addr = shared_tensor.access_ptr(
+      is_load ? 2 : 1, DataType::Handle(), 1, shared_offset, elements);
+  PrimExpr crd0 = cast(DataType::Int(32), global_offset);
+
+  Array<PrimExpr> args;
+  args.push_back(create_descriptor);
+  if (is_load) args.push_back(0); // mbarrier placeholder
+  args.push_back(shared_addr);
+  args.push_back(crd0);
+  args.push_back(this->eviction_policy);
+  auto op = is_load ? tma_load() : tma_store();
+  Stmt tma_copy = Evaluate(Call(DataType::Handle(), op, args));
+  tma_copy = IfThenElse(EQ(T.thread_var, T.thread_bounds->min), tma_copy);
+  return tma_copy;
 }
🧹 Nitpick comments (3)
src/op/copy.cc (3)

198-206: Fix doc typo: “Create s” → “Create indices”

Minor but visible in docs.

Apply:

- * \brief Create s for the copy operation.
+ * \brief Create indices for the copy operation.

751-754: Propagate buffer_oob during loop layout inference (consistency)

Parallel layout inference likely ignores buffer_oob, but passing the real value avoids surprises if it’s used later.

Apply:

-      par_op->InferLayout({T.target, T.thread_bounds, T.layout_map, analyzer,
-                           false, T.buffer_remap},
+      par_op->InferLayout({T.target, T.thread_bounds, T.layout_map, analyzer,
+                           T.buffer_oob, T.buffer_remap},
                           level);

1281-1354: Optional: unit tests for 1D/2D TMA selection

Add tests to assert that:

  • 1D TMA is selected only when buffer_oob == false and contiguity holds,
  • Lower() and InferLayout agree on the selected CopyInst,
  • 1D lowering emits descriptor form (argument count/shape).

I can draft minimal TIR unit-tests that exercise these paths if helpful.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 1545fb1 and c146849.

📒 Files selected for processing (2)
  • src/op/copy.cc (17 hunks)
  • src/op/copy.h (4 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
src/op/copy.cc (2)
src/tl_templates/cuda/copy_sm90.h (12)
  • tma_load (18-27)
  • tma_load (43-60)
  • tma_load (64-82)
  • tma_load (86-104)
  • tma_load (107-126)
  • tma_load (130-150)
  • tma_store (175-182)
  • tma_store (185-195)
  • tma_store (198-209)
  • tma_store (212-223)
  • tma_store (226-238)
  • tma_store (241-253)
tilelang/language/builtin.py (1)
  • tma_load (67-76)
src/op/copy.h (1)
src/op/copy.cc (20)
  • InferLayout (398-433)
  • InferLayout (398-399)
  • InferLayout (1645-1648)
  • InferLayout (1645-1646)
  • CheckBulkLoad (453-484)
  • CheckBulkLoad (453-453)
  • CheckBulkStore (569-599)
  • CheckBulkStore (569-569)
  • CheckBulkLoad1D (533-543)
  • CheckBulkLoad1D (533-534)
  • CheckBulkStore1D (545-555)
  • CheckBulkStore1D (545-546)
  • CheckBulkCopy1D (486-531)
  • CheckBulkCopy1D (486-491)
  • GetCopyInst (646-672)
  • GetCopyInst (646-649)
  • LowerBulkCopy (1001-1279)
  • LowerBulkCopy (1001-1002)
  • LowerBulkCopy1D (1281-1354)
  • LowerBulkCopy1D (1281-1282)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: format-check
🔇 Additional comments (8)
src/op/copy.h (5)

18-26: Enum additions for 1D TMA look good

kBulkLoad1D/kBulkStore1D separation is appropriate.


144-146: InferLayout override signature OK

Plumbs analyzer/layout_map/buffer_oob through inference.


150-156: Bulk load/store check declarations OK

Signatures align with usage in copy.cc.


157-176: 1D bulk-copy check API looks right

Parameters cover the necessary context for contiguity checks.


203-206: LowerBulkCopy1D hook: interface LGTM

Signature is consistent with other lowerers.

src/op/copy.cc (3)

406-408: Good: InferLayout gates 1D TMA on buffer_oob

The selection now respects OOB status during inference.


1257-1270: No change required here after above fixes

shared_addr will use the corrected shared_offset introduced after remap.

To be safe, run a smoke-test that exercises the bulk 2D path with a swizzled shared layout to confirm correct tile placement post-remap.


398-433: Nit: ParOp caching/inference flow looks solid

Nothing to change; noting this to acknowledge the improvement.

Comment on lines +1149 to +1155
if (T.layout_map.count(shared_tensor)) {
shared_layout = T.layout_map.at(shared_tensor);
ICHECK(T.buffer_remap.count(shared_tensor))
<< "shared_tensor: " << shared_tensor->name
<< " not found in buffer_remap";
shared_tensor = T.buffer_remap.at(shared_tensor);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Derive shared_offset with remap applied

After determining shared_layout and remapping the buffer, transform indices and use OffsetOf to compute the correct base.

Apply:

   if (!shared_layout.defined()) {
     desc.swizzle = static_cast<int>(CU_TENSOR_MAP_SWIZZLE_NONE);
   } else if (StructuralEqual()(shared_layout, linear_layout)) {
     desc.swizzle = static_cast<int>(CU_TENSOR_MAP_SWIZZLE_NONE);
   } else {
@@
   }
+
+  // Compute SMEM base offset in the (possibly) remapped buffer
+  Array<PrimExpr> shared_coords = shared_indices;
+  if (shared_layout.defined()) {
+    shared_coords = shared_layout->Forward(shared_indices);
+  }
+  PrimExpr shared_offset = shared_tensor.OffsetOf(shared_coords).back();
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if (T.layout_map.count(shared_tensor)) {
shared_layout = T.layout_map.at(shared_tensor);
ICHECK(T.buffer_remap.count(shared_tensor))
<< "shared_tensor: " << shared_tensor->name
<< " not found in buffer_remap";
shared_tensor = T.buffer_remap.at(shared_tensor);
}
if (!shared_layout.defined()) {
desc.swizzle = static_cast<int>(CU_TENSOR_MAP_SWIZZLE_NONE);
} else if (StructuralEqual()(shared_layout, linear_layout)) {
desc.swizzle = static_cast<int>(CU_TENSOR_MAP_SWIZZLE_NONE);
} else {
// existing swizzle-from-shared_layout logic
desc.swizzle = /* ... */;
}
// Compute SMEM base offset in the (possibly) remapped buffer
Array<PrimExpr> shared_coords = shared_indices;
if (shared_layout.defined()) {
shared_coords = shared_layout->Forward(shared_indices);
}
PrimExpr shared_offset = shared_tensor.OffsetOf(shared_coords).back();
🤖 Prompt for AI Agents
In src/op/copy.cc around lines 1149 to 1155, after you determine shared_layout
and then remap shared_tensor via T.buffer_remap, you must recompute the
shared_offset using the remapped tensor and its layout: transform the original
indices by the shared_layout (taking into account any axis permutations/strides)
and call OffsetOf on the remapped shared_tensor to obtain the correct base
offset. Update the code so that OffsetOf is invoked with the transformed indices
and the remapped/shared_tensor (not the pre-remap tensor), ensuring
shared_offset reflects the remapped buffer layout.

…tion

- Updated `CheckBulkLoad` and `CheckBulkStore` methods in `CopyNode` to include an optional parameter for validating the last dimension during bulk copy operations.
- Adjusted related methods `CheckBulkLoad1D` and `CheckBulkStore1D` to pass the new parameter, improving the accuracy of bulk copy checks.
- This change enhances the robustness of memory operations in the TileLang framework by ensuring compliance with dimensional requirements.
…ility

- Reformatted the parameter lists of `CheckBulkLoad` and `CheckBulkStore` methods in `CopyNode` to enhance code clarity by aligning parameters across multiple lines.
- This change improves the maintainability of the code and adheres to coding standards.
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
src/op/copy.cc (1)

650-676: Include buffer_oob argument in GetCopyInst call
At src/op/copy.cc:697, update the 4-arg call

GetCopyInst(target, disable_tma_lower, T.layout_map, analyzer)

to pass T.buffer_oob as the fifth parameter, since the default was removed.

♻️ Duplicate comments (4)
src/op/copy.cc (4)

488-533: Complete shared-side contiguity and fix potential misuse

  • Only global is range-checked for contiguity; shared needs the same tail-ones rule.
  • You correctly use shared_tensor for layout_map (fixes the old dst misuse), but please add the shared range check.

Apply:

 bool CopyNode::CheckBulkCopy1D(const Buffer &global_tensor,
                                const Buffer &shared_tensor,
                                const Array<Range> &global_range,
                                const Array<Range> &shared_range,
                                const LayoutMap &layout_map,
                                arith::Analyzer *analyzer) const {

   // Step 1: check shared is contiguous
   bool shared_is_contiguous = true;
   if (layout_map.count(shared_tensor)) {
     shared_is_contiguous = false;
   }
   // Step 2: check global is contiguous
   bool global_is_contiguous = true;
   bool global_not_full_dim_encounter = false;
   for (int i = global_range.size() - 1; i >= 0; i--) {
@@
   }
+
+  // Step 2b: check shared is contiguous via ranges (mirror global rule)
+  bool shared_not_full_dim_encounter = false;
+  for (int i = shared_range.size() - 1; i >= 0; i--) {
+    if (!shared_not_full_dim_encounter) {
+      if (!analyzer->CanProve(shared_range[i]->extent == shared_tensor->shape[i] &&
+                                   shared_range[i]->min == 0,
+                               arith::ProofStrength::kSymbolicBound)) {
+        shared_not_full_dim_encounter = true;
+      }
+    } else {
+      if (!analyzer->CanProve(shared_range[i]->extent == 1,
+                              arith::ProofStrength::kSymbolicBound)) {
+        shared_is_contiguous = false;
+        break;
+      }
+    }
+  }
@@
-  return (shared_is_contiguous && global_is_contiguous && element_match);
+  return (shared_is_contiguous && global_is_contiguous && element_match);
 }

650-653: Remove default argument in definition (must match header)

Default args in definitions are invalid in C++; header already removed it.

Apply:

-CopyInst CopyNode::GetCopyInst(Target target, bool disable_tma_lower,
-                               const LayoutMap &layout_map,
-                               arith::Analyzer *analyzer,
-                               bool buffer_oob = false) const {
+CopyInst CopyNode::GetCopyInst(Target target, bool disable_tma_lower,
+                               const LayoutMap &layout_map,
+                               arith::Analyzer *analyzer,
+                               bool buffer_oob) const {

696-703: Lower() must pass buffer_oob to GetCopyInst

Omitting it re-enables 1D TMA during lowering even when OOB exists.

Apply:

-  auto copy_inst =
-      GetCopyInst(target, disable_tma_lower, T.layout_map, analyzer);
+  auto copy_inst =
+      GetCopyInst(target, disable_tma_lower, T.layout_map, analyzer, T.buffer_oob);

1285-1358: Fix 1D TMA lowering: wrong intrinsic calling convention, missing descriptor, no remap, no zero-length guard

Current path passes (shared_addr, global_addr, mbar, bytes, eviction). Intrinsics expect (descriptor, [mbar], smem_ptr, crd0, cache_hint). Also ignores buffer remap and zero elements.

Apply:

 Stmt CopyNode::LowerBulkCopy1D(const LowerArgs &T, arith::Analyzer *analyzer,
                                CopyInst copy_inst) const {
   ICHECK(copy_inst == CopyInst::kBulkLoad1D ||
          copy_inst == CopyInst::kBulkStore1D);

-  // Add 1D TMA copy when the global and shared memory is contiguous
-  // Check if shared_tensor->name is present in T.buffer_var_gemm
-  // (Array<PrimExpr>) to avoid use 1D TMA copy for swizzled layout
   bool is_load = copy_inst == CopyInst::kBulkLoad1D;
-  auto shared_range = is_load ? dst_range : src_range;
-  auto global_range = is_load ? src_range : dst_range;
-  auto shared_tensor = is_load ? dst : src;
-  auto global_tensor = is_load ? src : dst;
+  auto shared_range = is_load ? dst_range : src_range;
+  auto global_range = is_load ? src_range : dst_range;
+  Buffer shared_tensor = is_load ? dst : src;
+  Buffer global_tensor = is_load ? src : dst;
+  if (T.buffer_remap.count(shared_tensor)) shared_tensor = T.buffer_remap.at(shared_tensor);
+  if (T.buffer_remap.count(global_tensor)) global_tensor = T.buffer_remap.at(global_tensor);

   PrimExpr shared_elements = 1;
   for (size_t i = 0; i < shared_range.size(); i++) {
     shared_elements *= shared_range[i]->extent;
   }
@@
-  PrimExpr global_offset = 0;
+  PrimExpr global_offset = 0;
   for (size_t i = 0; i < global_indices.size(); i++) {
     global_offset += global_indices[i] * global_strides[i];
   }
@@
-  PrimExpr shared_offset = 0;
+  PrimExpr shared_offset = 0;
   for (size_t i = 0; i < shared_indices.size(); i++) {
     shared_offset += shared_indices[i] * shared_strides[i];
   }
 
-  PrimExpr elements = analyzer->Simplify(shared_elements);
-  PrimExpr shared_addr = shared_tensor.access_ptr(
-      is_load ? 2 : 1, DataType::Handle(), 1, shared_offset, elements);
-  PrimExpr global_addr = global_tensor.access_ptr(
-      is_load ? 1 : 2, DataType::Handle(), 1, global_offset, elements);
-  Stmt tma_copy;
-  if (is_load) {
-    // the zero is a placeholder for mbarrier ids
-    tma_copy = Evaluate(
-        Call(DataType::Handle(), tma_load(),
-             {shared_addr, global_addr, 0,
-              elements * shared_tensor->dtype.bytes(), this->eviction_policy}));
-  } else {
-    tma_copy = Evaluate(
-        Call(DataType::Handle(), tma_store(),
-             {global_addr, shared_addr, elements * shared_tensor->dtype.bytes(),
-              this->eviction_policy}));
-  }
-  tma_copy = IfThenElse(EQ(T.thread_var, T.thread_bounds->min), tma_copy);
-  return tma_copy;
+  PrimExpr elements = analyzer->Simplify(shared_elements);
+  // Early out on zero-length copies
+  if (analyzer->CanProveEqual(elements, 0)) {
+    return IfThenElse(EQ(T.thread_var, T.thread_bounds->min), Evaluate(0));
+  }
+
+  // Build 1D TMA descriptor over flattened global tensor
+  TMADesc desc;
+  desc.rank = 1;
+  desc.data_type = to_CUtensorMapDataType(global_tensor->dtype);
+  desc.global_addr = global_tensor->data;
+  PrimExpr total_global_elems = 1;
+  for (auto s : global_tensor->shape) total_global_elems *= s;
+  desc.global_shape = { total_global_elems };
+  desc.global_stride = { cast(DataType::Int(64), global_tensor->dtype.bytes()) };
+  desc.smem_box = { elements };
+  desc.smem_stride = { PrimExpr(1) };
+  desc.interleave = static_cast<int>(CU_TENSOR_MAP_INTERLEAVE_NONE);
+  desc.swizzle    = static_cast<int>(CU_TENSOR_MAP_SWIZZLE_NONE);
+  desc.l2_promotion = static_cast<int>(CU_TENSOR_MAP_L2_PROMOTION_L2_128B);
+  desc.oob_fill     = static_cast<int>(CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE);
+  Call create_descriptor =
+      Call(DataType::Handle(), create_tma_descriptor(), desc.EncodeCallArgs());
+
+  // Addresses and coordinate
+  PrimExpr shared_addr = shared_tensor.access_ptr(
+      is_load ? 2 : 1, DataType::Handle(), 1, shared_offset, elements);
+  PrimExpr crd0 = cast(DataType::Int(32), global_offset);
+
+  Array<PrimExpr> args;
+  args.push_back(create_descriptor);
+  if (is_load) args.push_back(0); // mbarrier placeholder
+  args.push_back(shared_addr);
+  args.push_back(crd0);
+  args.push_back(this->eviction_policy);
+  auto op = is_load ? tma_load() : tma_store();
+  Stmt tma_copy = Evaluate(Call(DataType::Handle(), op, args));
+  tma_copy = IfThenElse(EQ(T.thread_var, T.thread_bounds->min), tma_copy);
+  return tma_copy;
 }
🧹 Nitpick comments (2)
src/op/copy.h (1)

159-170: Clarify 1D checks contract (doc-only)

Document that 1D path requires:

  • shared tensor has no layout_map entry (no swizzle/remap),
  • both global and shared ranges are contiguous (tail-ones rule),
  • dtypes match and no OOB at call site.
src/op/copy.cc (1)

199-206: Nit: fix comment typo

“Create s” → “Create indices”.

Apply:

- * \brief Create s for the copy operation.
+ * \brief Create indices for the copy operation.
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between c146849 and 0fafc15.

📒 Files selected for processing (2)
  • src/op/copy.cc (17 hunks)
  • src/op/copy.h (4 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
src/op/copy.h (1)
src/op/copy.cc (20)
  • InferLayout (398-433)
  • InferLayout (398-399)
  • InferLayout (1649-1652)
  • InferLayout (1649-1650)
  • CheckBulkLoad (453-486)
  • CheckBulkLoad (453-454)
  • CheckBulkStore (571-603)
  • CheckBulkStore (571-572)
  • CheckBulkLoad1D (535-545)
  • CheckBulkLoad1D (535-536)
  • CheckBulkStore1D (547-557)
  • CheckBulkStore1D (547-548)
  • CheckBulkCopy1D (488-533)
  • CheckBulkCopy1D (488-493)
  • GetCopyInst (650-676)
  • GetCopyInst (650-653)
  • LowerBulkCopy (1005-1283)
  • LowerBulkCopy (1005-1006)
  • LowerBulkCopy1D (1285-1358)
  • LowerBulkCopy1D (1285-1286)
src/op/copy.cc (2)
src/tl_templates/cuda/copy_sm90.h (12)
  • tma_load (18-27)
  • tma_load (43-60)
  • tma_load (64-82)
  • tma_load (86-104)
  • tma_load (107-126)
  • tma_load (130-150)
  • tma_store (175-182)
  • tma_store (185-195)
  • tma_store (198-209)
  • tma_store (212-223)
  • tma_store (226-238)
  • tma_store (241-253)
tilelang/language/builtin.py (1)
  • tma_load (67-76)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: bot-task
🔇 Additional comments (7)
src/op/copy.h (5)

18-26: Enum extension for 1D TMA looks good

New kBulkLoad1D/kBulkStore1D entries are clear and non-breaking.


144-145: InferLayout override is appropriate

Const-correct override is consistent with usage.


150-157: Bulk load/store signatures OK

Keeping check_last_dim with a default is fine for 2D; 1D path disables it at call sites.


193-196: GetCopyInst signature (no default) is correct

Header removes the default for buffer_oob. Ensure the .cc definition matches (it currently doesn't).


204-209: 1D lowering hook addition is good

Separation keeps 1D path self-contained.

src/op/copy.cc (2)

406-408: Good: InferLayout now gates 1D TMA on buffer_oob

Passing T.buffer_oob into GetCopyInst avoids unsafe 1D selections during inference.


755-758: OK: pass explicit buffer_oob=false to layout inference inside normal path

Intentional and safe since normal copy tolerates predicates and masking.

Comment on lines +1025 to 1034
Array<PrimExpr> shared_indices;
for (auto r : shared_range)
indices.push_back(r->min);
std::vector<PrimExpr> strides;
PrimExpr stride = 1;
shared_indices.push_back(r->min);
std::vector<PrimExpr> shared_strides;
PrimExpr shared_stride = 1;
for (size_t i = 0; i < shared_tensor->shape.size(); i++) {
auto s = shared_tensor->shape[shared_tensor->shape.size() - i - 1];
strides.insert(strides.begin(), stride);
stride *= s;
shared_strides.insert(shared_strides.begin(), shared_stride);
shared_stride *= s;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Compute shared_offset after layout remap

shared_offset is derived before applying layout_map/buffer_remap to shared_tensor, leading to wrong base addresses with swizzled layouts.

Apply:

-  Array<PrimExpr> shared_indices;
-  for (auto r : shared_range)
-    shared_indices.push_back(r->min);
+  Array<PrimExpr> shared_indices;
+  for (auto r : shared_range) shared_indices.push_back(r->min);
@@
-  PrimExpr shared_offset = 0;
-  for (size_t i = 0; i < shared_indices.size(); i++) {
-    shared_offset += shared_indices[i] * shared_strides[i];
-  }
+  PrimExpr shared_offset = 0;  // recomputed after possible remap
@@
   Layout shared_layout;
   if (T.layout_map.count(shared_tensor)) {
     shared_layout = T.layout_map.at(shared_tensor);
     ICHECK(T.buffer_remap.count(shared_tensor))
         << "shared_tensor: " << shared_tensor->name
         << " not found in buffer_remap";
     shared_tensor = T.buffer_remap.at(shared_tensor);
   }
+  // Derive offset in the (possibly) remapped buffer
+  {
+    Array<PrimExpr> shared_coords = shared_indices;
+    if (shared_layout.defined()) {
+      shared_coords = shared_layout->Forward(shared_indices);
+    }
+    shared_offset = shared_tensor.OffsetOf(shared_coords).back();
+  }
@@
-    PrimExpr shared_addr = shared_tensor.access_ptr(
-        is_load ? 2 : 1, DataType::Handle(), 1,
-        shared_offset + total_elements * loop_var, total_elements);
+    PrimExpr shared_addr = shared_tensor.access_ptr(
+        is_load ? 2 : 1, DataType::Handle(), 1,
+        shared_offset + total_elements * loop_var, total_elements);
@@
-    PrimExpr shared_addr = shared_tensor.access_ptr(
-        is_load ? 2 : 1, DataType::Handle(), 1, shared_offset, total_elements);
+    PrimExpr shared_addr = shared_tensor.access_ptr(
+        is_load ? 2 : 1, DataType::Handle(), 1, shared_offset, total_elements);

Also applies to: 1048-1054, 1152-1159, 1261-1274

🤖 Prompt for AI Agents
In src/op/copy.cc around lines 1025-1034 (and similarly at 1048-1054, 1152-1159,
1261-1274): shared_offset is computed from shared_tensor shape/strides before
applying any layout_map or buffer_remap, so the base address is wrong for
swizzled/remapped layouts; recompute shared_strides and shared_offset after
performing layout/buffer remap on shared_tensor (i.e., use the
remapped/shared_tensor layout to derive strides and then compute shared_offset),
ensuring you apply the same layout_map/buffer_remap transformations used
elsewhere before calculating offsets so base addresses reflect the final memory
layout.

@LeiWang1999 LeiWang1999 merged commit 9d7d45b into tile-ai:main Sep 6, 2025
5 of 7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant