-
Notifications
You must be signed in to change notification settings - Fork 299
[TMA] Automatically lower 1d tma in appropriate cases #788
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
- Updated `CopyNode` to introduce separate handling for 1D bulk load/store operations, including new methods for checking and lowering these operations. - Modified `InferLayout` and `GetCopyInst` to accommodate additional parameters for layout maps and analyzers. - Enhanced `AtomicAddNode` and `FillNode` to utilize the updated layout inference logic. - Improved buffer out-of-bounds checks during layout inference to ensure safe memory access. This update improves the efficiency and correctness of memory operations in the TileLang framework.
WalkthroughAdds an arith::Analyzer pointer and a buffer_oob flag to LayoutInferArgs and propagates them through layout-inference and lowering call sites; implements explicit 1D bulk-copy support and checks in CopyNode, refactors 2D bulk-copy addressing to use layout_map, introduces per-op OOB tracking in layout inference, and removes WgMMA gating from warp-specialized rewriting. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant LI as LayoutInferencer
participant AZ as Analyzer
participant Op as OpNode
participant LM as LayoutMap
LI->>AZ: Analyze bounds for current op
LI->>LI: compute per-op buffer_oob
LI->>Op: InferLayout({target, thread_bounds, LM, AZ, buffer_oob})
Op-->>LI: return LayoutMap
LI->>LI: store per-op thread/OOB context
sequenceDiagram
autonumber
participant Copy as CopyNode
participant LA as LowerArgs
participant AZ as Analyzer
participant Inst as CopyInstSelector
Copy->>Inst: GetCopyInst(target, disable_tma_lower, layout_map, AZ, buffer_oob)
alt kBulkLoad1D / kBulkStore1D
Copy->>Copy: LowerBulkCopy1D(LA, AZ, inst)
else kBulkLoad / kBulkStore
Copy->>Copy: LowerBulkCopy(LA, AZ, inst)
else kLDSM / kSTSM
Copy->>Copy: LowerLDSMCopy(LA, inst)
else kNormal
Copy->>Copy: LowerNormalCopy(LA, AZ)
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Poem
✨ Finishing Touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. CodeRabbit Commands (Invoked using PR/Issue comments)Type Other keywords and placeholders
CodeRabbit Configuration File (
|
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @LeiWang1999, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request introduces specialized handling for 1D Tensor Memory Access (TMA) operations within the TileLang framework, aiming to optimize memory transfers for contiguous data. It refactors the CopyNode to better manage different types of bulk copies and enhances the layout inference system with more detailed analytical capabilities, including explicit buffer out-of-bounds checks. Additionally, it includes a cleanup of the warp specialization rewriter.
Highlights
- Introduction of 1D TMA Operations: Added specific support for 1D Tensor Memory Access (TMA) bulk load and store operations (kBulkLoad1D, kBulkStore1D) to improve efficiency for contiguous memory transfers.
- Refactored CopyNode Logic: The CopyNode now includes dedicated methods (CheckBulkLoad1D, CheckBulkStore1D, LowerBulkCopy1D) to check for and lower 1D bulk copies, separating this logic from general bulk copy handling.
- Enhanced Layout Inference: The InferLayout and GetCopyInst functions, along with the LayoutInferArgs structure, have been extended to include arith::Analyzer and buffer_oob parameters, allowing for more precise layout inference and out-of-bounds checking.
- Improved Buffer OOB Checks: Out-of-bounds checks for buffers are now explicitly incorporated into the layout inference process, particularly for CopyNode operations, ensuring safer memory access.
- Updated Operator Inference: AtomicAddNode and FillNode have been updated to utilize the new InferLayout parameters, ensuring they benefit from the enhanced layout inference capabilities.
- Warp Specialization Refactoring: Simplified the WSCodeEmitter and WarpSpecializedRewriter by removing the only_has_wgmma_ flag and related logic, streamlining the warp specialization pass.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces support for 1D TMA (Tensor Memory Access) bulk copy operations, which can improve performance in appropriate cases. The changes include new functions to check for and lower 1D TMA, updates to the CopyInst enum, and modifications to layout inference to handle these new cases. While the changes are generally good and include some nice refactoring, I've identified a critical issue with potential for a crash due to unsafe map access, and a high-severity bug in the logic for checking 1D bulk store conditions. I've also noted a couple of places where documentation was removed, which could impact future maintainability.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 6
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (3)
src/op/atomic_add.cc (1)
289-307: address_of applied to a conditional rvalue can generate invalid IRaddress_of expects a BufferLoad (lvalue). Wrapping dst_value in if_then_else before address_of makes it a non-lvalue expression. Guard the atomic call with a combined predicate instead of taking address_of on a conditional.
- PrimExpr dst_value = BufferLoad(dst, dst_indices); - if (dst_predicate.defined()) - dst_value = if_then_else(dst_predicate, dst_value, make_zero(dst->dtype)); - - Call address_of_value = - tvm::tir::Call(DataType::Handle(), builtin::address_of(), {dst_value}); + PrimExpr dst_load = BufferLoad(dst, dst_indices); + Call address_of_value = + tvm::tir::Call(DataType::Handle(), builtin::address_of(), {dst_load}); new_args.push_back(address_of_value); new_args.push_back(src_value); - Call atomicadd_call = - tvm::tir::Call(dst->dtype, builtin::call_extern(), new_args); - - Stmt body = tvm::tir::Evaluate(atomicadd_call); + Call atomicadd_call = + tvm::tir::Call(dst->dtype, builtin::call_extern(), new_args); + Stmt body = tvm::tir::Evaluate(atomicadd_call); + // Combine source/destination predicates to guard the atomic op + PrimExpr call_pred = Bool(true); + if (src_predicate.defined()) call_pred = And(call_pred, src_predicate); + if (dst_predicate.defined()) call_pred = And(call_pred, dst_predicate); + if (!(is_one(call_pred))) { + body = IfThenElse(call_pred, body); + }src/transform/layout_inference.cc (1)
88-110: Include buffer_remap in LayoutInferArgs initializer
In transform/layout_inference.cc (lines 88–110), the call tonext->InferLayout(LayoutInferArgs{target_, thread_bounds, layout_map, &analyzer_, buffer_oob}, level);omits the
buffer_remapmap, defaulting it to empty. Many operator-levelInferLayoutimplementations (e.g. fill, atomic_add, reduce, copy, finalize_reducer, gemm_sp) readT.buffer_remapand will misbehave. Change the call to pass the currentbuffer_remap, e.g.:auto updates = next->InferLayout( LayoutInferArgs{target_, thread_bounds, layout_map, &analyzer_, buffer_oob, buffer_remap}, level);src/op/copy.cc (1)
1121-1161: Avoid .at() on layout_map; fall back cleanly when no layout is annotated.
.at()will CHECK-fail if the shared buffer has no layout entry.- auto shared_layout = T.layout_map.at(shared_tensor); - if (!shared_layout.defined()) { + Layout shared_layout; + if (T.layout_map.count(shared_tensor)) { + shared_layout = T.layout_map[shared_tensor]; + } + if (!shared_layout.defined()) { desc.swizzle = static_cast<int>(CU_TENSOR_MAP_SWIZZLE_NONE);
🧹 Nitpick comments (7)
src/transform/warp_specialized_rewriter.cc (1)
1226-1229: Nit: drop the explicitfalse— it’s the default.The extra boolean obscures intent and adds churn. Prefer the defaulted ctor.
Apply:
-WSCodeEmitter consumer(false, thread_iv_, buffer_data_to_buffer_, marker, - false); +WSCodeEmitter consumer(false, thread_iv_, buffer_data_to_buffer_, marker);src/op/operator.h (1)
48-51: Make analyzer pointer explicitly nullable and document buffer_oob semanticsAvoid accidental uninitialized use and clarify meaning/expectations of buffer_oob (true = may be OOB). Recommend default-init and brief comment.
- arith::Analyzer *analyzer; - bool buffer_oob = false; + // Optional analysis context; nullptr means no analyzer available + arith::Analyzer* analyzer = nullptr; + // True if this op may access out-of-bounds (used to steer conservative lowering) + bool buffer_oob = false;src/op/copy.h (2)
156-175: Public 1D bulk-copy checks: add brief doc on preconditionsShort comment on contiguity/stride assumptions (e.g., unit stride in either src or dst) will help callers.
189-193: Add= falsedefault forbuffer_oobin GetCopyInst declaration
In src/op/copy.h (line 189), change the last parameter frombool buffer_oobto
bool buffer_oob = falseAll call sites now pass 4 or 5 arguments (no legacy 2-arg uses), so this default reduces churn.
src/op/copy.cc (3)
199-206: Fix doc typo: “Create s” → “Create indices”.Small nit in the brief line.
- * \brief Create s for the copy operation. + * \brief Create indices for the copy operation.
996-1005: Address computation refactor looks good, but add small robustness guard.The stride/offset derivation is fine. Add early return if
shared_indices.size()==0to avoid accidental 0-d edge cases.- PrimExpr shared_offset = 0; - for (size_t i = 0; i < shared_indices.size(); i++) { + PrimExpr shared_offset = 0; + if (shared_indices.empty()) return LowerNormalCopy(T, analyzer); + for (size_t i = 0; i < shared_indices.size(); i++) { shared_offset += shared_indices[i] * shared_strides[i]; }Also applies to: 1019-1025
727-729: Pass actual buffer_oob to InferLayout
Replace the hard-codedfalsewithT.buffer_oobso the OOB flag isn’t dropped:- {T.target, T.thread_bounds, T.layout_map, analyzer, false, T.buffer_remap}, + {T.target, T.thread_bounds, T.layout_map, analyzer, T.buffer_oob, T.buffer_remap},
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (7)
src/op/atomic_add.cc(1 hunks)src/op/copy.cc(12 hunks)src/op/copy.h(5 hunks)src/op/fill.cc(2 hunks)src/op/operator.h(1 hunks)src/transform/layout_inference.cc(10 hunks)src/transform/warp_specialized_rewriter.cc(4 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/transform/layout_inference.cc (1)
tilelang/language/copy.py (1)
copy(84-152)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: bot-task
- GitHub Check: format-check
- GitHub Check: format-check
🔇 Additional comments (16)
src/transform/warp_specialized_rewriter.cc (2)
770-778: Barrier-arrival now unconditional on consumer path — validate counts against init.With WgMMA gating removed, consumer
makeArriveBarrier(release_barrier_id)executes wheneverrelease_afteris true. Ensure the corresponding barrier’sarrive_thread_count(computed below) matches this arrival source to avoid hangs or early releases.You can sanity-check by instrumenting a small IR containing both producer/consumer releases and verifying the number of barriers equals the number of
arriveinsertions per pattern. If helpful, I can draft a small IR snippet test.
1251-1258: Thread-count policy for barriers: double-check producer-vs-consumer arrival math.Logic reads: producer-released ⇒
producer_thread_extentif SIMT copy else1; otherwise ⇒consumer_thread_extent. This matches TMA (single-arrive) vs SIMT semantics, but edge-cases to verify:
- Bulk-copy-only path forces
producer_thread_extent = 128; ensure kernels actually launch ≥ consumer+128 threads.- Mixed patterns where a barrier can be released by consumer in one stage and producer in another should still be consistent with
released_barrier_bookkeeping.If desired, add a brief comment here to document the rationale.
I can add a unit-style IR test that asserts no mismatch between initialized counts and total arrives per barrier across producer/consumer paths.
src/op/atomic_add.cc (1)
364-365: Pass-through of analyzer and buffer_remap looks correctExtending InferLayout args here is consistent with the new signature. No issues.
src/op/fill.cc (2)
173-175: LGTM: InferLayout arg expansion (fragment path)Correctly threads analyzer, buffer_oob=false, and buffer_remap.
190-192: LGTM: InferLayout arg expansion (shared/global path)Matches the new LayoutInferArgs layout.
src/transform/layout_inference.cc (5)
19-20: Include copy.h is necessary for CopyNode OOB logicGood addition to enable type checks.
69-70: Using-declarations to inherit constructors are fineKeeps analyzer-aware visitors/mutators consistent.
207-210: Alignment check for buffer_oob_vec_ with infer_list_Good guard against index skew between vectors.
408-409: Parallel For bookkeeping includes buffer_oob falseKeeps vectors aligned. Good.
601-602: Using-declaration for IRMutatorWithAnalyzerLooks good; consistent with visitor.
src/op/copy.h (4)
18-26: Enum extension for 1D bulk-copy looks sensibleClear separation between 1D and multi-D TMA paths.
144-145: Marking InferLayout as override is correctMatches TileOperatorNode vtable.
201-206: Adding LowerBulkCopy1D hook is good; ensure parity with 2D path on swizzle/oobConfirm swizzle, l2_promotion, and oob_fill behaviors match LowerBulkCopy for consistency.
349-350: Conv2DIm2Col InferLayout overrideSignature matches base. Good.
src/op/copy.cc (2)
405-407: Good: GetCopyInst now considers analyzer, layout_map, and buffer_oob during inference.This keeps 1D TMA selection aligned with OOB knowledge at infer time.
520-542: 1D checks look consistent with 2D gating; OK after fixing CheckBulkCopy1D.Call routing and parameterization are correct.
| bool CopyNode::CheckBulkCopy1D(const Buffer &global_tensor, | ||
| const Buffer &shared_tensor, | ||
| const Array<Range> &global_range, | ||
| const Array<Range> &shared_range, | ||
| const LayoutMap &layout_map, | ||
| arith::Analyzer *analyzer) const { | ||
|
|
||
| // Step 1: check shared is contiguous | ||
| bool shared_is_contiguous = true; | ||
| if (layout_map.count(dst)) { | ||
| shared_is_contiguous = false; | ||
| } | ||
| // Step 2: check global is contiguous | ||
| bool global_is_contiguous = true; | ||
| bool global_not_full_dim_encounter = false; | ||
| for (int i = global_range.size() - 1; i >= 0; i--) { | ||
| if (!global_not_full_dim_encounter) { | ||
| if (!analyzer->CanProve(global_range[i]->extent == | ||
| global_tensor->shape[i] && | ||
| global_range[i]->min == 0, | ||
| arith::ProofStrength::kSymbolicBound)) { | ||
| global_not_full_dim_encounter = true; | ||
| } | ||
| } else { | ||
| if (!analyzer->CanProve(global_range[i]->extent == 1, | ||
| arith::ProofStrength::kSymbolicBound)) { | ||
| global_is_contiguous = false; | ||
| break; | ||
| } | ||
| } | ||
| } | ||
|
|
||
| // Step 3: check element match and no OOB | ||
| PrimExpr shared_elements = 1; | ||
| for (size_t i = 0; i < shared_range.size(); i++) { | ||
| shared_elements *= shared_range[i]->extent; | ||
| } | ||
| PrimExpr global_elements = 1; | ||
| for (size_t i = 0; i < global_range.size(); i++) { | ||
| global_elements *= global_range[i]->extent; | ||
| } | ||
| bool element_match = | ||
| analyzer->CanProveEqual(shared_elements, global_elements); | ||
|
|
||
| return (shared_is_contiguous && global_is_contiguous && element_match); | ||
| } | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix variable misuse and strengthen contiguity checks in CheckBulkCopy1D.
- Uses
dstfield instead of theshared_tensorparameter when consultinglayout_map. - Only checks contiguity on the global side; shared side should be checked too.
- // Step 1: check shared is contiguous
- bool shared_is_contiguous = true;
- if (layout_map.count(dst)) {
- shared_is_contiguous = false;
- }
+ // Step 1: check shared is contiguous (no remapped/swizzled layout)
+ bool shared_is_contiguous = true;
+ if (layout_map.count(shared_tensor)) {
+ shared_is_contiguous = false;
+ }
@@
- // Step 2: check global is contiguous
+ // Step 2: check global is contiguous
bool global_is_contiguous = true;
bool global_not_full_dim_encounter = false;
@@
}
+
+ // Step 2b: check shared is contiguous by ranges (same rule as global)
+ bool shared_not_full_dim_encounter = false;
+ for (int i = shared_range.size() - 1; i >= 0; i--) {
+ if (!shared_not_full_dim_encounter) {
+ if (!analyzer->CanProve(shared_range[i]->extent == shared_tensor->shape[i] &&
+ shared_range[i]->min == 0,
+ arith::ProofStrength::kSymbolicBound)) {
+ shared_not_full_dim_encounter = true;
+ }
+ } else {
+ if (!analyzer->CanProve(shared_range[i]->extent == 1,
+ arith::ProofStrength::kSymbolicBound)) {
+ shared_is_contiguous = false;
+ break;
+ }
+ }
+ }
@@
- return (shared_is_contiguous && global_is_contiguous && element_match);
+ return (shared_is_contiguous && global_is_contiguous && element_match);📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| bool CopyNode::CheckBulkCopy1D(const Buffer &global_tensor, | |
| const Buffer &shared_tensor, | |
| const Array<Range> &global_range, | |
| const Array<Range> &shared_range, | |
| const LayoutMap &layout_map, | |
| arith::Analyzer *analyzer) const { | |
| // Step 1: check shared is contiguous | |
| bool shared_is_contiguous = true; | |
| if (layout_map.count(dst)) { | |
| shared_is_contiguous = false; | |
| } | |
| // Step 2: check global is contiguous | |
| bool global_is_contiguous = true; | |
| bool global_not_full_dim_encounter = false; | |
| for (int i = global_range.size() - 1; i >= 0; i--) { | |
| if (!global_not_full_dim_encounter) { | |
| if (!analyzer->CanProve(global_range[i]->extent == | |
| global_tensor->shape[i] && | |
| global_range[i]->min == 0, | |
| arith::ProofStrength::kSymbolicBound)) { | |
| global_not_full_dim_encounter = true; | |
| } | |
| } else { | |
| if (!analyzer->CanProve(global_range[i]->extent == 1, | |
| arith::ProofStrength::kSymbolicBound)) { | |
| global_is_contiguous = false; | |
| break; | |
| } | |
| } | |
| } | |
| // Step 3: check element match and no OOB | |
| PrimExpr shared_elements = 1; | |
| for (size_t i = 0; i < shared_range.size(); i++) { | |
| shared_elements *= shared_range[i]->extent; | |
| } | |
| PrimExpr global_elements = 1; | |
| for (size_t i = 0; i < global_range.size(); i++) { | |
| global_elements *= global_range[i]->extent; | |
| } | |
| bool element_match = | |
| analyzer->CanProveEqual(shared_elements, global_elements); | |
| return (shared_is_contiguous && global_is_contiguous && element_match); | |
| } | |
| bool CopyNode::CheckBulkCopy1D(const Buffer &global_tensor, | |
| const Buffer &shared_tensor, | |
| const Array<Range> &global_range, | |
| const Array<Range> &shared_range, | |
| const LayoutMap &layout_map, | |
| arith::Analyzer *analyzer) const { | |
| // Step 1: check shared is contiguous (no remapped/swizzled layout) | |
| bool shared_is_contiguous = true; | |
| if (layout_map.count(shared_tensor)) { | |
| shared_is_contiguous = false; | |
| } | |
| // Step 2: check global is contiguous | |
| bool global_is_contiguous = true; | |
| bool global_not_full_dim_encounter = false; | |
| for (int i = global_range.size() - 1; i >= 0; i--) { | |
| if (!global_not_full_dim_encounter) { | |
| if (!analyzer->CanProve(global_range[i]->extent == global_tensor->shape[i] && | |
| global_range[i]->min == 0, | |
| arith::ProofStrength::kSymbolicBound)) { | |
| global_not_full_dim_encounter = true; | |
| } | |
| } else { | |
| if (!analyzer->CanProve(global_range[i]->extent == 1, | |
| arith::ProofStrength::kSymbolicBound)) { | |
| global_is_contiguous = false; | |
| break; | |
| } | |
| } | |
| } | |
| // Step 2b: check shared is contiguous by ranges (same rule as global) | |
| bool shared_not_full_dim_encounter = false; | |
| for (int i = shared_range.size() - 1; i >= 0; i--) { | |
| if (!shared_not_full_dim_encounter) { | |
| if (!analyzer->CanProve(shared_range[i]->extent == shared_tensor->shape[i] && | |
| shared_range[i]->min == 0, | |
| arith::ProofStrength::kSymbolicBound)) { | |
| shared_not_full_dim_encounter = true; | |
| } | |
| } else { | |
| if (!analyzer->CanProve(shared_range[i]->extent == 1, | |
| arith::ProofStrength::kSymbolicBound)) { | |
| shared_is_contiguous = false; | |
| break; | |
| } | |
| } | |
| } | |
| // Step 3: check element match and no OOB | |
| PrimExpr shared_elements = 1; | |
| for (size_t i = 0; i < shared_range.size(); i++) { | |
| shared_elements *= shared_range[i]->extent; | |
| } | |
| PrimExpr global_elements = 1; | |
| for (size_t i = 0; i < global_range.size(); i++) { | |
| global_elements *= global_range[i]->extent; | |
| } | |
| bool element_match = | |
| analyzer->CanProveEqual(shared_elements, global_elements); | |
| return (shared_is_contiguous && global_is_contiguous && element_match); | |
| } |
🤖 Prompt for AI Agents
In src/op/copy.cc around lines 473-519, the code mistakenly uses dst when
querying layout_map and only checks contiguity for the global side; replace the
layout_map.count(dst) call with layout_map.count(shared_tensor) and add a
symmetric contiguity check for shared_tensor (mirror the global_is_contiguous
logic: walk shared_range from the last dimension to first, use
analyzer->CanProve to detect a full-dimension then require trailing extents ==
1, set shared_is_contiguous accordingly), keeping the existing element_match
check and return expression.
src/op/copy.cc
Outdated
| CopyInst CopyNode::GetCopyInst(Target target, bool disable_tma_lower, | ||
| const LayoutMap &layout_map, | ||
| arith::Analyzer *analyzer, | ||
| bool buffer_oob = false) const { | ||
| // disable_tma_lower is from pass_configs | ||
| // when tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER is True, | ||
| // we will not use tma for bulk load/store | ||
| if (!disable_tma_lower && CheckBulkLoad(target)) { | ||
|
|
||
| // 1d tma access can not support out of bound access | ||
| if (!disable_tma_lower && !buffer_oob && | ||
| CheckBulkLoad1D(target, layout_map, analyzer)) { | ||
| return CopyInst::kBulkLoad1D; | ||
| } else if (!disable_tma_lower && !buffer_oob && | ||
| CheckBulkStore1D(target, layout_map, analyzer)) { | ||
| return CopyInst::kBulkStore1D; | ||
| } else if (!disable_tma_lower && CheckBulkLoad(target)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Make buffer_oob non-optional in GetCopyInst to prevent unsafe defaults.
Defaulting buffer_oob to false risks call sites forgetting to pass it and enabling 1D TMA when OOB exists.
-CopyInst CopyNode::GetCopyInst(Target target, bool disable_tma_lower,
- const LayoutMap &layout_map,
- arith::Analyzer *analyzer,
- bool buffer_oob = false) const {
+CopyInst CopyNode::GetCopyInst(Target target, bool disable_tma_lower,
+ const LayoutMap &layout_map,
+ arith::Analyzer *analyzer,
+ bool buffer_oob) const {Follow-up: remove the default from the declaration in the header as well.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| CopyInst CopyNode::GetCopyInst(Target target, bool disable_tma_lower, | |
| const LayoutMap &layout_map, | |
| arith::Analyzer *analyzer, | |
| bool buffer_oob = false) const { | |
| // disable_tma_lower is from pass_configs | |
| // when tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER is True, | |
| // we will not use tma for bulk load/store | |
| if (!disable_tma_lower && CheckBulkLoad(target)) { | |
| // 1d tma access can not support out of bound access | |
| if (!disable_tma_lower && !buffer_oob && | |
| CheckBulkLoad1D(target, layout_map, analyzer)) { | |
| return CopyInst::kBulkLoad1D; | |
| } else if (!disable_tma_lower && !buffer_oob && | |
| CheckBulkStore1D(target, layout_map, analyzer)) { | |
| return CopyInst::kBulkStore1D; | |
| } else if (!disable_tma_lower && CheckBulkLoad(target)) { | |
| CopyInst CopyNode::GetCopyInst(Target target, bool disable_tma_lower, | |
| const LayoutMap &layout_map, | |
| arith::Analyzer *analyzer, | |
| bool buffer_oob) const { | |
| // disable_tma_lower is from pass_configs | |
| // when tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER is True, | |
| // we will not use tma for bulk load/store | |
| // 1d tma access can not support out of bound access | |
| if (!disable_tma_lower && !buffer_oob && | |
| CheckBulkLoad1D(target, layout_map, analyzer)) { | |
| return CopyInst::kBulkLoad1D; | |
| } else if (!disable_tma_lower && !buffer_oob && | |
| CheckBulkStore1D(target, layout_map, analyzer)) { | |
| return CopyInst::kBulkStore1D; | |
| } else if (!disable_tma_lower && CheckBulkLoad(target)) { | |
| // ... |
🤖 Prompt for AI Agents
In src/op/copy.cc around lines 621 to 636, the GetCopyInst signature currently
defaults buffer_oob to false which can lead callers to unintentionally enable 1D
TMA when out-of-bounds is possible; remove the default (= false) from the
definition and also remove the default in the corresponding declaration in the
header file, then audit and update all call sites to pass an explicit buffer_oob
boolean value (true/false) appropriate for their context so no caller relies on
the implicit false.
| WSCodeEmitter(bool is_emitting_producer, const IterVar &thread_iv, | ||
| Map<Var, Buffer> buffer_data_to_buffer, | ||
| const WarpSpecializedRoleMarker &marker, | ||
| bool mbarrier_only = false, bool only_has_wgmma = false) | ||
| bool mbarrier_only = false) | ||
| : is_emitting_producer_(is_emitting_producer), | ||
| buffer_data_to_buffer_(std::move(buffer_data_to_buffer)), | ||
| marker_(marker), thread_var_(thread_iv->var), | ||
| mbarrier_only_(mbarrier_only), only_has_wgmma_(only_has_wgmma) {} | ||
| mbarrier_only_(mbarrier_only) {} | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Verification agent
🧩 Analysis chain
Constructor signature change looks good; confirm all call sites and remove dead WgMMA remnants.
Dropping the WgMMA flag simplifies the API. Please verify there are no lingering callers still passing the old parameter or referencing the removed accessor.
Run to confirm no stale uses remain:
🏁 Script executed:
#!/bin/bash
# Constructor/call sites
rg -nP --glob '!**/build/**' '\bWSCodeEmitter\s*\(' -C2
# Old API remnants
rg -nP --glob '!**/build/**' 'onlyHasWgMMA|has_wgmma|WgMMA' -C2Length of output: 2330
Delete dead WgMMACollector and associated HasWgMMA code
The WSCodeEmitter constructor change is confirmed—no external calls use the old signature—but the WgMMACollector class and its HasWgMMA members (src/transform/warp_specialized_rewriter.cc:572–604) remain unreferenced and should be removed.
🤖 Prompt for AI Agents
In src/transform/warp_specialized_rewriter.cc around lines 572 to 604 (and
constructor change at 608–616), remove the dead WgMMACollector class and its
HasWgMMA members which are no longer referenced: delete the entire class
definition and any HasWgMMA helper methods/fields, remove or update any forward
declarations and includes that exist solely to support them, and run a quick
search to eliminate any lingering references; ensure the WSCodeEmitter
constructor and its usage remain unchanged and compile cleanly after the
removals.
- Updated `InferLayout` calls in `AtomicAddNode`, `CopyNode`, and `FillNode` to enhance code clarity by formatting parameters across multiple lines. - Cleaned up whitespace and formatting in `copy.h` and `layout_inference.cc` to adhere to coding standards and improve maintainability. This refactor aims to streamline the layout inference logic and improve overall code organization.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (3)
src/transform/layout_inference.cc (1)
319-332: Bug: const_int_bound queried with IterVar instead of VarUse thread_var_->var when querying const_int_bound; current code won’t compile or will pick the wrong overload.
Apply:
- if (analyzer_.const_int_bound.IsBound(thread_var_->var)) { - auto const_int_bound = analyzer_.const_int_bound(thread_var_); + if (analyzer_.const_int_bound.IsBound(thread_var_->var)) { + auto const_int_bound = analyzer_.const_int_bound(thread_var_->var); @@ - thread_bounds_vec_.push_back(Range::FromMinExtent( - IntImm(dtype, min_value), IntImm(dtype, extent))); + thread_bounds_vec_.push_back(Range::FromMinExtent( + IntImm(dtype, min_value), IntImm(dtype, extent)));And similarly below:
- if (thread_var_.defined() && - analyzer_.const_int_bound.IsBound(thread_var_->var)) { - auto const_int_bound = analyzer_.const_int_bound(thread_var_); + if (thread_var_.defined() && + analyzer_.const_int_bound.IsBound(thread_var_->var)) { + auto const_int_bound = analyzer_.const_int_bound(thread_var_->var);src/op/copy.cc (2)
1089-1093: Potential OOB read in while-conditionCheck index bounds before indexing shared_range.
Apply:
- while (is_one(shared_range[s_range_idx]->extent) && - s_range_idx < shared_range.size()) { + while (s_range_idx < shared_range.size() && + is_one(shared_range[s_range_idx]->extent)) { s_range_idx++; }
1121-1141: Guard layout_map access; .at() throws when absentUse count()/operator[] pattern before dereference.
Apply:
- auto shared_layout = T.layout_map.at(shared_tensor); - if (!shared_layout.defined()) { + Layout shared_layout; + if (T.layout_map.count(shared_tensor)) { + shared_layout = T.layout_map[shared_tensor]; + } + if (!shared_layout.defined()) {
♻️ Duplicate comments (5)
src/transform/layout_inference.cc (1)
333-361: OOB check misses lower-bound; add min >= 0 proofs for src/dstCurrent logic only proves min+extent <= shape. Also prove min >= 0 to avoid enabling 1D TMA on negative starts.
Apply:
- for (size_t i = 0; i < src_range.size(); i++) { - if (!analyzer_.CanProve(src_range[i]->min + src_range[i]->extent <= - src_tensor->shape[i], - arith::ProofStrength::kSymbolicBound)) { - src_oob = true; - break; - } - } - for (size_t i = 0; i < dst_range.size(); i++) { - if (!analyzer_.CanProve(dst_range[i]->min + dst_range[i]->extent <= - dst_tensor->shape[i], - arith::ProofStrength::kSymbolicBound)) { - dst_oob = true; - break; - } - } + for (size_t i = 0; i < src_range.size(); i++) { + bool upper_ok = analyzer_.CanProve( + src_range[i]->min + src_range[i]->extent <= src_tensor->shape[i], + arith::ProofStrength::kSymbolicBound); + bool lower_ok = analyzer_.CanProve( + src_range[i]->min >= 0, arith::ProofStrength::kSymbolicBound); + if (!(upper_ok && lower_ok)) { + src_oob = true; + break; + } + } + for (size_t i = 0; i < dst_range.size(); i++) { + bool upper_ok = analyzer_.CanProve( + dst_range[i]->min + dst_range[i]->extent <= dst_tensor->shape[i], + arith::ProofStrength::kSymbolicBound); + bool lower_ok = analyzer_.CanProve( + dst_range[i]->min >= 0, arith::ProofStrength::kSymbolicBound); + if (!(upper_ok && lower_ok)) { + dst_oob = true; + break; + } + }src/op/copy.cc (4)
621-647: Remove default for buffer_oob in GetCopyInst definitionHeader requires explicit buffer_oob; keep definition consistent to avoid unsafe defaults and mismatched redeclarations.
Apply:
-CopyInst CopyNode::GetCopyInst(Target target, bool disable_tma_lower, - const LayoutMap &layout_map, - arith::Analyzer *analyzer, - bool buffer_oob = false) const { +CopyInst CopyNode::GetCopyInst(Target target, bool disable_tma_lower, + const LayoutMap &layout_map, + arith::Analyzer *analyzer, + bool buffer_oob) const {
661-674: Lower(): pass buffer_oob into GetCopyInstWithout this, 1D TMA may be selected despite OOB.
Apply:
- auto copy_inst = - GetCopyInst(target, disable_tma_lower, T.layout_map, analyzer); + auto copy_inst = + GetCopyInst(target, disable_tma_lower, T.layout_map, analyzer, T.buffer_oob);
473-519: Fix shared tensor lookup and add shared contiguity scan
- Uses dst instead of shared_tensor when checking layout_map.
- Lacks symmetric contiguity check on the shared side.
Apply:
- // Step 1: check shared is contiguous - bool shared_is_contiguous = true; - if (layout_map.count(dst)) { - shared_is_contiguous = false; - } + // Step 1: check shared is contiguous (no remapped/swizzled layout) + bool shared_is_contiguous = true; + if (layout_map.count(shared_tensor)) { + shared_is_contiguous = false; + } @@ // Step 2: check global is contiguous bool global_is_contiguous = true; bool global_not_full_dim_encounter = false; @@ } + + // Step 2b: check shared is contiguous by ranges (same rule as global) + bool shared_not_full_dim_encounter = false; + for (int i = shared_range.size() - 1; i >= 0; i--) { + if (!shared_not_full_dim_encounter) { + if (!analyzer->CanProve(shared_range[i]->extent == shared_tensor->shape[i] && + shared_range[i]->min == 0, + arith::ProofStrength::kSymbolicBound)) { + shared_not_full_dim_encounter = true; + } + } else { + if (!analyzer->CanProve(shared_range[i]->extent == 1, + arith::ProofStrength::kSymbolicBound)) { + shared_is_contiguous = false; + break; + } + } + }
1247-1320: LowerBulkCopy1D: honor buffer_remap and no-op on zero elements
- Use T.buffer_remap for both tensors (aligns with LDSM path).
- Early return if elements == 0.
Apply:
bool is_load = copy_inst == CopyInst::kBulkLoad1D; @@ - auto shared_tensor = is_load ? dst : src; - auto global_tensor = is_load ? src : dst; + auto shared_tensor = is_load ? dst : src; + auto global_tensor = is_load ? src : dst; + if (T.buffer_remap.count(shared_tensor)) shared_tensor = T.buffer_remap[shared_tensor]; + if (T.buffer_remap.count(global_tensor)) global_tensor = T.buffer_remap[global_tensor]; @@ - PrimExpr elements = analyzer->Simplify(shared_elements); + PrimExpr elements = analyzer->Simplify(shared_elements); + if (analyzer->CanProveEqual(elements, 0)) { + return IfThenElse(EQ(T.thread_var, T.thread_bounds->min), Evaluate(0)); + }
🧹 Nitpick comments (1)
src/op/copy.cc (1)
198-206: Fix doc typo“Create s” → “Create indices”.
Apply:
- * \brief Create s for the copy operation. + * \brief Create indices for the copy operation.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (5)
src/op/atomic_add.cc(1 hunks)src/op/copy.cc(12 hunks)src/op/copy.h(5 hunks)src/op/fill.cc(3 hunks)src/transform/layout_inference.cc(10 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
- src/op/atomic_add.cc
- src/op/fill.cc
🧰 Additional context used
🧬 Code graph analysis (3)
src/op/copy.h (1)
src/op/copy.cc (16)
InferLayout(398-433)InferLayout(398-399)InferLayout(1611-1614)InferLayout(1611-1612)CheckBulkLoad1D(520-530)CheckBulkLoad1D(520-521)CheckBulkStore1D(532-542)CheckBulkStore1D(532-533)CheckBulkCopy1D(473-518)CheckBulkCopy1D(473-478)GetCopyInst(621-647)GetCopyInst(621-624)LowerBulkCopy(976-1245)LowerBulkCopy(976-977)LowerBulkCopy1D(1247-1320)LowerBulkCopy1D(1247-1248)
src/transform/layout_inference.cc (1)
tilelang/language/copy.py (1)
copy(84-152)
src/op/copy.cc (1)
src/tl_templates/cuda/copy_sm90.h (12)
tma_load(18-27)tma_load(43-60)tma_load(64-82)tma_load(86-104)tma_load(107-126)tma_load(130-150)tma_store(175-182)tma_store(185-195)tma_store(198-209)tma_store(212-223)tma_store(226-238)tma_store(241-253)
🔇 Additional comments (14)
src/transform/layout_inference.cc (6)
17-20: Include dependency looks correctIncluding copy.h here is appropriate for CopyNode checks.
69-70: Constructor inheritance: OKInheriting IRVisitorWithAnalyzer constructors is fine.
209-212: Size consistency check: goodEnsuring buffer_oob_vec_ aligns with infer_list_ is necessary and correct.
410-410: Parallel For default OOB flag = false: OKInitializing per-op OOB to false for kParallel nodes makes sense.
603-604: Constructor inheritance: OKIRMutatorWithAnalyzer constructor inheritance is fine.
107-112: LayoutInferArgs initializer order is correct
The brace-init values—target_, thread_bounds, layout_map, &analyzer_, buffer_oob—exactly match the first five fields ofLayoutInferArgs(in declaration order), so there’s no misbinding risk.src/op/copy.h (6)
18-26: New 1D TMA enums: OKSeparating 1D from multi-dim variants is a good call.
144-146: Override annotation: OKMarking InferLayout as override is correct.
157-177: 1D viability checks: OK (ensure consistent impl in .cc)Interfaces look good; see .cc comments for impl fixes.
191-194: Signature requires explicit buffer_oob: goodHeader removes any default. Make sure the .cc definition also has no default (see comment there).
203-207: LowerBulkCopy1D hook: OKAPI surface looks good; see .cc for remap/zero-length refinements.
350-352: Conv2DIm2Col override: OKSignature matches usage.
src/op/copy.cc (2)
398-408: InferLayout now gates by buffer_oob: OKPassing analyzer and buffer_oob into GetCopyInst is correct.
726-729: LayoutInferArgs initializer arguments match struct field order
The call’s arguments (analyzer,false,T.buffer_remap) align with the struct’s fields (arith::Analyzer* analyzer,bool buffer_oob,Map<Buffer, Buffer> buffer_remap).
- Updated the condition in `CheckBulkCopy1D` to verify contiguity of `shared_tensor` instead of `dst`, ensuring correct handling of shared memory layouts during bulk copy operations. - This change enhances the accuracy of memory operations in the TileLang framework.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
src/op/copy.cc (1)
621-647: Make buffer_oob explicit in GetCopyInst (remove default).Defaulting to false is unsafe; callers will silently enable 1D TMA under OOB.
-CopyInst CopyNode::GetCopyInst(Target target, bool disable_tma_lower, - const LayoutMap &layout_map, - arith::Analyzer *analyzer, - bool buffer_oob = false) const { +CopyInst CopyNode::GetCopyInst(Target target, bool disable_tma_lower, + const LayoutMap &layout_map, + arith::Analyzer *analyzer, + bool buffer_oob) const {Also remove the default in the header declaration and update all call sites to pass an explicit boolean.
♻️ Duplicate comments (2)
src/op/copy.cc (2)
473-519: Complete shared-side contiguity validation (and nice fix on layout_map var).You correctly switched to layout_map.count(shared_tensor). However, contiguity is only proven for the global side; 1D TMA can still be mis-selected if shared_range isn’t contiguous-by-ranges (even when not swizzled). Mirror the global check for shared_range.
Apply this diff within CheckBulkCopy1D:
// Step 1: check shared is contiguous bool shared_is_contiguous = true; if (layout_map.count(shared_tensor)) { shared_is_contiguous = false; } // Step 2: check global is contiguous bool global_is_contiguous = true; bool global_not_full_dim_encounter = false; for (int i = global_range.size() - 1; i >= 0; i--) { if (!global_not_full_dim_encounter) { if (!analyzer->CanProve(global_range[i]->extent == global_tensor->shape[i] && global_range[i]->min == 0, arith::ProofStrength::kSymbolicBound)) { global_not_full_dim_encounter = true; } } else { if (!analyzer->CanProve(global_range[i]->extent == 1, arith::ProofStrength::kSymbolicBound)) { global_is_contiguous = false; break; } } } + + // Step 2b: check shared is contiguous by ranges (symmetric rule) + bool shared_not_full_dim_encounter = false; + for (int i = shared_range.size() - 1; i >= 0; i--) { + if (!shared_not_full_dim_encounter) { + if (!analyzer->CanProve(shared_range[i]->extent == + shared_tensor->shape[i] && + shared_range[i]->min == 0, + arith::ProofStrength::kSymbolicBound)) { + shared_not_full_dim_encounter = true; + } + } else { + if (!analyzer->CanProve(shared_range[i]->extent == 1, + arith::ProofStrength::kSymbolicBound)) { + shared_is_contiguous = false; + break; + } + } + }
667-676: Includebuffer_oobin GetCopyInst call in Lower()
In src/op/copy.cc:667 the call toGetCopyInstomitsT.buffer_oob, so OOB checks aren’t considered and 1D TMA may be mis-picked. Change to:- auto copy_inst = GetCopyInst(target, disable_tma_lower, T.layout_map, analyzer); + auto copy_inst = GetCopyInst(target, disable_tma_lower, T.layout_map, analyzer, T.buffer_oob);
🧹 Nitpick comments (3)
src/op/copy.cc (3)
199-206: Fix typo in MakeIndices doc."Create s" -> "Create indices".
- * \brief Create s for the copy operation. + * \brief Create indices for the copy operation.
726-729: Propagate real buffer_oob into parallel layout inference.Minor, but keeps the pipeline state consistent.
- par_op->InferLayout({T.target, T.thread_bounds, T.layout_map, analyzer, - false, T.buffer_remap}, + par_op->InferLayout({T.target, T.thread_bounds, T.layout_map, analyzer, + T.buffer_oob, T.buffer_remap}, level);
1116-1127: Avoid layout_map.at() without existence check.Map::at may throw; use count() then operator[] to allow “undefined” layouts.
- auto shared_layout = T.layout_map.at(shared_tensor); + Layout shared_layout; + if (T.layout_map.count(shared_tensor)) { + shared_layout = T.layout_map[shared_tensor]; + }
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (1)
src/op/copy.cc(12 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/op/copy.cc (2)
src/tl_templates/cuda/copy_sm90.h (12)
tma_load(18-27)tma_load(43-60)tma_load(64-82)tma_load(86-104)tma_load(107-126)tma_load(130-150)tma_store(175-182)tma_store(185-195)tma_store(198-209)tma_store(212-223)tma_store(226-238)tma_store(241-253)tilelang/language/builtin.py (1)
tma_load(67-76)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: format-check
- GitHub Check: bot-task
🔇 Additional comments (3)
src/op/copy.cc (3)
406-408: Good: GetCopyInst now uses analyzer and buffer_oob during layout inference.This keeps 1D TMA gating consistent between inference and lowering.
520-543: LGTM: 1D load/store viability wrappers.The scaffolding reuses base checks and correctly threads layout_map/analyzer.
996-1005: LGTM: explicit shared/global offsets for TMA addressing.The stride/offset derivation is correct and simplifies subsequent address math.
Also applies to: 1019-1025
| Stmt CopyNode::LowerBulkCopy1D(const LowerArgs &T, arith::Analyzer *analyzer, | ||
| CopyInst copy_inst) const { | ||
| ICHECK(copy_inst == CopyInst::kBulkLoad1D || | ||
| copy_inst == CopyInst::kBulkStore1D); | ||
|
|
||
| // Add 1D TMA copy when the global and shared memory is contiguous | ||
| // Check if shared_tensor->name is present in T.buffer_var_gemm | ||
| // (Array<PrimExpr>) to avoid use 1D TMA copy for swizzled layout | ||
| bool is_load = copy_inst == CopyInst::kBulkLoad1D; | ||
| auto shared_range = is_load ? dst_range : src_range; | ||
| auto global_range = is_load ? src_range : dst_range; | ||
| auto shared_tensor = is_load ? dst : src; | ||
| auto global_tensor = is_load ? src : dst; | ||
|
|
||
| PrimExpr shared_elements = 1; | ||
| for (size_t i = 0; i < shared_range.size(); i++) { | ||
| shared_elements *= shared_range[i]->extent; | ||
| } | ||
|
|
||
| std::vector<PrimExpr> shared_strides; | ||
| PrimExpr shared_stride = 1; | ||
| for (size_t i = 0; i < shared_tensor->shape.size(); i++) { | ||
| auto s = shared_tensor->shape[shared_tensor->shape.size() - i - 1]; | ||
| shared_strides.insert(shared_strides.begin(), shared_stride); | ||
| shared_stride *= s; | ||
| } | ||
|
|
||
| Array<PrimExpr> shared_indices; | ||
| for (auto r : shared_range) | ||
| shared_indices.push_back(r->min); | ||
|
|
||
| Array<PrimExpr> global_indices; | ||
| for (auto r : global_range) { | ||
| global_indices.push_back(r->min); | ||
| } | ||
| std::vector<PrimExpr> global_strides; | ||
| PrimExpr global_stride = 1; | ||
| for (size_t i = 0; i < global_tensor->shape.size(); i++) { | ||
| auto s = global_tensor->shape[global_tensor->shape.size() - i - 1]; | ||
| global_strides.insert(global_strides.begin(), global_stride); | ||
| global_stride *= s; | ||
| } | ||
|
|
||
| PrimExpr global_offset = 0; | ||
| for (size_t i = 0; i < global_indices.size(); i++) { | ||
| global_offset += global_indices[i] * global_strides[i]; | ||
| } | ||
|
|
||
| PrimExpr shared_offset = 0; | ||
| for (size_t i = 0; i < shared_indices.size(); i++) { | ||
| shared_offset += shared_indices[i] * shared_strides[i]; | ||
| } | ||
|
|
||
| PrimExpr elements = analyzer->Simplify(shared_elements); | ||
| PrimExpr shared_addr = shared_tensor.access_ptr( | ||
| is_load ? 2 : 1, DataType::Handle(), 1, shared_offset, elements); | ||
| PrimExpr global_addr = global_tensor.access_ptr( | ||
| is_load ? 1 : 2, DataType::Handle(), 1, global_offset, elements); | ||
| Stmt tma_copy; | ||
| if (is_load) { | ||
| // the zero is a placeholder for mbarrier ids | ||
| tma_copy = Evaluate( | ||
| Call(DataType::Handle(), tma_load(), | ||
| {shared_addr, global_addr, 0, | ||
| elements * shared_tensor->dtype.bytes(), this->eviction_policy})); | ||
| } else { | ||
| tma_copy = Evaluate( | ||
| Call(DataType::Handle(), tma_store(), | ||
| {global_addr, shared_addr, elements * shared_tensor->dtype.bytes(), | ||
| this->eviction_policy})); | ||
| } | ||
| tma_copy = IfThenElse(EQ(T.thread_var, T.thread_bounds->min), tma_copy); | ||
| return tma_copy; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix 1D TMA lowering: wrong intrinsic signature, missing descriptor, no remap, no zero-length guard.
Current 1D path passes (shared_addr, global_addr, mbar, bytes, eviction) to tma_load/tma_store, but intrinsics expect (descriptor, [mbar], smem_ptr, crd0, cache_hint). Also ignores T.buffer_remap and doesn’t early-out for zero elements.
Apply this rewrite (minimal, descriptor-based, remap-aware, 0-eval guard):
Stmt CopyNode::LowerBulkCopy1D(const LowerArgs &T, arith::Analyzer *analyzer,
CopyInst copy_inst) const {
ICHECK(copy_inst == CopyInst::kBulkLoad1D ||
copy_inst == CopyInst::kBulkStore1D);
- // Add 1D TMA copy when the global and shared memory is contiguous
- // Check if shared_tensor->name is present in T.buffer_var_gemm
- // (Array<PrimExpr>) to avoid use 1D TMA copy for swizzled layout
bool is_load = copy_inst == CopyInst::kBulkLoad1D;
- auto shared_range = is_load ? dst_range : src_range;
- auto global_range = is_load ? src_range : dst_range;
- auto shared_tensor = is_load ? dst : src;
- auto global_tensor = is_load ? src : dst;
+ auto shared_range = is_load ? dst_range : src_range;
+ auto global_range = is_load ? src_range : dst_range;
+ Buffer shared_tensor = is_load ? dst : src;
+ Buffer global_tensor = is_load ? src : dst;
+ if (T.buffer_remap.count(shared_tensor)) shared_tensor = T.buffer_remap[shared_tensor];
+ if (T.buffer_remap.count(global_tensor)) global_tensor = T.buffer_remap[global_tensor];
- PrimExpr shared_elements = 1;
+ PrimExpr shared_elements = 1;
for (size_t i = 0; i < shared_range.size(); i++) {
shared_elements *= shared_range[i]->extent;
}
- std::vector<PrimExpr> shared_strides;
- PrimExpr shared_stride = 1;
+ std::vector<PrimExpr> shared_strides;
+ PrimExpr shared_stride = 1;
for (size_t i = 0; i < shared_tensor->shape.size(); i++) {
auto s = shared_tensor->shape[shared_tensor->shape.size() - i - 1];
shared_strides.insert(shared_strides.begin(), shared_stride);
shared_stride *= s;
}
- Array<PrimExpr> shared_indices;
+ Array<PrimExpr> shared_indices;
for (auto r : shared_range)
shared_indices.push_back(r->min);
- Array<PrimExpr> global_indices;
+ Array<PrimExpr> global_indices;
for (auto r : global_range) {
global_indices.push_back(r->min);
}
- std::vector<PrimExpr> global_strides;
- PrimExpr global_stride = 1;
+ std::vector<PrimExpr> global_strides;
+ PrimExpr global_stride = 1;
for (size_t i = 0; i < global_tensor->shape.size(); i++) {
auto s = global_tensor->shape[global_tensor->shape.size() - i - 1];
global_strides.insert(global_strides.begin(), global_stride);
global_stride *= s;
}
- PrimExpr global_offset = 0;
+ PrimExpr global_offset = 0;
for (size_t i = 0; i < global_indices.size(); i++) {
global_offset += global_indices[i] * global_strides[i];
}
- PrimExpr shared_offset = 0;
+ PrimExpr shared_offset = 0;
for (size_t i = 0; i < shared_indices.size(); i++) {
shared_offset += shared_indices[i] * shared_strides[i];
}
- PrimExpr elements = analyzer->Simplify(shared_elements);
- PrimExpr shared_addr = shared_tensor.access_ptr(
- is_load ? 2 : 1, DataType::Handle(), 1, shared_offset, elements);
- PrimExpr global_addr = global_tensor.access_ptr(
- is_load ? 1 : 2, DataType::Handle(), 1, global_offset, elements);
- Stmt tma_copy;
- if (is_load) {
- // the zero is a placeholder for mbarrier ids
- tma_copy = Evaluate(
- Call(DataType::Handle(), tma_load(),
- {shared_addr, global_addr, 0,
- elements * shared_tensor->dtype.bytes(), this->eviction_policy}));
- } else {
- tma_copy = Evaluate(
- Call(DataType::Handle(), tma_store(),
- {global_addr, shared_addr, elements * shared_tensor->dtype.bytes(),
- this->eviction_policy}));
- }
- tma_copy = IfThenElse(EQ(T.thread_var, T.thread_bounds->min), tma_copy);
- return tma_copy;
+ PrimExpr elements = analyzer->Simplify(shared_elements);
+ if (analyzer->CanProveEqual(elements, 0)) {
+ return IfThenElse(EQ(T.thread_var, T.thread_bounds->min), Evaluate(0));
+ }
+
+ // Build a minimal 1D TMA descriptor over the flattened global tensor.
+ TMADesc desc;
+ desc.rank = 1;
+ desc.data_type = to_CUtensorMapDataType(global_tensor->dtype);
+ desc.global_addr = global_tensor->data;
+ PrimExpr total_global_elems = 1;
+ for (auto s : global_tensor->shape) total_global_elems *= s;
+ desc.global_shape = { total_global_elems };
+ desc.global_stride = { cast(DataType::Int(64), global_tensor->dtype.bytes()) };
+ desc.smem_box = { elements };
+ desc.smem_stride = { PrimExpr(1) };
+ desc.interleave = static_cast<int>(CU_TENSOR_MAP_INTERLEAVE_NONE);
+ desc.swizzle = static_cast<int>(CU_TENSOR_MAP_SWIZZLE_NONE);
+ desc.l2_promotion = static_cast<int>(CU_TENSOR_MAP_L2_PROMOTION_L2_128B);
+ desc.oob_fill = static_cast<int>(CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE);
+
+ Call create_descriptor =
+ Call(DataType::Handle(), create_tma_descriptor(), desc.EncodeCallArgs());
+
+ // Addresses and 1D coordinate
+ PrimExpr shared_addr = shared_tensor.access_ptr(
+ is_load ? 2 : 1, DataType::Handle(), 1, shared_offset, elements);
+ PrimExpr crd0 = cast(DataType::Int(32), global_offset);
+
+ Array<PrimExpr> args;
+ args.push_back(create_descriptor);
+ if (is_load) args.push_back(0); // mbarrier placeholder
+ args.push_back(shared_addr);
+ args.push_back(crd0);
+ args.push_back(this->eviction_policy);
+ auto op = is_load ? tma_load() : tma_store();
+ Stmt tma_copy = Evaluate(Call(DataType::Handle(), op, args));
+ tma_copy = IfThenElse(EQ(T.thread_var, T.thread_bounds->min), tma_copy);
+ return tma_copy;
}📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| Stmt CopyNode::LowerBulkCopy1D(const LowerArgs &T, arith::Analyzer *analyzer, | |
| CopyInst copy_inst) const { | |
| ICHECK(copy_inst == CopyInst::kBulkLoad1D || | |
| copy_inst == CopyInst::kBulkStore1D); | |
| // Add 1D TMA copy when the global and shared memory is contiguous | |
| // Check if shared_tensor->name is present in T.buffer_var_gemm | |
| // (Array<PrimExpr>) to avoid use 1D TMA copy for swizzled layout | |
| bool is_load = copy_inst == CopyInst::kBulkLoad1D; | |
| auto shared_range = is_load ? dst_range : src_range; | |
| auto global_range = is_load ? src_range : dst_range; | |
| auto shared_tensor = is_load ? dst : src; | |
| auto global_tensor = is_load ? src : dst; | |
| PrimExpr shared_elements = 1; | |
| for (size_t i = 0; i < shared_range.size(); i++) { | |
| shared_elements *= shared_range[i]->extent; | |
| } | |
| std::vector<PrimExpr> shared_strides; | |
| PrimExpr shared_stride = 1; | |
| for (size_t i = 0; i < shared_tensor->shape.size(); i++) { | |
| auto s = shared_tensor->shape[shared_tensor->shape.size() - i - 1]; | |
| shared_strides.insert(shared_strides.begin(), shared_stride); | |
| shared_stride *= s; | |
| } | |
| Array<PrimExpr> shared_indices; | |
| for (auto r : shared_range) | |
| shared_indices.push_back(r->min); | |
| Array<PrimExpr> global_indices; | |
| for (auto r : global_range) { | |
| global_indices.push_back(r->min); | |
| } | |
| std::vector<PrimExpr> global_strides; | |
| PrimExpr global_stride = 1; | |
| for (size_t i = 0; i < global_tensor->shape.size(); i++) { | |
| auto s = global_tensor->shape[global_tensor->shape.size() - i - 1]; | |
| global_strides.insert(global_strides.begin(), global_stride); | |
| global_stride *= s; | |
| } | |
| PrimExpr global_offset = 0; | |
| for (size_t i = 0; i < global_indices.size(); i++) { | |
| global_offset += global_indices[i] * global_strides[i]; | |
| } | |
| PrimExpr shared_offset = 0; | |
| for (size_t i = 0; i < shared_indices.size(); i++) { | |
| shared_offset += shared_indices[i] * shared_strides[i]; | |
| } | |
| PrimExpr elements = analyzer->Simplify(shared_elements); | |
| PrimExpr shared_addr = shared_tensor.access_ptr( | |
| is_load ? 2 : 1, DataType::Handle(), 1, shared_offset, elements); | |
| PrimExpr global_addr = global_tensor.access_ptr( | |
| is_load ? 1 : 2, DataType::Handle(), 1, global_offset, elements); | |
| Stmt tma_copy; | |
| if (is_load) { | |
| // the zero is a placeholder for mbarrier ids | |
| tma_copy = Evaluate( | |
| Call(DataType::Handle(), tma_load(), | |
| {shared_addr, global_addr, 0, | |
| elements * shared_tensor->dtype.bytes(), this->eviction_policy})); | |
| } else { | |
| tma_copy = Evaluate( | |
| Call(DataType::Handle(), tma_store(), | |
| {global_addr, shared_addr, elements * shared_tensor->dtype.bytes(), | |
| this->eviction_policy})); | |
| } | |
| tma_copy = IfThenElse(EQ(T.thread_var, T.thread_bounds->min), tma_copy); | |
| return tma_copy; | |
| } | |
| Stmt CopyNode::LowerBulkCopy1D(const LowerArgs &T, arith::Analyzer *analyzer, | |
| CopyInst copy_inst) const { | |
| ICHECK(copy_inst == CopyInst::kBulkLoad1D || | |
| copy_inst == CopyInst::kBulkStore1D); | |
| bool is_load = copy_inst == CopyInst::kBulkLoad1D; | |
| auto shared_range = is_load ? dst_range : src_range; | |
| auto global_range = is_load ? src_range : dst_range; | |
| Buffer shared_tensor = is_load ? dst : src; | |
| Buffer global_tensor = is_load ? src : dst; | |
| // Honor any buffer remappings | |
| if (T.buffer_remap.count(shared_tensor)) { | |
| shared_tensor = T.buffer_remap[shared_tensor]; | |
| } | |
| if (T.buffer_remap.count(global_tensor)) { | |
| global_tensor = T.buffer_remap[global_tensor]; | |
| } | |
| // Compute flattened element count | |
| PrimExpr shared_elements = 1; | |
| for (size_t i = 0; i < shared_range.size(); i++) { | |
| shared_elements *= shared_range[i]->extent; | |
| } | |
| // Compute shared strides & base offset | |
| std::vector<PrimExpr> shared_strides; | |
| PrimExpr shared_stride = 1; | |
| for (size_t i = 0; i < shared_tensor->shape.size(); i++) { | |
| auto s = shared_tensor->shape[shared_tensor->shape.size() - i - 1]; | |
| shared_strides.insert(shared_strides.begin(), shared_stride); | |
| shared_stride *= s; | |
| } | |
| Array<PrimExpr> shared_indices; | |
| for (auto r : shared_range) { | |
| shared_indices.push_back(r->min); | |
| } | |
| PrimExpr shared_offset = 0; | |
| for (size_t i = 0; i < shared_indices.size(); i++) { | |
| shared_offset += shared_indices[i] * shared_strides[i]; | |
| } | |
| // Compute global strides & base offset | |
| Array<PrimExpr> global_indices; | |
| for (auto r : global_range) { | |
| global_indices.push_back(r->min); | |
| } | |
| std::vector<PrimExpr> global_strides; | |
| PrimExpr global_stride = 1; | |
| for (size_t i = 0; i < global_tensor->shape.size(); i++) { | |
| auto s = global_tensor->shape[global_tensor->shape.size() - i - 1]; | |
| global_strides.insert(global_strides.begin(), global_stride); | |
| global_stride *= s; | |
| } | |
| PrimExpr global_offset = 0; | |
| for (size_t i = 0; i < global_indices.size(); i++) { | |
| global_offset += global_indices[i] * global_strides[i]; | |
| } | |
| // Simplify and early-exit on zero elements | |
| PrimExpr elements = analyzer->Simplify(shared_elements); | |
| if (analyzer->CanProveEqual(elements, 0)) { | |
| return IfThenElse(EQ(T.thread_var, T.thread_bounds->min), | |
| Evaluate(0)); | |
| } | |
| // Build a minimal 1D TMA descriptor | |
| TMADesc desc; | |
| desc.rank = 1; | |
| desc.data_type = to_CUtensorMapDataType(global_tensor->dtype); | |
| desc.global_addr = global_tensor->data; | |
| PrimExpr total_elems = 1; | |
| for (auto s : global_tensor->shape) { | |
| total_elems *= s; | |
| } | |
| desc.global_shape = { total_elems }; | |
| desc.global_stride = { cast(DataType::Int(64), | |
| global_tensor->dtype.bytes()) }; | |
| desc.smem_box = { elements }; | |
| desc.smem_stride = { PrimExpr(1) }; | |
| desc.interleave = static_cast<int>(CU_TENSOR_MAP_INTERLEAVE_NONE); | |
| desc.swizzle = static_cast<int>(CU_TENSOR_MAP_SWIZZLE_NONE); | |
| desc.l2_promotion = static_cast<int>(CU_TENSOR_MAP_L2_PROMOTION_L2_128B); | |
| desc.oob_fill = static_cast<int>(CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); | |
| Call create_descriptor = | |
| Call(DataType::Handle(), | |
| create_tma_descriptor(), | |
| desc.EncodeCallArgs()); | |
| // Prepare the final call arguments | |
| PrimExpr shared_addr = shared_tensor.access_ptr( | |
| is_load ? 2 : 1, DataType::Handle(), 1, shared_offset, elements); | |
| PrimExpr crd0 = cast(DataType::Int(32), global_offset); | |
| Array<PrimExpr> args; | |
| args.push_back(create_descriptor); | |
| if (is_load) { | |
| args.push_back(0); // mbarrier placeholder | |
| } | |
| args.push_back(shared_addr); | |
| args.push_back(crd0); | |
| args.push_back(this->eviction_policy); | |
| auto op = is_load ? tma_load() : tma_store(); | |
| Stmt tma_copy = Evaluate(Call(DataType::Handle(), op, args)); | |
| tma_copy = IfThenElse(EQ(T.thread_var, T.thread_bounds->min), tma_copy); | |
| return tma_copy; | |
| } |
🤖 Prompt for AI Agents
In src/op/copy.cc around lines 1247 to 1320 the 1D TMA lowering is incorrect: it
calls tma_load/tma_store with (shared_addr, global_addr, mbar, bytes, eviction)
but the intrinsic expects a descriptor as the first arg, optional mbar id, then
smem_ptr, crd0, and cache hint; it also ignores T.buffer_remap and doesn't guard
zero-length copies. Fix by (1) early-out: wrap the TMA call in an IfThenElse
checking Simplified(elements) > 0 so zero-length does nothing; (2) build a TMA
descriptor from the global address/shape/strides/element count (use the proper
helper/intrinsic to create a descriptor from global_addr + elements/strides) and
pass that descriptor as the first argument to tma_load/tma_store; (3) include
the mbar placeholder only where required, then pass smem_ptr (shared_addr) and
the coordinate/crd0 derived from global_offset/indices, and finally pass the
eviction/cache hint; (4) respect T.buffer_remap: if the shared/global tensor is
remapped (check T.buffer_remap for the tensor name) either fall back to the
non-1D path or apply the remap when computing addresses so the descriptor and
pointers reflect the remapped layout. Ensure the final Call uses the corrected
argument order and types and keep the IfThenElse guard around the whole TMA
Evaluate.
- Commented out the call to `tilelang.testing.main()` in `test_example_gdn_compilation.py` and replaced it with a direct call to `test_example_chunk_delta_bwd_compilation()`. This change simplifies the test execution flow and focuses on the specific test case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
src/op/copy.cc (1)
1018-1050: Compute shared_offset after remap using layout->Forward/OffsetOfshared_offset is derived before potential buffer remap and without applying the shared layout transform. This can yield wrong SMEM base for swizzled layouts.
Apply:
- Array<PrimExpr> shared_indices; - for (auto r : shared_range) - shared_indices.push_back(r->min); - std::vector<PrimExpr> shared_strides; - PrimExpr shared_stride = 1; - for (size_t i = 0; i < shared_tensor->shape.size(); i++) { - auto s = shared_tensor->shape[shared_tensor->shape.size() - i - 1]; - shared_strides.insert(shared_strides.begin(), shared_stride); - shared_stride *= s; - } + Array<PrimExpr> shared_indices; + for (auto r : shared_range) shared_indices.push_back(r->min); @@ - ICHECK(shared_strides.size() == shared_indices.size()) - << "shared_strides.size() != shared_indices.size()" - << shared_strides.size() << " " << shared_indices.size(); - PrimExpr shared_offset = 0; - for (size_t i = 0; i < shared_indices.size(); i++) { - shared_offset += shared_indices[i] * shared_strides[i]; - } + // shared_offset is computed after possible remap below
♻️ Duplicate comments (5)
src/op/copy.h (1)
191-194: Make buffer_oob non-defaulted across declaration/definitionHeader correctly requires an explicit buffer_oob; the definition in copy.cc still supplies a default. Please remove the default from the definition for consistency and to prevent unsafe callsites.
src/op/copy.cc (4)
646-650: Remove default argument for buffer_oob in definitionKeep defaults only in declarations (and here we want no default).
Apply:
-CopyInst CopyNode::GetCopyInst(Target target, bool disable_tma_lower, - const LayoutMap &layout_map, - arith::Analyzer *analyzer, - bool buffer_oob = false) const { +CopyInst CopyNode::GetCopyInst(Target target, bool disable_tma_lower, + const LayoutMap &layout_map, + arith::Analyzer *analyzer, + bool buffer_oob) const {
486-531: Strengthen 1D contiguity: add symmetric shared-range checkCurrently only global contiguity is validated by ranges; shared contiguity is only inferred from absence in layout_map. Add a shared-range walk mirroring the global logic to avoid false positives.
Apply:
// Step 1: check shared is contiguous bool shared_is_contiguous = true; if (layout_map.count(shared_tensor)) { shared_is_contiguous = false; } // Step 2: check global is contiguous bool global_is_contiguous = true; bool global_not_full_dim_encounter = false; @@ } + // Step 2b: check shared is contiguous by ranges (same rule as global) + bool shared_not_full_dim_encounter = false; + for (int i = shared_range.size() - 1; i >= 0; i--) { + if (!shared_not_full_dim_encounter) { + if (!analyzer->CanProve(shared_range[i]->extent == shared_tensor->shape[i] && + shared_range[i]->min == 0, + arith::ProofStrength::kSymbolicBound)) { + shared_not_full_dim_encounter = true; + } + } else { + if (!analyzer->CanProve(shared_range[i]->extent == 1, + arith::ProofStrength::kSymbolicBound)) { + shared_is_contiguous = false; + break; + } + } + }
692-699: Pass buffer_oob to GetCopyInst in Lower()Lower() re-selects the copy inst but drops OOB gating, potentially choosing 1D TMA incorrectly.
Apply:
- auto copy_inst = - GetCopyInst(target, disable_tma_lower, T.layout_map, analyzer); + auto copy_inst = + GetCopyInst(target, disable_tma_lower, T.layout_map, analyzer, T.buffer_oob);
1281-1354: Fix 1D TMA lowering: wrong intrinsic shape, missing descriptor/remap/zero-guardCurrent path calls tma_load/tma_store with raw pointers/byte counts, ignores descriptor+coords form, ignores buffer_remap, and doesn’t guard zero-length copies. Rewrite to the descriptor-based 1D tensor form.
Apply:
Stmt CopyNode::LowerBulkCopy1D(const LowerArgs &T, arith::Analyzer *analyzer, CopyInst copy_inst) const { ICHECK(copy_inst == CopyInst::kBulkLoad1D || copy_inst == CopyInst::kBulkStore1D); - // Add 1D TMA copy when the global and shared memory is contiguous - // Check if shared_tensor->name is present in T.buffer_var_gemm - // (Array<PrimExpr>) to avoid use 1D TMA copy for swizzled layout bool is_load = copy_inst == CopyInst::kBulkLoad1D; - auto shared_range = is_load ? dst_range : src_range; - auto global_range = is_load ? src_range : dst_range; - auto shared_tensor = is_load ? dst : src; - auto global_tensor = is_load ? src : dst; + auto shared_range = is_load ? dst_range : src_range; + auto global_range = is_load ? src_range : dst_range; + Buffer shared_tensor = is_load ? dst : src; + Buffer global_tensor = is_load ? src : dst; + if (T.buffer_remap.count(shared_tensor)) shared_tensor = T.buffer_remap.at(shared_tensor); + if (T.buffer_remap.count(global_tensor)) global_tensor = T.buffer_remap.at(global_tensor); PrimExpr shared_elements = 1; for (size_t i = 0; i < shared_range.size(); i++) { shared_elements *= shared_range[i]->extent; } @@ - PrimExpr elements = analyzer->Simplify(shared_elements); - PrimExpr shared_addr = shared_tensor.access_ptr( - is_load ? 2 : 1, DataType::Handle(), 1, shared_offset, elements); - PrimExpr global_addr = global_tensor.access_ptr( - is_load ? 1 : 2, DataType::Handle(), 1, global_offset, elements); - Stmt tma_copy; - if (is_load) { - // the zero is a placeholder for mbarrier ids - tma_copy = Evaluate( - Call(DataType::Handle(), tma_load(), - {shared_addr, global_addr, 0, - elements * shared_tensor->dtype.bytes(), this->eviction_policy})); - } else { - tma_copy = Evaluate( - Call(DataType::Handle(), tma_store(), - {global_addr, shared_addr, elements * shared_tensor->dtype.bytes(), - this->eviction_policy})); - } - tma_copy = IfThenElse(EQ(T.thread_var, T.thread_bounds->min), tma_copy); - return tma_copy; + PrimExpr elements = analyzer->Simplify(shared_elements); + if (analyzer->CanProveEqual(elements, 0)) { + return IfThenElse(EQ(T.thread_var, T.thread_bounds->min), Evaluate(0)); + } + + // Build a minimal 1D TMA descriptor over the flattened global tensor + TMADesc desc; + desc.rank = 1; + desc.data_type = to_CUtensorMapDataType(global_tensor->dtype); + desc.global_addr = global_tensor->data; + PrimExpr total_gmem_elems = 1; + for (auto s : global_tensor->shape) total_gmem_elems *= s; + desc.global_shape = { total_gmem_elems }; + desc.global_stride = { cast(DataType::Int(64), global_tensor->dtype.bytes()) }; + desc.smem_box = { elements }; + desc.smem_stride = { PrimExpr(1) }; + desc.interleave = static_cast<int>(CU_TENSOR_MAP_INTERLEAVE_NONE); + desc.swizzle = static_cast<int>(CU_TENSOR_MAP_SWIZZLE_NONE); + desc.l2_promotion = static_cast<int>(CU_TENSOR_MAP_L2_PROMOTION_L2_128B); + desc.oob_fill = static_cast<int>(CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); + + Call create_descriptor = + Call(DataType::Handle(), create_tma_descriptor(), desc.EncodeCallArgs()); + + // Addresses and 1D coordinate + PrimExpr shared_addr = shared_tensor.access_ptr( + is_load ? 2 : 1, DataType::Handle(), 1, shared_offset, elements); + PrimExpr crd0 = cast(DataType::Int(32), global_offset); + + Array<PrimExpr> args; + args.push_back(create_descriptor); + if (is_load) args.push_back(0); // mbarrier placeholder + args.push_back(shared_addr); + args.push_back(crd0); + args.push_back(this->eviction_policy); + auto op = is_load ? tma_load() : tma_store(); + Stmt tma_copy = Evaluate(Call(DataType::Handle(), op, args)); + tma_copy = IfThenElse(EQ(T.thread_var, T.thread_bounds->min), tma_copy); + return tma_copy; }
🧹 Nitpick comments (3)
src/op/copy.cc (3)
198-206: Fix doc typo: “Create s” → “Create indices”Minor but visible in docs.
Apply:
- * \brief Create s for the copy operation. + * \brief Create indices for the copy operation.
751-754: Propagate buffer_oob during loop layout inference (consistency)Parallel layout inference likely ignores buffer_oob, but passing the real value avoids surprises if it’s used later.
Apply:
- par_op->InferLayout({T.target, T.thread_bounds, T.layout_map, analyzer, - false, T.buffer_remap}, + par_op->InferLayout({T.target, T.thread_bounds, T.layout_map, analyzer, + T.buffer_oob, T.buffer_remap}, level);
1281-1354: Optional: unit tests for 1D/2D TMA selectionAdd tests to assert that:
- 1D TMA is selected only when buffer_oob == false and contiguity holds,
- Lower() and InferLayout agree on the selected CopyInst,
- 1D lowering emits descriptor form (argument count/shape).
I can draft minimal TIR unit-tests that exercise these paths if helpful.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (2)
src/op/copy.cc(17 hunks)src/op/copy.h(4 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
src/op/copy.cc (2)
src/tl_templates/cuda/copy_sm90.h (12)
tma_load(18-27)tma_load(43-60)tma_load(64-82)tma_load(86-104)tma_load(107-126)tma_load(130-150)tma_store(175-182)tma_store(185-195)tma_store(198-209)tma_store(212-223)tma_store(226-238)tma_store(241-253)tilelang/language/builtin.py (1)
tma_load(67-76)
src/op/copy.h (1)
src/op/copy.cc (20)
InferLayout(398-433)InferLayout(398-399)InferLayout(1645-1648)InferLayout(1645-1646)CheckBulkLoad(453-484)CheckBulkLoad(453-453)CheckBulkStore(569-599)CheckBulkStore(569-569)CheckBulkLoad1D(533-543)CheckBulkLoad1D(533-534)CheckBulkStore1D(545-555)CheckBulkStore1D(545-546)CheckBulkCopy1D(486-531)CheckBulkCopy1D(486-491)GetCopyInst(646-672)GetCopyInst(646-649)LowerBulkCopy(1001-1279)LowerBulkCopy(1001-1002)LowerBulkCopy1D(1281-1354)LowerBulkCopy1D(1281-1282)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: format-check
🔇 Additional comments (8)
src/op/copy.h (5)
18-26: Enum additions for 1D TMA look goodkBulkLoad1D/kBulkStore1D separation is appropriate.
144-146: InferLayout override signature OKPlumbs analyzer/layout_map/buffer_oob through inference.
150-156: Bulk load/store check declarations OKSignatures align with usage in copy.cc.
157-176: 1D bulk-copy check API looks rightParameters cover the necessary context for contiguity checks.
203-206: LowerBulkCopy1D hook: interface LGTMSignature is consistent with other lowerers.
src/op/copy.cc (3)
406-408: Good: InferLayout gates 1D TMA on buffer_oobThe selection now respects OOB status during inference.
1257-1270: No change required here after above fixesshared_addr will use the corrected shared_offset introduced after remap.
To be safe, run a smoke-test that exercises the bulk 2D path with a swizzled shared layout to confirm correct tile placement post-remap.
398-433: Nit: ParOp caching/inference flow looks solidNothing to change; noting this to acknowledge the improvement.
| if (T.layout_map.count(shared_tensor)) { | ||
| shared_layout = T.layout_map.at(shared_tensor); | ||
| ICHECK(T.buffer_remap.count(shared_tensor)) | ||
| << "shared_tensor: " << shared_tensor->name | ||
| << " not found in buffer_remap"; | ||
| shared_tensor = T.buffer_remap.at(shared_tensor); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Derive shared_offset with remap applied
After determining shared_layout and remapping the buffer, transform indices and use OffsetOf to compute the correct base.
Apply:
if (!shared_layout.defined()) {
desc.swizzle = static_cast<int>(CU_TENSOR_MAP_SWIZZLE_NONE);
} else if (StructuralEqual()(shared_layout, linear_layout)) {
desc.swizzle = static_cast<int>(CU_TENSOR_MAP_SWIZZLE_NONE);
} else {
@@
}
+
+ // Compute SMEM base offset in the (possibly) remapped buffer
+ Array<PrimExpr> shared_coords = shared_indices;
+ if (shared_layout.defined()) {
+ shared_coords = shared_layout->Forward(shared_indices);
+ }
+ PrimExpr shared_offset = shared_tensor.OffsetOf(shared_coords).back();📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if (T.layout_map.count(shared_tensor)) { | |
| shared_layout = T.layout_map.at(shared_tensor); | |
| ICHECK(T.buffer_remap.count(shared_tensor)) | |
| << "shared_tensor: " << shared_tensor->name | |
| << " not found in buffer_remap"; | |
| shared_tensor = T.buffer_remap.at(shared_tensor); | |
| } | |
| if (!shared_layout.defined()) { | |
| desc.swizzle = static_cast<int>(CU_TENSOR_MAP_SWIZZLE_NONE); | |
| } else if (StructuralEqual()(shared_layout, linear_layout)) { | |
| desc.swizzle = static_cast<int>(CU_TENSOR_MAP_SWIZZLE_NONE); | |
| } else { | |
| // existing swizzle-from-shared_layout logic | |
| desc.swizzle = /* ... */; | |
| } | |
| // Compute SMEM base offset in the (possibly) remapped buffer | |
| Array<PrimExpr> shared_coords = shared_indices; | |
| if (shared_layout.defined()) { | |
| shared_coords = shared_layout->Forward(shared_indices); | |
| } | |
| PrimExpr shared_offset = shared_tensor.OffsetOf(shared_coords).back(); |
🤖 Prompt for AI Agents
In src/op/copy.cc around lines 1149 to 1155, after you determine shared_layout
and then remap shared_tensor via T.buffer_remap, you must recompute the
shared_offset using the remapped tensor and its layout: transform the original
indices by the shared_layout (taking into account any axis permutations/strides)
and call OffsetOf on the remapped shared_tensor to obtain the correct base
offset. Update the code so that OffsetOf is invoked with the transformed indices
and the remapped/shared_tensor (not the pre-remap tensor), ensuring
shared_offset reflects the remapped buffer layout.
…tion - Updated `CheckBulkLoad` and `CheckBulkStore` methods in `CopyNode` to include an optional parameter for validating the last dimension during bulk copy operations. - Adjusted related methods `CheckBulkLoad1D` and `CheckBulkStore1D` to pass the new parameter, improving the accuracy of bulk copy checks. - This change enhances the robustness of memory operations in the TileLang framework by ensuring compliance with dimensional requirements.
…ility - Reformatted the parameter lists of `CheckBulkLoad` and `CheckBulkStore` methods in `CopyNode` to enhance code clarity by aligning parameters across multiple lines. - This change improves the maintainability of the code and adheres to coding standards.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
src/op/copy.cc (1)
650-676: Includebuffer_oobargument in GetCopyInst call
At src/op/copy.cc:697, update the 4-arg callGetCopyInst(target, disable_tma_lower, T.layout_map, analyzer)to pass
T.buffer_oobas the fifth parameter, since the default was removed.
♻️ Duplicate comments (4)
src/op/copy.cc (4)
488-533: Complete shared-side contiguity and fix potential misuse
- Only global is range-checked for contiguity; shared needs the same tail-ones rule.
- You correctly use shared_tensor for layout_map (fixes the old dst misuse), but please add the shared range check.
Apply:
bool CopyNode::CheckBulkCopy1D(const Buffer &global_tensor, const Buffer &shared_tensor, const Array<Range> &global_range, const Array<Range> &shared_range, const LayoutMap &layout_map, arith::Analyzer *analyzer) const { // Step 1: check shared is contiguous bool shared_is_contiguous = true; if (layout_map.count(shared_tensor)) { shared_is_contiguous = false; } // Step 2: check global is contiguous bool global_is_contiguous = true; bool global_not_full_dim_encounter = false; for (int i = global_range.size() - 1; i >= 0; i--) { @@ } + + // Step 2b: check shared is contiguous via ranges (mirror global rule) + bool shared_not_full_dim_encounter = false; + for (int i = shared_range.size() - 1; i >= 0; i--) { + if (!shared_not_full_dim_encounter) { + if (!analyzer->CanProve(shared_range[i]->extent == shared_tensor->shape[i] && + shared_range[i]->min == 0, + arith::ProofStrength::kSymbolicBound)) { + shared_not_full_dim_encounter = true; + } + } else { + if (!analyzer->CanProve(shared_range[i]->extent == 1, + arith::ProofStrength::kSymbolicBound)) { + shared_is_contiguous = false; + break; + } + } + } @@ - return (shared_is_contiguous && global_is_contiguous && element_match); + return (shared_is_contiguous && global_is_contiguous && element_match); }
650-653: Remove default argument in definition (must match header)Default args in definitions are invalid in C++; header already removed it.
Apply:
-CopyInst CopyNode::GetCopyInst(Target target, bool disable_tma_lower, - const LayoutMap &layout_map, - arith::Analyzer *analyzer, - bool buffer_oob = false) const { +CopyInst CopyNode::GetCopyInst(Target target, bool disable_tma_lower, + const LayoutMap &layout_map, + arith::Analyzer *analyzer, + bool buffer_oob) const {
696-703: Lower() must pass buffer_oob to GetCopyInstOmitting it re-enables 1D TMA during lowering even when OOB exists.
Apply:
- auto copy_inst = - GetCopyInst(target, disable_tma_lower, T.layout_map, analyzer); + auto copy_inst = + GetCopyInst(target, disable_tma_lower, T.layout_map, analyzer, T.buffer_oob);
1285-1358: Fix 1D TMA lowering: wrong intrinsic calling convention, missing descriptor, no remap, no zero-length guardCurrent path passes (shared_addr, global_addr, mbar, bytes, eviction). Intrinsics expect (descriptor, [mbar], smem_ptr, crd0, cache_hint). Also ignores buffer remap and zero elements.
Apply:
Stmt CopyNode::LowerBulkCopy1D(const LowerArgs &T, arith::Analyzer *analyzer, CopyInst copy_inst) const { ICHECK(copy_inst == CopyInst::kBulkLoad1D || copy_inst == CopyInst::kBulkStore1D); - // Add 1D TMA copy when the global and shared memory is contiguous - // Check if shared_tensor->name is present in T.buffer_var_gemm - // (Array<PrimExpr>) to avoid use 1D TMA copy for swizzled layout bool is_load = copy_inst == CopyInst::kBulkLoad1D; - auto shared_range = is_load ? dst_range : src_range; - auto global_range = is_load ? src_range : dst_range; - auto shared_tensor = is_load ? dst : src; - auto global_tensor = is_load ? src : dst; + auto shared_range = is_load ? dst_range : src_range; + auto global_range = is_load ? src_range : dst_range; + Buffer shared_tensor = is_load ? dst : src; + Buffer global_tensor = is_load ? src : dst; + if (T.buffer_remap.count(shared_tensor)) shared_tensor = T.buffer_remap.at(shared_tensor); + if (T.buffer_remap.count(global_tensor)) global_tensor = T.buffer_remap.at(global_tensor); PrimExpr shared_elements = 1; for (size_t i = 0; i < shared_range.size(); i++) { shared_elements *= shared_range[i]->extent; } @@ - PrimExpr global_offset = 0; + PrimExpr global_offset = 0; for (size_t i = 0; i < global_indices.size(); i++) { global_offset += global_indices[i] * global_strides[i]; } @@ - PrimExpr shared_offset = 0; + PrimExpr shared_offset = 0; for (size_t i = 0; i < shared_indices.size(); i++) { shared_offset += shared_indices[i] * shared_strides[i]; } - PrimExpr elements = analyzer->Simplify(shared_elements); - PrimExpr shared_addr = shared_tensor.access_ptr( - is_load ? 2 : 1, DataType::Handle(), 1, shared_offset, elements); - PrimExpr global_addr = global_tensor.access_ptr( - is_load ? 1 : 2, DataType::Handle(), 1, global_offset, elements); - Stmt tma_copy; - if (is_load) { - // the zero is a placeholder for mbarrier ids - tma_copy = Evaluate( - Call(DataType::Handle(), tma_load(), - {shared_addr, global_addr, 0, - elements * shared_tensor->dtype.bytes(), this->eviction_policy})); - } else { - tma_copy = Evaluate( - Call(DataType::Handle(), tma_store(), - {global_addr, shared_addr, elements * shared_tensor->dtype.bytes(), - this->eviction_policy})); - } - tma_copy = IfThenElse(EQ(T.thread_var, T.thread_bounds->min), tma_copy); - return tma_copy; + PrimExpr elements = analyzer->Simplify(shared_elements); + // Early out on zero-length copies + if (analyzer->CanProveEqual(elements, 0)) { + return IfThenElse(EQ(T.thread_var, T.thread_bounds->min), Evaluate(0)); + } + + // Build 1D TMA descriptor over flattened global tensor + TMADesc desc; + desc.rank = 1; + desc.data_type = to_CUtensorMapDataType(global_tensor->dtype); + desc.global_addr = global_tensor->data; + PrimExpr total_global_elems = 1; + for (auto s : global_tensor->shape) total_global_elems *= s; + desc.global_shape = { total_global_elems }; + desc.global_stride = { cast(DataType::Int(64), global_tensor->dtype.bytes()) }; + desc.smem_box = { elements }; + desc.smem_stride = { PrimExpr(1) }; + desc.interleave = static_cast<int>(CU_TENSOR_MAP_INTERLEAVE_NONE); + desc.swizzle = static_cast<int>(CU_TENSOR_MAP_SWIZZLE_NONE); + desc.l2_promotion = static_cast<int>(CU_TENSOR_MAP_L2_PROMOTION_L2_128B); + desc.oob_fill = static_cast<int>(CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); + Call create_descriptor = + Call(DataType::Handle(), create_tma_descriptor(), desc.EncodeCallArgs()); + + // Addresses and coordinate + PrimExpr shared_addr = shared_tensor.access_ptr( + is_load ? 2 : 1, DataType::Handle(), 1, shared_offset, elements); + PrimExpr crd0 = cast(DataType::Int(32), global_offset); + + Array<PrimExpr> args; + args.push_back(create_descriptor); + if (is_load) args.push_back(0); // mbarrier placeholder + args.push_back(shared_addr); + args.push_back(crd0); + args.push_back(this->eviction_policy); + auto op = is_load ? tma_load() : tma_store(); + Stmt tma_copy = Evaluate(Call(DataType::Handle(), op, args)); + tma_copy = IfThenElse(EQ(T.thread_var, T.thread_bounds->min), tma_copy); + return tma_copy; }
🧹 Nitpick comments (2)
src/op/copy.h (1)
159-170: Clarify 1D checks contract (doc-only)Document that 1D path requires:
- shared tensor has no layout_map entry (no swizzle/remap),
- both global and shared ranges are contiguous (tail-ones rule),
- dtypes match and no OOB at call site.
src/op/copy.cc (1)
199-206: Nit: fix comment typo“Create s” → “Create indices”.
Apply:
- * \brief Create s for the copy operation. + * \brief Create indices for the copy operation.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (2)
src/op/copy.cc(17 hunks)src/op/copy.h(4 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
src/op/copy.h (1)
src/op/copy.cc (20)
InferLayout(398-433)InferLayout(398-399)InferLayout(1649-1652)InferLayout(1649-1650)CheckBulkLoad(453-486)CheckBulkLoad(453-454)CheckBulkStore(571-603)CheckBulkStore(571-572)CheckBulkLoad1D(535-545)CheckBulkLoad1D(535-536)CheckBulkStore1D(547-557)CheckBulkStore1D(547-548)CheckBulkCopy1D(488-533)CheckBulkCopy1D(488-493)GetCopyInst(650-676)GetCopyInst(650-653)LowerBulkCopy(1005-1283)LowerBulkCopy(1005-1006)LowerBulkCopy1D(1285-1358)LowerBulkCopy1D(1285-1286)
src/op/copy.cc (2)
src/tl_templates/cuda/copy_sm90.h (12)
tma_load(18-27)tma_load(43-60)tma_load(64-82)tma_load(86-104)tma_load(107-126)tma_load(130-150)tma_store(175-182)tma_store(185-195)tma_store(198-209)tma_store(212-223)tma_store(226-238)tma_store(241-253)tilelang/language/builtin.py (1)
tma_load(67-76)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: bot-task
🔇 Additional comments (7)
src/op/copy.h (5)
18-26: Enum extension for 1D TMA looks goodNew kBulkLoad1D/kBulkStore1D entries are clear and non-breaking.
144-145: InferLayout override is appropriateConst-correct override is consistent with usage.
150-157: Bulk load/store signatures OKKeeping check_last_dim with a default is fine for 2D; 1D path disables it at call sites.
193-196: GetCopyInst signature (no default) is correctHeader removes the default for buffer_oob. Ensure the .cc definition matches (it currently doesn't).
204-209: 1D lowering hook addition is goodSeparation keeps 1D path self-contained.
src/op/copy.cc (2)
406-408: Good: InferLayout now gates 1D TMA on buffer_oobPassing T.buffer_oob into GetCopyInst avoids unsafe 1D selections during inference.
755-758: OK: pass explicit buffer_oob=false to layout inference inside normal pathIntentional and safe since normal copy tolerates predicates and masking.
| Array<PrimExpr> shared_indices; | ||
| for (auto r : shared_range) | ||
| indices.push_back(r->min); | ||
| std::vector<PrimExpr> strides; | ||
| PrimExpr stride = 1; | ||
| shared_indices.push_back(r->min); | ||
| std::vector<PrimExpr> shared_strides; | ||
| PrimExpr shared_stride = 1; | ||
| for (size_t i = 0; i < shared_tensor->shape.size(); i++) { | ||
| auto s = shared_tensor->shape[shared_tensor->shape.size() - i - 1]; | ||
| strides.insert(strides.begin(), stride); | ||
| stride *= s; | ||
| shared_strides.insert(shared_strides.begin(), shared_stride); | ||
| shared_stride *= s; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Compute shared_offset after layout remap
shared_offset is derived before applying layout_map/buffer_remap to shared_tensor, leading to wrong base addresses with swizzled layouts.
Apply:
- Array<PrimExpr> shared_indices;
- for (auto r : shared_range)
- shared_indices.push_back(r->min);
+ Array<PrimExpr> shared_indices;
+ for (auto r : shared_range) shared_indices.push_back(r->min);
@@
- PrimExpr shared_offset = 0;
- for (size_t i = 0; i < shared_indices.size(); i++) {
- shared_offset += shared_indices[i] * shared_strides[i];
- }
+ PrimExpr shared_offset = 0; // recomputed after possible remap
@@
Layout shared_layout;
if (T.layout_map.count(shared_tensor)) {
shared_layout = T.layout_map.at(shared_tensor);
ICHECK(T.buffer_remap.count(shared_tensor))
<< "shared_tensor: " << shared_tensor->name
<< " not found in buffer_remap";
shared_tensor = T.buffer_remap.at(shared_tensor);
}
+ // Derive offset in the (possibly) remapped buffer
+ {
+ Array<PrimExpr> shared_coords = shared_indices;
+ if (shared_layout.defined()) {
+ shared_coords = shared_layout->Forward(shared_indices);
+ }
+ shared_offset = shared_tensor.OffsetOf(shared_coords).back();
+ }
@@
- PrimExpr shared_addr = shared_tensor.access_ptr(
- is_load ? 2 : 1, DataType::Handle(), 1,
- shared_offset + total_elements * loop_var, total_elements);
+ PrimExpr shared_addr = shared_tensor.access_ptr(
+ is_load ? 2 : 1, DataType::Handle(), 1,
+ shared_offset + total_elements * loop_var, total_elements);
@@
- PrimExpr shared_addr = shared_tensor.access_ptr(
- is_load ? 2 : 1, DataType::Handle(), 1, shared_offset, total_elements);
+ PrimExpr shared_addr = shared_tensor.access_ptr(
+ is_load ? 2 : 1, DataType::Handle(), 1, shared_offset, total_elements);Also applies to: 1048-1054, 1152-1159, 1261-1274
🤖 Prompt for AI Agents
In src/op/copy.cc around lines 1025-1034 (and similarly at 1048-1054, 1152-1159,
1261-1274): shared_offset is computed from shared_tensor shape/strides before
applying any layout_map or buffer_remap, so the base address is wrong for
swizzled/remapped layouts; recompute shared_strides and shared_offset after
performing layout/buffer remap on shared_tensor (i.e., use the
remapped/shared_tensor layout to derive strides and then compute shared_offset),
ensuring you apply the same layout_map/buffer_remap transformations used
elsewhere before calculating offsets so base addresses reflect the final memory
layout.
CopyNodeto introduce separate handling for 1D bulk load/store operations, including new methods for checking and lowering these operations.InferLayoutandGetCopyInstto accommodate additional parameters for layout maps and analyzers.AtomicAddNodeandFillNodeto utilize the updated layout inference logic.This update improves the efficiency and correctness of memory operations in the TileLang framework.
Summary by CodeRabbit
New Features
Improvements
Refactor
Documentation