-
Couldn't load subscription status.
- Fork 286
[Refactor] Refactor Operator into TileOperator and with tvm reflection
#763
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…out inference methods - Changed base class of several operator classes (AtomicAdd, Copy, Gemm, etc.) from Operator to TileOperator for better alignment with tile operations. - Updated InferLayout and Lower methods to use 'override' specifier for clarity and consistency. - Adjusted header inclusions to replace "op.h" with "operator.h" across multiple files for improved organization. - Added missing layout inference implementations for Fill and Conv2DIm2ColOp. - Removed deprecated op.cc and op.h files to streamline the codebase.
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughRefactors the operator system to a TileOperatorNode/TileOperator object-ref design, adds RegionOp and parsing in Changes
Sequence Diagram(s)%%{init: {"theme":"neutral","themeVariables":{"primaryColor":"#2b7a78","secondaryColor":"#f6f7f9","actorBorder":"#4a5568"}}}%%
sequenceDiagram
autonumber
participant Transform as TransformPass
participant Parser as ParseOperator
participant TileRef as TileOperator
participant Node as TileOperatorNode
participant Region as RegionOp
Transform->>Parser: ParseOperator(call/stmt, vmap)
alt builder found
Parser-->>Transform: TileOperator (wrapper)
Transform->>TileRef: tile_op.InferLayout(args, level)
TileRef->>Node: dispatch to Node::InferLayout(...)
Node-->>TileRef: LayoutMap
Transform->>TileRef: tile_op.Lower(lower_args, analyzer)
TileRef->>Node: dispatch to Node::Lower(...)
Node-->>TileRef: Stmt
else not found
Parser-->>Transform: empty TileOperator
end
note right of Region: RegionOp registered as "region"\nused by parser & layout helpers
Estimated code review effort🎯 5 (Critical) | ⏱️ ~120 minutes Possibly related PRs
Poem
Tip 🔌 Remote MCP (Model Context Protocol) integration is now available!Pro plan users can now connect to remote MCP servers from the Integrations page. Connect with popular remote MCPs such as Notion and Linear to add more context to your reviews and chats. ✨ Finishing Touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. CodeRabbit Commands (Invoked using PR/Issue comments)Type Other keywords and placeholders
CodeRabbit Configuration File (
|
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @LeiWang1999, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request refactors the core Operator class into TileOperator to improve modularity and maintainability within the Tile Library. This change involves renaming the base class, updating all derived operator classes to inherit from the new TileOperator, and enhancing const correctness across the operator hierarchy. Additionally, the RegionOp functionality has been extracted into its own dedicated files, further modularizing the codebase.
Highlights
- Core Class Refactoring: The core
Operatorclass has been renamed toTileOperator, and all derived classes (e.g.,AtomicAdd,Copy,Gemm,Fill,ReduceOp,ParallelOp) have been updated to inherit fromTileOperator. - Improved
constCorrectness: Many methods, particularlyInferLayout, have been markedconst, and associated member variables (par_op_,completed_,loop_layout_,predicate_) have been mademutableto allow state modification withinconstmethods where necessary. - Modularity Enhancement: The
RegionOpclass and its related functionalities have been extracted from the generalopfiles into new, dedicatedregion.handregion.ccfiles, enhancing code organization. - Expanded Operator Implementations: New
InferLayoutimplementations have been added forFillandConv2DIm2ColOp, and a basicLowerimplementation forParallelOp, enhancing the completeness of the operator definitions.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request is a significant and beneficial refactoring that renames the Operator class to TileOperator and modularizes the related code by splitting op.h/op.cc into separate files for operator and region. The changes are consistent across the codebase, with derived classes updated to use override and InferLayout methods correctly marked as const. My review focuses on a few minor points to improve code consistency and maintainability.
src/op/operator.h
Outdated
| // Lower 接口 | ||
| virtual Stmt Lower(const LowerArgs& T, arith::Analyzer* analyzer) const { | ||
| ICHECK(0) << "Not Implemented Lower method."; | ||
| return Evaluate(0); | ||
| } | ||
|
|
||
| const Buffer &GetBuffer() const { return buffer_; } | ||
| const Array<Range> &GetRanges() const { return ranges_; } | ||
| int GetAccessMask() const { return access_mask_; } | ||
| bool IsFullRegion() const; | ||
| // InferLayout 接口 | ||
| virtual LayoutMap InferLayout(const LayoutInferArgs& T, InferLevel level) const { | ||
| return {}; | ||
| } | ||
|
|
||
| private: | ||
| Buffer buffer_; | ||
| Array<Range> ranges_; | ||
| int access_mask_; | ||
| // Clone 接口 | ||
| virtual std::unique_ptr<TileOperator> Clone() const = 0; | ||
|
|
||
| // 虚析构函数 | ||
| virtual ~TileOperator() = default; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comments for the TileOperator class members are in Chinese. For consistency with the rest of the codebase, which is in English, these should be translated to English.
| // Lower 接口 | |
| virtual Stmt Lower(const LowerArgs& T, arith::Analyzer* analyzer) const { | |
| ICHECK(0) << "Not Implemented Lower method."; | |
| return Evaluate(0); | |
| } | |
| const Buffer &GetBuffer() const { return buffer_; } | |
| const Array<Range> &GetRanges() const { return ranges_; } | |
| int GetAccessMask() const { return access_mask_; } | |
| bool IsFullRegion() const; | |
| // InferLayout 接口 | |
| virtual LayoutMap InferLayout(const LayoutInferArgs& T, InferLevel level) const { | |
| return {}; | |
| } | |
| private: | |
| Buffer buffer_; | |
| Array<Range> ranges_; | |
| int access_mask_; | |
| // Clone 接口 | |
| virtual std::unique_ptr<TileOperator> Clone() const = 0; | |
| // 虚析构函数 | |
| virtual ~TileOperator() = default; | |
| // Lower interface | |
| virtual Stmt Lower(const LowerArgs& T, arith::Analyzer* analyzer) const { | |
| ICHECK(0) << "Not Implemented Lower method."; | |
| return Evaluate(0); | |
| } | |
| // InferLayout interface | |
| virtual LayoutMap InferLayout(const LayoutInferArgs& T, InferLevel level) const { | |
| return {}; | |
| } | |
| // Clone interface | |
| virtual std::unique_ptr<TileOperator> Clone() const = 0; | |
| // Virtual destructor | |
| virtual ~TileOperator() = default; |
src/op/region.h
Outdated
| Var GetVarFromAccessPtr(const PrimExpr &expr); | ||
|
|
||
| std::unique_ptr<TileOperator> ParseOperator(Call call, BufferMap vmap); | ||
| std::unique_ptr<TileOperator> ParseOperator(Stmt stmt, BufferMap vmap); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 4
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (7)
src/op/copy.cc (1)
1097-1104: Bug: conv stride parameter shadowed; desc.elem_stride uses the wrong “stride”In Conv2DIm2ColOp::Lower, a local variable named “stride” (PrimExpr accumulator for global strides) shadows the conv stride parameter. As a result,
desc.elem_stride = {1, stride, stride, 1};uses the accumulator, not the intended convolution stride. This silently produces incorrect window stepping.Fix by renaming the local accumulator and explicitly using the member conv stride.
- // Make global stride in bytes - desc.global_stride = desc.global_stride.Map([&](PrimExpr e) { - return cast(DataType::Int(64), e) * src->dtype.bytes(); - }); - desc.elem_stride = {1, stride, stride, 1}; + // Make global stride in bytes + desc.global_stride = desc.global_stride.Map([&](PrimExpr e) { + return cast(DataType::Int(64), e) * src->dtype.bytes(); + }); + // Use the convolution stride parameter, not the local accumulator. + desc.elem_stride = {Integer(1), Integer(this->stride), Integer(this->stride), Integer(1)};Optionally also rename the earlier local variable to avoid future shadowing:
- PrimExpr stride = 1; + PrimExpr stride_acc = 1; desc.global_stride.reserve(desc.rank); for (size_t i = 0; i < desc.rank; i++) { - desc.global_stride.push_back(stride); - stride *= desc.global_shape[i]; + desc.global_stride.push_back(stride_acc); + stride_acc *= desc.global_shape[i]; }src/op/gemm_sp.cc (1)
259-304: A/B layout continuity inconsistency: A path computes ‘continuity’ but doesn’t use itIn the Hopper branch, A’s layout computes
continuitybut passesmat_continuousinstead; B’s path correctly usescontinuity. This likely reduces performance or incorrect layout for transposed cases. Usecontinuityfor A as well.- const int64_t continuity = - trans_A ? 4 * mat_continuous / warp_m : mat_continuous; - results.Set(A, makeGemmABLayoutHopper(mat_stride, mat_continuous, - mat_continuous, A->dtype.bits(), - trans_A ? 1 : 2)); + const int64_t continuity = + trans_A ? 4 * mat_continuous / warp_m : mat_continuous; + results.Set(A, makeGemmABLayoutHopper(mat_stride, mat_continuous, + continuity, A->dtype.bits(), + trans_A ? 1 : 2));src/op/atomic_add.cc (1)
156-171: Fix invalid address_of() operand and guard the atomic call with the destination predicate.address_of expects a BufferLoad. After wrapping dst_value with if_then_else, it is no longer a BufferLoad when dst_predicate is defined, which can miscompile. Also, we must not perform the atomic add when dst is out-of-bounds; guard the Evaluate with dst_predicate instead of predicating the address.
Apply this diff:
- PrimExpr dst_value = BufferLoad(dst, dst_indices); - if (dst_predicate.defined()) - dst_value = if_then_else(dst_predicate, dst_value, make_zero(dst->dtype)); - - Call address_of_value = - tvm::tir::Call(DataType::Handle(), builtin::address_of(), {dst_value}); + // Compute pointer from a plain BufferLoad; do not predicate the address. + auto dst_load = BufferLoad(dst, dst_indices); + Call address_of_value = + tvm::tir::Call(DataType::Handle(), builtin::address_of(), {dst_load}); @@ - Stmt body = tvm::tir::Evaluate(atomicadd_call); + Stmt body = tvm::tir::Evaluate(atomicadd_call); + // Avoid OOB address by guarding the atomic itself. + if (dst_predicate.defined()) { + body = IfThenElse(dst_predicate, body); + }src/transform/layout_inference.cc (2)
557-561: Typo prevents fragment-layout validation from triggering."local.framgent" should be "local.fragment", otherwise the layout presence check is skipped silently.
- for (auto buffer : block->alloc_buffers) { - if (buffer.scope() == "local.framgent") { + for (auto buffer : block->alloc_buffers) { + if (buffer.scope() == "local.fragment") { ICHECK(result_.layout_map.count(buffer)) << "Cannot inference fragment layout for " << buffer; } }
101-104: Explicitly includebuffer_remapin allLayoutInferArgsinitializationsTo guard against future changes to the
LayoutInferArgsaggregate (e.g. field reordering or new fields), every braced‐initializer must specify thebuffer_remapmember—either propagatingT.buffer_remapwhen you’re forwarding an existingLayoutInferArgs, or passing an emptyMap<Buffer, Buffer>()when you’re constructing one from scratch.• In
src/transform/layout_inference.cc(lines 101–104):- auto updates = next->InferLayout( - LayoutInferArgs{target_, thread_bounds, layout_map}, level); + auto updates = next->InferLayout( + LayoutInferArgs{target_, thread_bounds, layout_map, Map<Buffer, Buffer>()}, level);• In
src/op/elem.cc(around lines 92–93 and 110–111), propagate the incomingbuffer_remap:- par_op->InferLayout({T.target, T.thread_bounds, T.layout_map}, InferLevel::kFree); + par_op->InferLayout({T.target, T.thread_bounds, T.layout_map, T.buffer_remap}, InferLevel::kFree);(All other call sites—such as in
atomic_add.ccandcopy.cc—already includebuffer_remapand require no change.)This ensures all aggregate‐init calls remain correct if the
LayoutInferArgsdefinition evolves.src/op/gemm.cc (1)
209-241: Ensure non-WGMMA warp partition uses all warps exactly.In kFullRow/kFullCol branches, m_warp*n_warp may not equal num_warps. This breaks the documented invariant and can mis-size fragments. Pick divisors that multiply to num_warps and respect max per-dim limits.
if (this->policy == GemmWarpPolicy::kFullRow) { - // Try to partition M first - m_warp = num_warps; - n_warp = 1; - - // If M cannot be evenly divided by m_warp*16, try to split remaining warps - // to N - if (this->M % (m_warp * kMPerWarp) != 0) { - // Calculate how many warps we can use for M - int max_m_warps = this->M / kMPerWarp; - m_warp = max_m_warps; - // Use remaining warps for N - n_warp = num_warps / m_warp; - if (n_warp == 0) - n_warp = 1; - } + // Prefer M but keep m*n == num_warps and respect per-dim limits. + int max_m_warps = std::min(this->M / kMPerWarp, num_warps); + for (int m = max_m_warps; m >= 1; --m) { + if (num_warps % m == 0) { + m_warp = m; + n_warp = num_warps / m_warp; + break; + } + } } else if (this->policy == GemmWarpPolicy::kFullCol) { - // Try to partition N first - m_warp = 1; - n_warp = num_warps; - - // If N cannot be evenly divided by n_warp*8, try to split remaining warps - // to M - if (this->N % (n_warp * kNPerWarp) != 0) { - // Calculate how many warps we can use for N - int max_n_warps = this->N / kNPerWarp; - n_warp = max_n_warps; - // Use remaining warps for M - m_warp = num_warps / n_warp; - if (m_warp == 0) - m_warp = 1; - } + // Prefer N but keep m*n == num_warps and respect per-dim limits. + int max_n_warps = std::min(this->N / kNPerWarp, num_warps); + for (int n = max_n_warps; n >= 1; --n) { + if (num_warps % n == 0) { + n_warp = n; + m_warp = num_warps / n_warp; + break; + } + } } else if (this->policy == GemmWarpPolicy::kSquare) { @@ - return {m_warp, n_warp}; + ICHECK(m_warp * n_warp == num_warps) << "m_warp*n_warp must equal num_warps"; + return {m_warp, n_warp};src/op/copy.h (1)
232-239: Initialize eviction_policy to a sane default to avoid UBeviction_policy is a plain int and is only conditionally set from args in the ctor. Default-initialize it to kEvictNormal to prevent uninitialized reads in alternative code paths.
Apply this diff:
- int eviction_policy; // Policy for cache eviction + int eviction_policy = static_cast<int>(EvictionPolicy::kEvictNormal); // Policy for cache eviction
🧹 Nitpick comments (17)
src/op/copy.cc (2)
18-18: Unify local include path style for region.hOther files in this PR include Region headers with a leading "./" (e.g., "./region.h"). For consistency, prefer the same style here.
-#include "region.h" +#include "./region.h"
1047-1052: Style: prefer Downcast(...) over chained Optional handle accessThese lines use
args[i].as<IntImm>().value()->value. Elsewhere in this file you useDowncast<IntImm>(args[i])->value. Consider unifying to the latter for readability and consistency.- kernel = args[4].as<IntImm>().value()->value; + kernel = Downcast<IntImm>(args[4])->value; - stride = args[5].as<IntImm>().value()->value; + stride = Downcast<IntImm>(args[5])->value; - dilation = args[6].as<IntImm>().value()->value; + dilation = Downcast<IntImm>(args[6])->value; - padding = args[7].as<IntImm>().value()->value; + padding = Downcast<IntImm>(args[7])->value; - eviction_policy = args[8].as<IntImm>().value()->value; + eviction_policy = Downcast<IntImm>(args[8])->value;src/op/elem.cc (1)
125-128: No-op Fill::InferLayout hook is fine for nowAcceptable as a placeholder; Fill doesn’t need to seed layouts today. If Fill targets shared/fragment in the future, consider adding a lightweight free-level inference to reduce downstream guesswork.
src/op/atomic_add.cc (1)
189-194: Consider using the standard inference order (Strict → Common → Free).The rest of the pipeline infers in Strict, then Common, then Free. Mirroring that order here will keep behavior consistent and avoids surprising loop_layout_ initialization side-effects.
- std::vector<InferLevel> levels = {InferLevel::kCommon, InferLevel::kStrict, - InferLevel::kFree}; + std::vector<InferLevel> levels = {InferLevel::kStrict, InferLevel::kCommon, + InferLevel::kFree};src/op/gemm.cc (1)
300-367: Optional: compress CheckWGMMA combinatorics into a table-driven helper.The dtype/constraint matrix is verbose and duplication-prone. A compact table or predicate helpers would reduce maintenance overhead.
src/op/operator.cc (3)
7-12: Include tvm/ir/op.h to ensure Op/OpNode are available.
Relying on transitive includes is brittle.Apply this small addition:
#include "operator.h" +#include <tvm/ir/op.h> #include <tvm/tir/builtin.h> #include <tvm/tir/op.h> #include <tvm/tir/op_attr_types.h>
30-36: Minor: tighten Stmt parsing and avoid duplicate as<> calls.
Cleaner checks reduce chances of null deref on unusual Evaluate shapes.-std::unique_ptr<TileOperator> ParseOperator(Stmt stmt, BufferMap vmap) { - if (stmt.as<Evaluate>() && stmt.as<EvaluateNode>()->value.as<CallNode>()) { - auto call = stmt.as<EvaluateNode>()->value.as<CallNode>(); - return ParseOperator(GetRef<Call>(call), vmap); - } - return nullptr; -} +std::unique_ptr<TileOperator> ParseOperator(Stmt stmt, BufferMap vmap) { + if (const auto* eval = stmt.as<EvaluateNode>()) { + if (const auto* call = eval->value.as<CallNode>()) { + return ParseOperator(GetRef<Call>(call), vmap); + } + } + return nullptr; +}
38-45: GetVarFromAccessPtr: OK, but consider a clearer check for arg[1].
Currently ICHECK(var) is fine. Optionally add a message for easier debugging.src/op/gemm_sp.h (1)
48-48: Same note on mutable completed_.
Consider guarding for concurrent/duplicate invocations or documenting single-use guarantees per op instance.src/op/atomic_add.h (2)
27-35: Prefer copy-ctor over Clone()+static_cast for par_op_ deep copyUsing the concrete ParallelOp copy-ctor keeps this simpler and avoids a static_cast across the TileOperator interface.
Apply this diff:
- if (other.par_op_) - par_op_ = std::unique_ptr<ParallelOp>( - static_cast<ParallelOp *>(other.par_op_->Clone().release())); + if (other.par_op_) { + par_op_ = std::make_unique<ParallelOp>(*other.par_op_); + }Also applies to: 57-58
57-57: Be mindful of mutable state accessed from const methodspar_op_ is mutated from const methods (layout inference). That’s fine by design, but if AtomicAdd instances can be shared across threads, this introduces a data race. If multi-threaded access is possible in your pipeline, consider guarding par_op_ or caching at a higher level.
src/op/region.h (2)
1-3: Fix file header path in the doc blockThe file header says op.h; it should reference region.h.
Apply this diff:
-/*! - * \file tl/op/op.h +/*! + * \file tl/op/region.h
44-47: Avoid duplicating free-function declarations already in operator.hGetVarFromAccessPtr and ParseOperator are declared in operator.h. Keeping them here duplicates declarations and can lead to drift.
Apply this diff to remove the duplicates from region.h:
-Var GetVarFromAccessPtr(const PrimExpr &expr); - -std::unique_ptr<TileOperator> ParseOperator(Call call, BufferMap vmap); -std::unique_ptr<TileOperator> ParseOperator(Stmt stmt, BufferMap vmap);If you keep them, ensure both headers remain in sync.
src/op/copy.h (2)
1-9: Fix doc header: references elem.h but this is copy.hMinor doc mismatch.
Apply this diff:
-/*! - * \file tl/op/elem.h +/*! + * \file tl/op/copy.h
229-231: Note on const InferLayout + mutable par_op_The pattern mirrors ParallelOp usage elsewhere and is appropriate. Just ensure any concurrent passes don’t access the same Copy instance from multiple threads without synchronization.
Also applies to: 107-109
src/op/operator.h (1)
29-39: Macro registration aligns with typed builder; keep single source of truthTIR_REGISTER_TL_OP ties Entry::Get() and the builder attr. Make sure only operator.h declares ParseOperator and utility helpers to avoid header duplication (see region.h note).
src/op/parallel.h (1)
55-58: Clone/copy only partially preserves internal state; consider copying caches or documenting reset semanticsThe copy constructor delegates to ParallelOp(other.root_) and copies only loop_layout_ and predicate_. indice_map_, buffer_is_write_, and loop_vars_ are not copied, which may silently drop precomputed analysis if a previously used instance is cloned. That might be intentional (to force lazy recomputation), but it’s worth making explicit or copying the caches.
Apply this diff if you want clones to carry over cached analysis:
ParallelOp(const ParallelOp &other) : ParallelOp(other.root_) { loop_layout_ = other.loop_layout_; predicate_ = other.predicate_; + indice_map_ = other.indice_map_; + buffer_is_write_ = other.buffer_is_write_; + loop_vars_ = other.loop_vars_; }If intentional to reset caches, add a short comment to the copy ctor stating that analysis state is intentionally not copied.
Also applies to: 59-61
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (22)
src/op/atomic_add.cc(2 hunks)src/op/atomic_add.h(3 hunks)src/op/builtin.h(1 hunks)src/op/copy.cc(3 hunks)src/op/copy.h(8 hunks)src/op/elem.cc(1 hunks)src/op/elem.h(1 hunks)src/op/gemm.cc(4 hunks)src/op/gemm.h(2 hunks)src/op/gemm_sp.cc(1 hunks)src/op/gemm_sp.h(2 hunks)src/op/op.cc(0 hunks)src/op/operator.cc(1 hunks)src/op/operator.h(2 hunks)src/op/parallel.cc(3 hunks)src/op/parallel.h(4 hunks)src/op/reduce.cc(2 hunks)src/op/reduce.h(2 hunks)src/op/region.cc(1 hunks)src/op/region.h(1 hunks)src/transform/layout_inference.cc(5 hunks)src/transform/lower_tile_op.cc(1 hunks)
💤 Files with no reviewable changes (1)
- src/op/op.cc
🧰 Additional context used
🧬 Code graph analysis (18)
src/op/region.h (2)
src/op/operator.h (2)
tl(20-89)TileOperator(63-69)src/op/op.cc (1)
tir(14-86)
src/op/elem.cc (1)
src/op/region.cc (1)
InferLayout(49-52)
src/op/gemm_sp.cc (5)
src/op/region.cc (2)
InferLayout(49-52)InferLayout(49-50)src/op/atomic_add.cc (2)
InferLayout(213-233)InferLayout(213-214)src/op/copy.cc (4)
InferLayout(320-353)InferLayout(320-320)InferLayout(1232-1235)InferLayout(1232-1233)src/op/gemm.cc (2)
InferLayout(443-589)InferLayout(443-443)src/op/parallel.cc (2)
InferLayout(186-361)InferLayout(186-187)
src/op/operator.cc (2)
src/transform/layout_inference.cc (14)
op(39-45)op(39-39)op(294-322)op(294-294)op(346-369)op(346-346)op(371-388)op(371-371)op(390-399)op(390-390)op(553-565)op(553-553)expr(324-336)expr(324-324)tilelang/language/tir/op.py (1)
tvm_access_ptr(650-675)
src/op/region.cc (1)
src/op/region.h (1)
RegionOp(21-42)
src/op/reduce.cc (5)
src/op/region.cc (2)
InferLayout(49-52)InferLayout(49-50)src/op/atomic_add.cc (2)
InferLayout(213-233)InferLayout(213-214)src/op/copy.cc (4)
InferLayout(320-353)InferLayout(320-320)InferLayout(1232-1235)InferLayout(1232-1233)src/op/parallel.cc (2)
InferLayout(186-361)InferLayout(186-187)src/op/parallel.h (1)
InferLayout(52-58)
src/op/copy.cc (3)
src/op/region.cc (2)
InferLayout(49-52)InferLayout(49-50)src/op/atomic_add.cc (2)
InferLayout(213-233)InferLayout(213-214)src/op/parallel.h (1)
InferLayout(52-58)
src/op/atomic_add.cc (3)
src/op/region.cc (2)
InferLayout(49-52)InferLayout(49-50)src/op/copy.cc (4)
InferLayout(320-353)InferLayout(320-320)InferLayout(1232-1235)InferLayout(1232-1233)src/op/parallel.h (1)
InferLayout(52-58)
src/op/gemm.cc (3)
src/op/region.cc (2)
InferLayout(49-52)InferLayout(49-50)src/op/gemm_sp.cc (2)
InferLayout(259-311)InferLayout(259-260)src/op/parallel.h (1)
InferLayout(52-58)
src/op/elem.h (2)
src/op/operator.h (2)
tl(20-89)TileOperator(63-69)src/op/elem.cc (5)
Fill(25-69)Lower(88-123)Lower(88-88)InferLayout(125-127)InferLayout(125-125)
src/op/reduce.h (2)
src/op/region.h (3)
tvm(16-50)tl(17-49)- `` (29-31)
src/op/operator.h (2)
tl(20-89)TileOperator(63-69)
src/op/gemm.h (3)
src/op/gemm_sp.h (3)
tl(13-52)GemmWarpPolicy(24-49)- `` (30-32)
src/op/operator.h (2)
tl(20-89)TileOperator(63-69)src/op/gemm.cc (5)
Gemm(36-63)Lower(382-417)Lower(382-382)InferLayout(443-589)InferLayout(443-443)
src/op/atomic_add.h (4)
src/op/parallel.h (3)
InferLayout(52-58)- `` (59-61)
ParallelOp(48-88)src/op/copy.h (2)
- `` (167-169)
- `` (275-277)
src/op/operator.h (1)
TileOperator(63-69)src/op/atomic_add.cc (5)
AtomicAdd(37-53)Lower(183-211)Lower(183-183)InferLayout(213-233)InferLayout(213-214)
src/op/gemm_sp.h (3)
src/op/gemm.h (3)
tl(13-60)GemmWarpPolicy(24-57)- `` (30-32)
src/op/operator.h (2)
tl(20-89)TileOperator(63-69)src/op/gemm_sp.cc (5)
GemmSP(34-55)Lower(215-257)Lower(215-215)InferLayout(259-311)InferLayout(259-260)
src/op/parallel.cc (2)
src/op/region.cc (4)
Lower(45-47)Lower(45-45)InferLayout(49-52)InferLayout(49-50)src/op/atomic_add.cc (4)
Lower(183-211)Lower(183-183)InferLayout(213-233)InferLayout(213-214)
src/op/parallel.h (5)
src/op/operator.h (1)
TileOperator(63-69)src/op/region.cc (4)
Lower(45-47)Lower(45-45)InferLayout(49-52)InferLayout(49-50)src/op/atomic_add.cc (4)
Lower(183-211)Lower(183-183)InferLayout(213-233)InferLayout(213-214)src/op/copy.cc (4)
Lower(475-495)Lower(475-475)InferLayout(320-353)InferLayout(320-320)src/op/elem.cc (4)
Lower(88-123)Lower(88-88)InferLayout(125-127)InferLayout(125-125)
src/op/operator.h (2)
src/op/region.cc (4)
Lower(45-47)Lower(45-45)InferLayout(49-52)InferLayout(49-50)src/op/operator.cc (6)
GetVarFromAccessPtr(38-45)GetVarFromAccessPtr(38-38)ParseOperator(18-28)ParseOperator(18-18)ParseOperator(30-36)ParseOperator(30-30)
src/op/copy.h (4)
src/op/copy.cc (10)
Copy(115-140)Lower(475-495)Lower(475-475)Lower(1064-1186)Lower(1064-1065)InferLayout(320-353)InferLayout(320-320)InferLayout(1232-1235)InferLayout(1232-1233)Conv2DIm2ColOp(1042-1052)src/op/operator.h (1)
TileOperator(63-69)src/op/parallel.cc (5)
Lower(159-161)Lower(159-159)InferLayout(186-361)InferLayout(186-187)ParallelOp(157-157)src/op/parallel.h (2)
- `` (59-61)
ParallelOp(48-88)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: build-test-amd
- GitHub Check: build-test-nvidia
- GitHub Check: bot-task
🔇 Additional comments (48)
src/op/copy.cc (2)
320-353: Const-qualifying Copy::InferLayout is correct; flow matches TMA/LDSM/STSM expectationsThe const refactor is sound. The early-return for TMA linear layout at kFree and the fallback to ParallelOp inference look correct and align with AtomicAdd/Gemm patterns.
1232-1235: Conv2DIm2ColOp::InferLayout is a no-op; confirm the transform pipeline expects thisReturning an empty LayoutMap is fine if downstream layout inference (e.g., Copy/TMA) provides necessary annotations. Please verify no passes rely on im2col’s dst/shared layout being seeded here.
Would you like me to add a minimal shared layout inference (e.g., linear layout when free and unannotated) to reduce fallback paths?
src/op/builtin.h (1)
10-10: Include switch to operator.h aligns with TileOperator migrationThe header swap is correct given the refactor away from Operator/RegionOp. No further changes needed here.
src/transform/lower_tile_op.cc (1)
15-15: Header update to operator.h is appropriateMatches the move to TileOperator-based parsing (ParseOperator). Looks good.
src/op/gemm_sp.cc (1)
259-311: No action needed:completed_is already declaredmutable
The membercompleted_insrc/op/gemm_sp.his defined asmutable bool completed_ = false; // line 48which satisfies the requirement for a
constmethod to modify it.src/op/atomic_add.cc (2)
7-8: Include path normalization looks good.Switching to relative includes and pulling in region.h matches the refactor and avoids header search fragility.
213-214: Const-correct InferLayout override is aligned with TileOperator.Assuming par_op_ is declared mutable in the header, this is good and matches other ops.
src/op/reduce.cc (2)
287-288: Const-qualifying ReduceOp::InferLayout is correct.No functional change; matches the TileOperator interface.
406-408: Const-qualifying CumSumOp::InferLayout is fine.Returning an empty LayoutMap is expected for a no-op inference hook here.
src/transform/layout_inference.cc (4)
18-18: Including region.h is necessary for RegionOp::Get().Good catch; avoids ADL/forward-decl issues in getBufferFromAccessPtr.
257-276: Safe downcast after unique_ptr move; OK.Moving out of infer_list_ into base_infer and then dynamic_cast to ParallelOp is correct. infer_list_ isn’t used afterwards in Run(), so lifetime is fine.
404-404: Switch to std::unique_ptr vector is consistent with the refactor.Clone-based backups below also match the new polymorphic base.
416-423: Clone-based BackupInferList is correct and avoids aliasing.Returning deep clones prevents state leakage across attempts during free-mode search. LGTM.
src/op/gemm.cc (1)
443-443: Mutablecompleted_flag confirmed in Gemm
- In
src/op/gemm.h(line 56),completed_is already declared asmutable bool completed_ = false;, satisfying the need for state updates in theconstmethod.src/op/parallel.cc (3)
159-161: Lower passthrough is fine as a no-op hook.Returning root_ matches the new TileOperator lowering contract.
186-187: Const-correct InferLayout override is aligned with the interface.No functional concerns; the internal analyzer_ handles state.
371-396: Const-qualifying CompleteBufferFragment is correct; logic remains intact.Computation still respects replicate extents and inverse mapping. LGTM.
src/op/elem.h (4)
10-10: Switch to operator.h is correct and aligns with TileOperator migration.
Header include updated to the new TileOperator base; no issues spotted.
18-18: Good: Fill now derives from TileOperator.
This aligns with the refactor away from Operator. No ABI concerns apparent.
21-23: Const-correct overrides added.
Lower and InferLayout are now explicit overrides with const correctness; matches the new base interface.
26-28: Clone return type updated to std::unique_ptr.
Matches the base class and the rest of the refactor. LGTM.src/op/gemm.h (5)
10-10: Include switch to operator.h is consistent.
This keeps Gemm tied to the TileOperator API. No action needed.
17-17: Inheritance updated to TileOperator.
The change is straightforward and consistent with other ops.
20-22: Lower/InferLayout const overrides look good.
API compatibility with the new base is preserved. Ensure implementation keeps side effects out of these methods.
30-32: Clone override updated correctly.
Ownership semantics preserved; no regressions expected.
56-56: Ensure thread-safety forcompleted_guard in GemmOp::InferLayout
The mutable flag// src/op/gemm.h:56 mutable bool completed_ = false;can lead to data races or skipped inferences if
GemmOp::InferLayoutis ever invoked concurrently (e.g. across threads or repeated with different targets/bounds).Our quick scan shows
InferLayoutbeing called in a sequential loop within the layout-inference pass:
- src/transform/layout_inference.cc:102–105 –
// Run InferLayout auto updates = next->InferLayout( LayoutInferArgs{target_, thread_bounds, layout_map}, level); for (const auto& [buffer, layout] : updates) { … }However, we have not located any guarantees that no other code path invokes
InferLayoutin parallel.Recommendations:
- Cache inference results keyed by
(target, thread_bounds, buffers)instead of a singlecompleted_flag.- Or document and enforce that
InferLayoutis only ever called once per op in a single-threaded context.Please verify that there are no parallel or repeated callers of
InferLayoutelsewhere in the codebase.src/op/region.cc (4)
14-18: Op registration looks correct (builder + purity + varargs).
Registration chains cleanly from TIR_REGISTER_TL_OP; no problems.
35-43: IsFullRegion() logic is sound.
Using is_zero(min) and structural equality against buffer shape is appropriate here.
45-47: Lower() as a no-op is acceptable for RegionOp.
Given RegionOp is a descriptor, returning Evaluate(0) is a harmless placeholder.
49-52: InferLayout() no-op is fine for now.
No layout implications for region descriptors at present.src/op/operator.cc (1)
18-28: All Operator API refactor checks passed
No further action required.• No leftover
#include "op.h"entries
• No classes deriving from the oldOperatorbase or uses ofstd::unique_ptr<Operator>
• AllInferLayout(const LayoutInferArgs&, …)methods are already markedconst
• EveryClone()returnsstd::unique_ptr<TileOperator>rather thanOperatorsrc/op/gemm_sp.h (5)
10-10: Include switch to operator.h is correct.
Matches the new base and other files.
17-17: GemmSP now inherits TileOperator — good.
No issues with public API exposure.
20-23: Const-correct overrides for Lower/InferLayout.
Interfaces align with TileOperator; implementation looks consistent in gemm_sp.cc.
30-32: Clone override updated properly.
Ownership semantics OK.
34-37: Private ComputeWarpPartition helper: good encapsulation.
No external exposure; naming is consistent with Gemm.src/op/atomic_add.h (1)
18-23: TileOperator migration + const-correct overrides look goodThe base-class swap, const-qualified Lower/InferLayout, and Clone() return type align with the new TileOperator surface.
Also applies to: 36-38
src/op/region.h (1)
21-27: RegionOp skeleton and const-correct overrides are consistentInterface matches TileOperator with a no-op Lower/InferLayout in region.cc. This keeps RegionOp as a lightweight parsing/descriptor wrapper. LGTM.
Also applies to: 29-37
src/op/reduce.h (3)
17-23: ReduceOp: API refactor to TileOperator + const InferLayout looks goodThe migration and Clone signature are consistent with the rest of the PR.
Also applies to: 25-27
46-56: CumSumOp: consistent refactor with const-correctnessSame remarks as ReduceOp—looks good.
20-23: No downstream subclasses of ReduceOp/CumSumOp found
A code‐wide search for anyclassorstructpublicly inheriting fromReduceOporCumSumOpyielded no matches. There are no existing overrides ofLowerorInferLayoutin downstream code, so removing thefinalspecifier does not introduce any breaking changes.src/op/copy.h (1)
86-109: TileOperator migration, const InferLayout, and Clone return types are consistentThe Copy and Conv2DIm2ColOp changes match the new operator base and const-correctness strategy. Good use of mutable par_op_ for layout inference.
Also applies to: 167-169, 247-266, 275-277
src/op/operator.h (2)
63-81: New TileOperator base is coherent and minimally opinionatedInline “not implemented” Lower, empty InferLayout, and a virtual Clone() provide a clear, uniform surface for all ops.
85-87: All references to the old Operator API have been removedI searched the entire codebase and confirmed:
- No uses of
std::unique_ptr<Operator>- No
ParseOperatoroverloads returning the oldOperatortype- No classes inheriting from
OperatorThe remaining
operator.hincludes are either the newsrc/op/operator.hor external headers (<tvm/ir/op.h>,<tvm/tir/op.h>), and are expected. No further cleanup is needed here.src/op/parallel.h (4)
39-41: Visitor overrides are now const-correct and match StmtExprVisitorThe updated signatures using const ForNode*/BufferStoreNode*/BufferLoadNode* with override look correct and align with tvm::tir::StmtExprVisitor.
48-48: Migration to TileOperator + const-correct hooks looks goodDeriving ParallelOp from TileOperator and providing const Lower/InferLayout overrides matches the new operator API and will enable uniform reflection.
Please confirm all call sites now treat Lower/InferLayout as const (e.g., storing ParallelOp behind const TileOperator*). If you want, I can scan the repo for any remaining mutable usages.
Also applies to: 51-54
69-69: Good: method is const as it shouldn’t mutate stateCompleting a buffer fragment is a pure query; const-qualifying CompleteBufferFragment is the right call.
83-86: Mutable members enable lazy evaluation but can surprise; document thread-safety assumptionsloop_layout_, analyzer_, and predicate_ are mutable so they can be populated from const methods. That’s fine, but it implies ParallelOp instances are not thread-safe for concurrent use. Please document this in the class comment or guard mutations if cross-thread usage is possible.
Would you like me to add a brief comment stating “Not thread-safe; mutated lazily from const methods”?
src/op/operator.cc
Outdated
| std::unique_ptr<TileOperator> ParseOperator(Call call, BufferMap vmap) { | ||
| auto op_map = Op::GetAttrMap<OpBuilderFunc>("TLOpBuilder"); | ||
| Op op = call->op.as<Op>().value(); | ||
| if (op_map.count(op)) { | ||
| TileOperator *ptr = | ||
| static_cast<TileOperator *>(op_map[op](call->args, vmap)); | ||
| ICHECK(ptr != nullptr); | ||
| return std::unique_ptr<TileOperator>(ptr); | ||
| } | ||
| return nullptr; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unsafe cast to Op; handle non-Op calls gracefully.
call->op.as<Op>().value() will crash if call->op is not an Op (e.g., builtin). Prefer as<OpNode>() + GetRef or early-return.
Apply this diff:
std::unique_ptr<TileOperator> ParseOperator(Call call, BufferMap vmap) {
auto op_map = Op::GetAttrMap<OpBuilderFunc>("TLOpBuilder");
- Op op = call->op.as<Op>().value();
- if (op_map.count(op)) {
+ const auto* op_node = call->op.as<OpNode>();
+ if (!op_node) {
+ return nullptr;
+ }
+ Op op = GetRef<Op>(op_node);
+ if (op_map.count(op)) {
TileOperator *ptr =
static_cast<TileOperator *>(op_map[op](call->args, vmap));
ICHECK(ptr != nullptr);
return std::unique_ptr<TileOperator>(ptr);
}
return nullptr;
}📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| std::unique_ptr<TileOperator> ParseOperator(Call call, BufferMap vmap) { | |
| auto op_map = Op::GetAttrMap<OpBuilderFunc>("TLOpBuilder"); | |
| Op op = call->op.as<Op>().value(); | |
| if (op_map.count(op)) { | |
| TileOperator *ptr = | |
| static_cast<TileOperator *>(op_map[op](call->args, vmap)); | |
| ICHECK(ptr != nullptr); | |
| return std::unique_ptr<TileOperator>(ptr); | |
| } | |
| return nullptr; | |
| } | |
| std::unique_ptr<TileOperator> ParseOperator(Call call, BufferMap vmap) { | |
| auto op_map = Op::GetAttrMap<OpBuilderFunc>("TLOpBuilder"); | |
| const auto* op_node = call->op.as<OpNode>(); | |
| if (!op_node) { | |
| return nullptr; | |
| } | |
| Op op = GetRef<Op>(op_node); | |
| if (op_map.count(op)) { | |
| TileOperator *ptr = | |
| static_cast<TileOperator *>(op_map[op](call->args, vmap)); | |
| ICHECK(ptr != nullptr); | |
| return std::unique_ptr<TileOperator>(ptr); | |
| } | |
| return nullptr; | |
| } |
🤖 Prompt for AI Agents
In src/op/operator.cc around lines 18 to 28, the code unsafely calls
call->op.as<Op>().value() which will crash when call->op is not an Op; change
this to a safe cast using call->op.as<OpNode>() (or call->op.template
as<OpNode>() if needed), check for null, and early-return nullptr when the cast
fails; then obtain a proper Op reference (e.g., GetRef<Op>(op_node)) before
using it to look up op_map and proceed with creating the TileOperator.
|
|
||
| #include "../layout/layout.h" | ||
| #include "op.h" | ||
| #include "operator.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing required headers; avoid relying on transitive includes
This header uses std::exception/std::string, std::unordered_set and tir::And. None of , , <unordered_set>, or <tvm/tir/op.h> are directly included here. Because this is a public header, please include what you use to prevent fragile builds across translation units.
Apply this diff to add the missing includes:
#include <tvm/target/target.h>
#include <tvm/tir/stmt_functor.h>
#include "../layout/layout.h"
-#include "operator.h"
+#include "operator.h"
+
+// Add missing standard and TIR op headers used in this file
+#include <exception>
+#include <string>
+#include <unordered_set>
+#include <tvm/tir/op.h>📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| #include "operator.h" | |
| #include <tvm/target/target.h> | |
| #include <tvm/tir/stmt_functor.h> | |
| #include "../layout/layout.h" | |
| #include "operator.h" | |
| // Add missing standard and TIR op headers used in this file | |
| #include <exception> | |
| #include <string> | |
| #include <unordered_set> | |
| #include <tvm/tir/op.h> |
🤖 Prompt for AI Agents
In src/op/parallel.h around line 13, the header currently only includes
"operator.h" but uses std::exception, std::string, std::unordered_set and
tir::And; update the includes to explicitly add <exception>, <string>,
<unordered_set> and <tvm/tir/op.h> so the header "includes what it uses" and
does not rely on transitive includes; keep these new includes at the top of the
file with the other standard/library includes.
| void AddPredicate(PrimExpr expr) const { | ||
| predicate_ = predicate_.defined() ? And(expr, predicate_.value()) : expr; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Qualify And and ensure its header is included
And belongs to tvm::tir (builder for logical-and). Qualify it to avoid ADL surprises and make the dependency explicit. Also ensure <tvm/tir/op.h> is included (see earlier comment).
Apply this diff:
- void AddPredicate(PrimExpr expr) const {
- predicate_ = predicate_.defined() ? And(expr, predicate_.value()) : expr;
- }
+ void AddPredicate(PrimExpr expr) const {
+ predicate_ = predicate_.defined() ? tir::And(expr, predicate_.value()) : expr;
+ }📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| void AddPredicate(PrimExpr expr) const { | |
| predicate_ = predicate_.defined() ? And(expr, predicate_.value()) : expr; | |
| } | |
| void AddPredicate(PrimExpr expr) const { | |
| predicate_ = predicate_.defined() ? tir::And(expr, predicate_.value()) : expr; | |
| } |
🤖 Prompt for AI Agents
In src/op/parallel.h around lines 71 to 73, the call to And is unqualified and
the required header is missing; qualify the function as tvm::tir::And to avoid
ADL surprises and add #include <tvm/tir/op.h> (near other includes at the top of
the file) so the declaration is available.
| RegionOp::RegionOp(Array<PrimExpr> args, BufferMap vmap) { | ||
| size_t n = args.size(); | ||
| size_t ndim = n - 2; | ||
| auto load = args[0].as<BufferLoadNode>(); | ||
| ICHECK(load); | ||
| ICHECK(load->indices.size() == ndim) | ||
| << "load->indices.size() = " << load->indices << " ndim = " << ndim; | ||
| buffer_ = load->buffer; | ||
| access_mask_ = static_cast<int>(*as_const_int(args[1])); | ||
| for (size_t i = 0; i < ndim; i++) { | ||
| PrimExpr min = load->indices[i]; | ||
| PrimExpr extent = args[2 + i]; | ||
| ranges_.push_back(Range::FromMinExtent(min, extent)); | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Constructor needs stronger validation; error message bug.
- Potential crash if args.size() < 2 (ndim underflow).
- access_mask_ assumes args[1] is IntImm without checking.
- Error message prints load->indices (object) instead of its size.
Apply this diff to harden the constructor and fix the message:
RegionOp::RegionOp(Array<PrimExpr> args, BufferMap vmap) {
- size_t n = args.size();
- size_t ndim = n - 2;
- auto load = args[0].as<BufferLoadNode>();
- ICHECK(load);
- ICHECK(load->indices.size() == ndim)
- << "load->indices.size() = " << load->indices << " ndim = " << ndim;
+ const size_t n = args.size();
+ ICHECK_GE(n, 2U) << "RegionOp expects at least 2 args: BufferLoad and access_mask.";
+ const size_t ndim = n - 2;
+ const auto* load = args[0].as<BufferLoadNode>();
+ ICHECK(load) << "RegionOp expects the first argument to be a BufferLoad.";
+ ICHECK_EQ(load->indices.size(), ndim)
+ << "load->indices.size() = " << load->indices.size() << ", ndim = " << ndim;
buffer_ = load->buffer;
- access_mask_ = static_cast<int>(*as_const_int(args[1]));
+ const int64_t* mask = as_const_int(args[1]);
+ ICHECK(mask) << "RegionOp expects access_mask to be an IntImm.";
+ access_mask_ = static_cast<int>(*mask);
for (size_t i = 0; i < ndim; i++) {
PrimExpr min = load->indices[i];
PrimExpr extent = args[2 + i];
ranges_.push_back(Range::FromMinExtent(min, extent));
}
}If you want to silence the unused vmap parameter until needed, also mark it [[maybe_unused]] in both declaration and definition.
/* region.h (declaration)
RegionOp(Array<PrimExpr> args, [[maybe_unused]] BufferMap vmap);
*/
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| RegionOp::RegionOp(Array<PrimExpr> args, BufferMap vmap) { | |
| size_t n = args.size(); | |
| size_t ndim = n - 2; | |
| auto load = args[0].as<BufferLoadNode>(); | |
| ICHECK(load); | |
| ICHECK(load->indices.size() == ndim) | |
| << "load->indices.size() = " << load->indices << " ndim = " << ndim; | |
| buffer_ = load->buffer; | |
| access_mask_ = static_cast<int>(*as_const_int(args[1])); | |
| for (size_t i = 0; i < ndim; i++) { | |
| PrimExpr min = load->indices[i]; | |
| PrimExpr extent = args[2 + i]; | |
| ranges_.push_back(Range::FromMinExtent(min, extent)); | |
| } | |
| } | |
| RegionOp::RegionOp(Array<PrimExpr> args, BufferMap vmap) { | |
| const size_t n = args.size(); | |
| ICHECK_GE(n, 2U) << "RegionOp expects at least 2 args: BufferLoad and access_mask."; | |
| const size_t ndim = n - 2; | |
| const auto* load = args[0].as<BufferLoadNode>(); | |
| ICHECK(load) << "RegionOp expects the first argument to be a BufferLoad."; | |
| ICHECK_EQ(load->indices.size(), ndim) | |
| << "load->indices.size() = " << load->indices.size() << ", ndim = " << ndim; | |
| buffer_ = load->buffer; | |
| const int64_t* mask = as_const_int(args[1]); | |
| ICHECK(mask) << "RegionOp expects access_mask to be an IntImm."; | |
| access_mask_ = static_cast<int>(*mask); | |
| for (size_t i = 0; i < ndim; i++) { | |
| PrimExpr min = load->indices[i]; | |
| PrimExpr extent = args[2 + i]; | |
| ranges_.push_back(Range::FromMinExtent(min, extent)); | |
| } | |
| } |
🤖 Prompt for AI Agents
In src/op/region.cc around lines 19-33, the constructor currently can underflow
ndim when args.size() < 2, assumes args[1] is an IntImm, and prints
load->indices object instead of its size; update it to first check args.size()
>= 2 and return or ICHECK failure if not, compute ndim = args.size() - 2 only
after that check, verify args[0] is a BufferLoadNode and load->indices.size() ==
ndim using load->indices.size() in the error message (not the object), validate
that args[1] is an IntImm (or use as_const_int result safely) before casting and
error if not, and consider marking the vmap parameter [[maybe_unused]] in both
declaration and definition to silence unused warnings.
…gement - Updated several operator classes (AtomicAdd, Copy, Gemm, etc.) to utilize the Node pattern for better memory management and encapsulation. - Changed constructors to initialize member variables through a node object, enhancing clarity and reducing direct member access. - Updated Clone methods to return TileOperator instances instead of unique pointers, aligning with the new design. - Refactored InferLayout and Lower methods to ensure consistency across operator implementations. - Adjusted header files to reflect the new class structure and removed deprecated code for a cleaner codebase.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 7
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (11)
src/transform/layout_inference.cc (3)
131-136: Missing declaration for ProveFragmentContains causes ODR/compile error across TUsThis TU calls ProveFragmentContains but it’s only defined in parallel.cc (no header declaration). Add a declaration to a shared header (e.g., src/op/parallel.h) or move the function to a common header like layout/utils.h.
Example header addition (parallel.h):
+ // Proves small_frag accesses are contained in large_frag for given indices. + bool ProveFragmentContains(Fragment small_frag, Fragment large_frag, + Array<PrimExpr> small_frag_indices, + Array<PrimExpr> large_frag_indices, + arith::Analyzer& analyzer);
324-336: Harden getBufferFromAccessPtr: unchecked casts may segfaultBoth branches assume node types and directly deref. Guard with as<...>() checks and verify mapping existence before returning.
Apply:
if (call->op.same_as(builtin::tvm_access_ptr())) { - auto var = call->args[1].as<Var>().value(); - return buffer_data_to_buffer_[var]; + if (const auto* v = call->args[1].as<VarNode>()) { + Var var = GetRef<Var>(v); + if (buffer_data_to_buffer_.count(var)) return buffer_data_to_buffer_[var]; + } + return std::nullopt; } else if (call->op.same_as(RegionOp::Get())) { - return call->args[0].as<BufferLoadNode>()->buffer; + if (const auto* bl = call->args[0].as<BufferLoadNode>()) { + return bl->buffer; + } + return std::nullopt; }
556-560: Typo breaks fragment-scope check (no enforcement happens)"local.framgent" → "local.fragment".
- if (buffer.scope() == "local.framgent") { + if (buffer.scope() == "local.fragment") {src/op/gemm_sp.cc (2)
294-297: Bug: continuity for A ignored in Hopper pathYou compute continuity but pass mat_continuous instead. This skews shared layout for transposed A.
- results.Set(A, makeGemmABLayoutHopper(mat_stride, mat_continuous, - mat_continuous, A->dtype.bits(), - trans_A ? 1 : 2)); + results.Set(A, makeGemmABLayoutHopper(mat_stride, mat_continuous, + continuity, A->dtype.bits(), + trans_A ? 1 : 2));
254-258: C buffer remap must be optional (may be absent)Direct indexing can yield an undefined Buffer and crash on access_ptr.
- auto C_buffer = T.buffer_remap[C]; + auto C_buffer = T.buffer_remap.count(C) ? T.buffer_remap[C] : C;src/op/gemm.cc (1)
601-605: Arity mismatch: set_num_inputs should be 4 for tl_gemmLower emits String + Aptr + Bptr + Cptr (4 inputs). Registry uses 5.
-TIR_REGISTER_TL_OP(Gemm, gemm) - .set_num_inputs(5) +TIR_REGISTER_TL_OP(Gemm, gemm) + .set_num_inputs(4)src/op/reduce.cc (1)
56-88: Fix potential overflow when building init constants.
Shifts on 32-bit literals overflow for 32/64-bit types; use min_value/max_value helpers.Apply this diff:
switch (type) { case ReduceType::kSum: return make_zero(dst->dtype); case ReduceType::kAbsSum: return make_zero(dst->dtype); case ReduceType::kMax: - if (is_int) { - return make_const(dst->dtype, -(1 << (bits - 1))); - } else if (is_uint) { + if (is_int) { + return min_value(dst->dtype); + } else if (is_uint) { return make_const(dst->dtype, 0); } else { return make_const(dst->dtype, -INFINITY); } case ReduceType::kMin: - if (is_int) { - return make_const(dst->dtype, (1 << (bits - 1)) - 1); - } else if (is_uint) { - return make_const(dst->dtype, (1 << bits) - 1); - } else { + if (is_int || is_uint) { + return max_value(dst->dtype); + } else { return make_const(dst->dtype, INFINITY); }src/op/atomic_add.cc (1)
131-136: Fix scalar-path indexing bug.
For all-ones shapes with rank > 1,{0}indices are wrong. Use MakeIndices to produce full-rank indices.Apply this diff:
- if (is_scalar) { - return For(Var("i"), 0, 1, ForKind::kSerial, - BufferStore(dst, BufferLoad(src, {0}), {0})); - } + if (is_scalar) { + Array<PrimExpr> sidx = MakeIndices(/*ivs=*/{}, /*src_dst=*/0); + Array<PrimExpr> didx = MakeIndices(/*ivs=*/{}, /*src_dst=*/1); + return For(Var("i"), 0, 1, ForKind::kSerial, + BufferStore(dst, BufferLoad(src, sidx), didx)); + }src/op/copy.cc (3)
69-72: Incorrect CUtensorMapDataType for int16Signed int16 must map to CU_TENSOR_MAP_DATA_TYPE_INT16, not UINT16.
- case 16: - tp = CU_TENSOR_MAP_DATA_TYPE_UINT16; - break; + case 16: + tp = CU_TENSOR_MAP_DATA_TYPE_INT16; + break;
560-604: Guard LDSM/STSM path on presence of local layout; otherwise fallbackLowerLDSMCopy assumes T.layout_map has an entry for local_tensor. If missing, Downcast(...) will fail. Add a guard (or JIT infer via ParallelOp) and fallback to LowerNormalCopy.
- Fragment local_layout = Downcast<Fragment>(T.layout_map[local_tensor]); + if (!T.layout_map.count(local_tensor)) { + return LowerNormalCopy(T, analyzer); + } + Fragment local_layout = Downcast<Fragment>(T.layout_map[local_tensor]);
1115-1116: Use the member stride instead of an undefined local variableThis line references stride that doesn’t exist in scope; use this->stride.
- desc.elem_stride = {1, stride, stride, 1}; + desc.elem_stride = {1, this->stride, this->stride, 1};
♻️ Duplicate comments (3)
src/op/operator.cc (1)
20-25: Unsafe cast to Op; handle non-Op calls gracefullySwitch to
as<OpNode>()+ GetRef with early return when not an Op. This prevents crashes on builtin/global calls.- Op op = call->op.as<Op>().value(); - if (op_map.count(op)) { - auto tile_op = op_map[op](call->args, vmap); + const auto* op_node = call->op.as<OpNode>(); + if (!op_node) { + return TileOperator(); + } + Op op = GetRef<Op>(op_node); + if (op_map.count(op)) { + auto tile_op = op_map[op](call->args, vmap); ICHECK(tile_op.defined()); return tile_op; }src/op/parallel.h (2)
73-75: Qualify And with tir:: to avoid ADL and make intent explicitAlso aligns with requiring <tvm/tir/op.h>.
- predicate_ = predicate_.defined() ? And(expr, predicate_.value()) : expr; + predicate_ = predicate_.defined() ? tir::And(expr, predicate_.value()) : expr;
13-14: Missing required headers; avoid relying on transitive includesThis header uses std::exception, std::string, std::unordered_set, tir::And, and arith::Analyzer but doesn’t include their headers.
#include "../layout/layout.h" #include "operator.h" + +// Explicit dependencies used in this header +#include <exception> +#include <string> +#include <unordered_set> +#include <tvm/tir/op.h> +#include <tvm/arith/analyzer.h>
🧹 Nitpick comments (38)
src/op/elem.h (3)
26-31: Add override specifiers to ensure interface conformanceMark these as overrides of TileOperatorNode to catch signature drift at compile-time.
- Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const; - LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) const; + Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override; + LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) const override; static const Op &Get(); - TileOperator Clone() const; + TileOperator Clone() const override;
16-16: Avoidusing namespacein headers
using namespace tir;in a public header pollutes includers. Prefer qualified names or type aliases.
2-4: Fix header doc typo“Elment-wise” → “element-wise”.
- * \brief Define elment-wise operators. + * \brief Define element-wise operators.src/op/elem.cc (3)
105-109: Duplicate InferLayout call
par_op->InferLayout(...)is invoked twice back-to-back with identical args in the local.fragment path. Remove the duplicate.- par_op->InferLayout({T.target, T.thread_bounds, T.layout_map}, - InferLevel::kFree); par_op->InferLayout({T.target, T.thread_bounds, T.layout_map}, InferLevel::kFree);
83-86: Stabilize loop var naming and dtypeSingle-char names via
char('i' + i)scale poorly and dtype derived fromregion[i]->extentcan be non-int or unexpected. Use deterministic names and an explicit int32 dtype for loop vars.- Var var = Var(std::string{char('i' + i)}, region[i]->extent->dtype); + Var var = Var("i" + std::to_string(i), DataType::Int(32));
1-5: Fix source file doc typo“Elment-wise” → “element-wise”.
- * Define elment-wise operators. + * Define element-wise operators.src/op/copy.h (6)
130-131: Mark InferLayout as overrideAligns with TileOperatorNode interface.
- LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) const; + LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) const override;
211-212: Mark Clone as overridePrevents accidental signature drift.
- TileOperator Clone() const; + TileOperator Clone() const override;
114-116: Use the strong type for eviction policyStore the scoped enum, not a raw int, to avoid accidental misuse.
- int eviction_policy; // Policy for cache eviction + EvictionPolicy eviction_policy{EvictionPolicy::kEvictNormal}; // Cache eviction policy
103-105: Consider generalizing coalesced_width type
IntImmas a field is awkward to default/construct; preferPrimExpr(or Optional) for flexibility.- IntImm coalesced_width; // Width (in elements) for coalesced memory access + PrimExpr coalesced_width; // Width (in elements) for coalesced memory access
254-266: Mark Conv2DIm2Col overridesAdd override on InferLayout and Clone.
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override; - LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) const; + LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) const override; static const Op &Get(); - TileOperator Clone() const; + TileOperator Clone() const override;
19-19: Avoidusing namespacein headersHeader-wide
using namespace tir;leaks into transitive includers.src/op/operator.cc (2)
2-5: Fix file header typos“usd” → “used”.
- * Define operators usd in tile library. + * Define operators used in tile library.
30-33: Minor: simplify Evaluate cast usageOptional: bind the EvaluateNode once to avoid double
as<>calls.- if (stmt.as<Evaluate>() && stmt.as<EvaluateNode>()->value.as<CallNode>()) { - auto call = stmt.as<EvaluateNode>()->value.as<CallNode>(); + if (auto* eval = stmt.as<EvaluateNode>()) { + if (auto* call = eval->value.as<CallNode>()) { return ParseOperator(GetRef<Call>(call), vmap); - } + } + }src/op/region.h (3)
39-40: Mark Clone as overrideMatch TileOperatorNode interface.
- TileOperator Clone() const; + TileOperator Clone() const override;
10-10: Unify include stylePrefer
#include "operator.h"for consistency with other headers.-#include "./operator.h" +#include "operator.h"
19-19: Avoidusing namespacein headersPrevents namespace leakage to dependents.
src/transform/layout_inference.cc (4)
491-498: Catch exceptions by const reference to avoid slicing and copiesCatching by value can slice and incur copies.
- } catch (LayoutConflictException e) { + } catch (const LayoutConflictException& e) { // such an order fails, try others do_update = false; - } catch (NormalizeIterException e) { + } catch (const NormalizeIterException& e) {
447-451: Remove unused components_buffersIt is computed but never used.
- std::unordered_map<int, std::vector<Buffer>> components_buffers; - for (const auto &[buffer, infer_indices] : use_list_) { - int root = uf.Find(infer_indices[0]); - components_buffers[root].push_back(buffer); - }
83-85: Nit: fix function name in error textMessage says "run_infer_step" but method is RunInferStep.
- ICHECK(next.defined()) << "infer_list_[" << cur_infer_id - << "] is null inside run_infer_step."; + ICHECK(next.defined()) << "infer_list_[" << cur_infer_id + << "] is null inside RunInferStep.";
32-52: Duplicate using namespace tir;There are two using declarations (Line 32 and 51). Remove one.
-using namespace tir;src/op/gemm_sp.cc (2)
20-32: Remove dead helpertoPrimeFactors is unused.
-static std::vector<int> toPrimeFactors(int x) { - int i = 2; - std::vector<int> result; - while (x > 1) { - if (x % i == 0) { - x /= i; - result.push_back(i); - } else { - i++; - } - } - return result; -}
65-143: (Optional) Align divisibility checks with GemmConsider ICHECKs on M%16 and N%8 like in GemmNode::ComputeWarpPartition to fail fast on invalid tiles.
src/op/parallel.cc (1)
194-369: Solid inference flow; minor style nits onlyInference logic looks consistent with prior Operator-based code. No functional issues spotted.
- Consider early-return if indice_map_ empty.
src/op/atomic_add.h (1)
18-35: Clean up fields and add overrides
- args_ appears unused; remove to avoid confusion.
- Mark virtuals with override to catch signature drift.
-class AtomicAddNode : public TileOperatorNode { +class AtomicAddNode : public TileOperatorNode { public: - Array<PrimExpr> args_; + // args_ unused; remove if not needed. Buffer src, dst; Array<Range> src_range, dst_range; IntImm coalesced_width; mutable ParallelOp par_op_; static constexpr const char *_type_key = "tl.AtomicAdd"; TVM_DECLARE_FINAL_OBJECT_INFO(AtomicAddNode, TileOperatorNode); - Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const; - LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) const; + Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override; + LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) const override; static const Op &Get(); - TileOperator Clone() const; + TileOperator Clone() const override;src/op/gemm.cc (1)
22-35: Remove unused helpertoPrimeFactors is unused.
-static std::vector<int> toPrimeFactors(int x) { - int i = 2; - std::vector<int> result; - while (x > 1) { - if (x % i == 0) { - x /= i; - result.push_back(i); - } else { - i++; - } - } - return result; -}src/op/gemm_sp.h (1)
21-25: Unify GemmWarpPolicy type with Gemm.
Gemm exposes a top-level GemmWarpPolicy; GemmSP defines a nested enum. Prefer reusing the top-level to keep APIs consistent and avoid duplicate semantics.src/op/operator.h (1)
14-16: Header hygiene (minor).
Consider avoidingusing namespace tir;in a public header to limit namespace leakage. Use explicittir::instead.src/op/reduce.cc (3)
24-44: Defensive checks on argument types and vmap.
.as<...>().value()will crash on malformed inputs; vmap[] may yield undefined buffers. Add ICHECKs.Apply this diff:
- node->src = vmap[GetVarFromAccessPtr(args[0])]; - node->dst = vmap[GetVarFromAccessPtr(args[1])]; - std::string reduce_type = args[2].as<StringImm>().value()->value; - node->dim = args[3].as<IntImm>().value()->value; + auto v0 = GetVarFromAccessPtr(args[0]), v1 = GetVarFromAccessPtr(args[1]); + ICHECK(vmap.count(v0) && vmap.count(v1)); + node->src = vmap[v0]; + node->dst = vmap[v1]; + ICHECK(args[2].as<StringImm>()); + std::string reduce_type = args[2].as<StringImm>()->value; + ICHECK(args[3].as<IntImm>()); + node->dim = args[3].as<IntImm>()->value;
240-247: Prefer Target helper over raw arch string.
Use TargetIsHopper(T.target) for portability instead of"sm_90"string checks.Suggested change:
- bool has_arch = T.target->attrs.count("arch") > 0; - auto thread_offset = T.thread_bounds->min; - if (has_arch && Downcast<String>(T.target->attrs["arch"]) == "sm_90") { + auto thread_offset = T.thread_bounds->min; + if (TargetIsHopper(T.target)) {
264-266: Remove unused statement.
reduce_interthreadis constructed but not emitted.Apply this diff:
- Stmt reduce_interthread = BufferStore( - clear_buffer, BufferLoad(clear_buffer, dst_indices), dst_indices);src/op/atomic_add.cc (1)
62-74: Iterator var naming (nit).
Char arithmetic can collide past 26 dims. Consideri0, i1, …for robustness.src/op/gemm.h (1)
17-21: Apply same policy enum to GemmSP.
To avoid divergence, reuse this enum in GemmSP instead of a nested duplicate.src/op/reduce.h (2)
39-41: Add missing override specifier for Clone()Declare Clone() with override to enforce vtable correctness and catch signature drift.
- TileOperator Clone() const; + TileOperator Clone() const override;Also applies to: 41-45
25-33: Consider forward-declare vs include for Op, or include <tvm/ir/op.h>This header declares static Get() returning const Op&. Prefer including <tvm/ir/op.h> (or forward-declare tvm::Op) to make the dependency explicit.
#include "operator.h" +#include <tvm/ir/op.h>src/op/parallel.h (1)
68-69: Add missing override for Clone()Ensure it overrides TileOperatorNode::Clone().
- TileOperator Clone() const; + TileOperator Clone() const override;src/op/copy.cc (2)
13-21: Duplicate include of ../target/utils.hRemove one of the duplicates to avoid redundant work and potential ODR issues in some build setups.
#include "copy.h" -#include "../target/utils.h" #include "../transform/common/loop_fusion_utils.h" #include "../transform/common/loop_parallel_transform_utils.h" #include "../transform/loop_partition.h" #include "../transform/loop_vectorize.h" #include "region.h" #include "../target/cuda.h" #include "../target/utils.h"
327-361: Optional: collapse par_op_ warm-up in InferLayout to avoid duplicationYou already JIT-build par_op_ from MakeSIMTLoop; consider caching vector_size/thread partitioning outputs or exposing a helper to share with Lower paths if you later need layout in Lower.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (20)
src/op/atomic_add.cc(7 hunks)src/op/atomic_add.h(2 hunks)src/op/copy.cc(22 hunks)src/op/copy.h(4 hunks)src/op/elem.cc(5 hunks)src/op/elem.h(1 hunks)src/op/gemm.cc(7 hunks)src/op/gemm.h(2 hunks)src/op/gemm_sp.cc(4 hunks)src/op/gemm_sp.h(2 hunks)src/op/operator.cc(1 hunks)src/op/operator.h(2 hunks)src/op/parallel.cc(3 hunks)src/op/parallel.h(3 hunks)src/op/reduce.cc(7 hunks)src/op/reduce.h(1 hunks)src/op/region.cc(1 hunks)src/op/region.h(1 hunks)src/transform/layout_inference.cc(9 hunks)src/transform/lower_tile_op.cc(2 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
- src/transform/lower_tile_op.cc
- src/op/region.cc
🧰 Additional context used
🧬 Code graph analysis (14)
src/op/operator.cc (3)
src/transform/layout_inference.cc (14)
op(39-45)op(39-39)op(294-322)op(294-294)op(346-369)op(346-346)op(371-388)op(371-371)op(390-399)op(390-390)op(553-565)op(553-553)expr(324-336)expr(324-324)src/op/operator.h (1)
TileOperator(60-93)tilelang/language/tir/op.py (1)
tvm_access_ptr(650-675)
src/op/region.h (2)
src/op/operator.h (2)
TileOperatorNode(53-94)TileOperator(60-93)src/op/region.cc (9)
Lower(49-51)Lower(49-49)InferLayout(53-56)InferLayout(53-54)IsFullRegion(39-47)IsFullRegion(39-39)Clone(34-37)Clone(34-34)RegionOp(14-32)
src/transform/layout_inference.cc (2)
src/op/parallel.h (1)
ParallelOp(92-99)src/transform/pipeline_planning.cc (10)
op(54-72)op(54-54)op(74-92)op(74-74)op(94-121)op(94-94)op(123-133)op(123-123)op(545-554)op(545-545)
src/op/gemm_sp.cc (3)
src/op/operator.cc (2)
GetVarFromAccessPtr(37-44)GetVarFromAccessPtr(37-37)src/op/gemm.cc (8)
Clone(69-72)Clone(69-69)ComputeWarpPartition(134-294)ComputeWarpPartition(134-136)Lower(391-426)Lower(391-391)InferLayout(452-599)InferLayout(452-453)src/op/gemm_sp.h (1)
GemmSP(47-52)
src/op/gemm.cc (6)
src/op/operator.cc (2)
GetVarFromAccessPtr(37-44)GetVarFromAccessPtr(37-37)src/op/elem.cc (6)
Clone(74-77)Clone(74-74)Lower(96-136)Lower(96-96)InferLayout(138-141)InferLayout(138-139)src/op/gemm_sp.cc (8)
Clone(60-63)Clone(60-60)ComputeWarpPartition(65-221)ComputeWarpPartition(66-67)Lower(223-265)Lower(223-223)InferLayout(267-319)InferLayout(267-268)src/op/parallel.cc (10)
Clone(161-164)Clone(161-161)op(96-106)op(96-96)op(108-116)op(108-108)Lower(166-169)Lower(166-167)InferLayout(194-369)InferLayout(194-195)src/op/reduce.cc (12)
Clone(46-49)Clone(46-46)Clone(51-54)Clone(51-51)Lower(130-297)Lower(130-130)Lower(394-418)Lower(394-394)InferLayout(299-369)InferLayout(299-300)InferLayout(420-423)InferLayout(420-421)src/op/gemm.h (1)
Gemm(61-66)
src/op/parallel.h (2)
src/op/parallel.cc (23)
ParallelOpNode(157-159)op(96-106)op(96-96)op(108-116)op(108-108)VisitStmt_(122-128)VisitStmt_(122-122)VisitStmt_(130-142)VisitStmt_(130-130)VisitExpr_(144-155)VisitExpr_(144-144)Lower(166-169)Lower(166-167)InferLayout(194-369)InferLayout(194-195)GetPredicate(371-377)GetPredicate(371-371)Clone(161-164)Clone(161-161)CompleteBufferFragment(379-404)CompleteBufferFragment(379-379)IsCommonAccessIndice(171-174)IsCommonAccessIndice(171-171)src/op/operator.h (2)
TileOperatorNode(53-94)TileOperator(60-93)
src/op/elem.cc (6)
src/op/operator.cc (2)
GetVarFromAccessPtr(37-44)GetVarFromAccessPtr(37-37)src/op/gemm.cc (6)
Clone(69-72)Clone(69-69)Lower(391-426)Lower(391-391)InferLayout(452-599)InferLayout(452-453)src/op/parallel.cc (10)
Clone(161-164)Clone(161-161)op(96-106)op(96-96)op(108-116)op(108-108)Lower(166-169)Lower(166-167)InferLayout(194-369)InferLayout(194-195)src/op/region.cc (6)
Clone(34-37)Clone(34-34)Lower(49-51)Lower(49-49)InferLayout(53-56)InferLayout(53-54)src/op/elem.h (1)
Fill(36-41)src/op/parallel.h (1)
InferLayout(55-61)
src/op/gemm.h (4)
src/op/elem.h (2)
tvm(13-44)tl(14-43)src/op/gemm_sp.h (2)
tl(13-55)GemmWarpPolicy(21-45)src/op/operator.h (2)
TileOperatorNode(53-94)TileOperator(60-93)src/op/gemm.cc (11)
Lower(391-426)Lower(391-391)InferLayout(452-599)InferLayout(452-453)Clone(69-72)Clone(69-69)GetGemmInst(74-88)GetGemmInst(74-74)ComputeWarpPartition(134-294)ComputeWarpPartition(134-136)Gemm(36-67)
src/op/operator.h (2)
src/op/elem.cc (2)
Clone(74-77)Clone(74-74)src/op/operator.cc (6)
GetVarFromAccessPtr(37-44)GetVarFromAccessPtr(37-37)ParseOperator(18-27)ParseOperator(18-18)ParseOperator(29-35)ParseOperator(29-29)
src/op/elem.h (2)
src/op/operator.h (2)
TileOperatorNode(53-94)TileOperator(60-93)src/op/elem.cc (9)
Lower(96-136)Lower(96-96)InferLayout(138-141)InferLayout(138-139)Clone(74-77)Clone(74-74)MakeSIMTLoop(79-94)MakeSIMTLoop(79-79)Fill(25-72)
src/op/reduce.cc (5)
src/op/operator.cc (2)
GetVarFromAccessPtr(37-44)GetVarFromAccessPtr(37-37)src/op/elem.cc (2)
Clone(74-77)Clone(74-74)src/op/gemm.cc (2)
Clone(69-72)Clone(69-69)src/op/parallel.cc (6)
Clone(161-164)Clone(161-161)op(96-106)op(96-96)op(108-116)op(108-108)src/op/reduce.h (1)
ReduceOp(47-52)
src/op/parallel.cc (3)
src/op/parallel.h (3)
ParallelOpNode(48-90)ParallelOp(92-99)InferLayout(55-61)src/op/elem.cc (6)
Clone(74-77)Clone(74-74)Lower(96-136)Lower(96-96)InferLayout(138-141)InferLayout(138-139)src/transform/layout_inference.cc (20)
op(39-45)op(39-39)op(294-322)op(294-294)op(346-369)op(346-346)op(371-388)op(371-371)op(390-399)op(390-390)op(553-565)op(553-553)op(567-624)op(567-567)op(626-635)op(626-626)buffer(338-344)buffer(338-338)level(174-189)level(174-176)
src/op/gemm_sp.h (4)
src/op/elem.h (1)
tl(14-43)src/op/gemm.h (1)
tl(13-69)src/op/operator.h (2)
TileOperatorNode(53-94)TileOperator(60-93)src/op/gemm_sp.cc (7)
Lower(223-265)Lower(223-223)InferLayout(267-319)InferLayout(267-268)Clone(60-63)Clone(60-60)GemmSP(34-58)
src/op/reduce.h (2)
src/op/operator.h (2)
TileOperatorNode(53-94)TileOperator(60-93)src/op/reduce.cc (20)
Lower(130-297)Lower(130-130)Lower(394-418)Lower(394-394)InferLayout(299-369)InferLayout(299-300)InferLayout(420-423)InferLayout(420-421)Clone(46-49)Clone(46-46)Clone(51-54)Clone(51-51)MakeInitValue(56-88)MakeInitValue(56-56)MakeReduce(90-110)MakeReduce(90-90)MakeCodegenReducer(112-128)MakeCodegenReducer(112-112)ReduceOp(24-44)CumSumOp(376-392)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: build-test-amd
- GitHub Check: build-test-nvidia
🔇 Additional comments (31)
src/op/elem.h (1)
28-28: DropFillNode::Get()– please verify removalI ran an initial search for any references to
FillNode::Get()and found no matches, but an empty grep isn’t definitive proof there are no indirect or scattered calls. Before removing the declaration, please manually confirm no call sites remain by running something like:
•rg -n "FillNode::Get" -C3 src/
•rg -nP '\bFillNode\s*::\s*Get' -C2 .Once you’ve ensured it’s unused, apply this optional refactor:
--- a/src/op/elem.h +++ b/src/op/elem.h @@ -28 +28 @@ - static const Op &Get();Ensure the wrapper class (via
TIR_REGISTER_TL_OP(FillNode)) continues to expose the op’sGet(), then drop the Node-level declaration to avoid potential ODR/link errors.src/op/parallel.cc (2)
157-169: LGTM: Node wrapper, Clone, and Lower wiringConstructor/Clone/Lower correctly set up ParallelOpNode and preserve root.
47-81: Incorrect suggestion:ProveFragmentContainsis already declared inparallel.hThe declaration for
ProveFragmentContainsalready exists insrc/op/parallel.h, so no new header entry is necessary. Instead, ensure any source using it includes the correct header.• Declaration found at:
src/op/parallel.h:29–32
• Usage in layout inference:
src/transform/layout_inference.cc:131– verify this file includessrc/op/parallel.hto resolve any missing-declaration errors.Likely an incorrect or invalid review comment.
src/op/atomic_add.h (1)
48-53: Wrapper surface looks correctConstructor and Get() declaration match the node-based pattern.
src/op/gemm.cc (2)
391-426: Lower path looks consistent with node fields and target gatingString composition and call emission align with the new node design.
452-599: Inference paths per-arch look correctHopper/CDNA/SM variants mirror prior logic; completed_ guard in place.
src/op/gemm_sp.h (2)
10-10: Include switch to operator.h looks correct.
Aligns with the new TileOperator API.
47-51: Registration Verified
TheTIR_REGISTER_TL_OP(GemmSP, gemm_sp)macro is present, soGemmSP::Get()and the builder resolution are correctly registered.• src/op/gemm_sp.cc:321 –
TIR_REGISTER_TL_OP(GemmSP, gemm_sp)src/op/operator.h (3)
53-65: TileOperator base interface looks solid.
Pure-virtual Lower/InferLayout/Clone with TVM object plumbing is consistent and minimal.
75-77: ParseOperator signature migration to TileOperator.
Returning TileOperator simplifies downstream use; implementation matches.
78-91: Registration macro update is correct.
Returning Entry(args, vmap) as TileOperator matches the new OpBuilderFunc type.src/op/reduce.cc (8)
46-49: Clone() implementation is correct.
Standard node copy and wrap.
51-54: CumSum Clone() implementation is correct.
Matches the project pattern.
90-110: MakeReduce logic reads clean.
Casts to align dtypes and correct cases for AbsMax/AbsSum.
112-128: Reducer codegen names OK.
AbsSum/AbsMax both map to MaxOp/SumOp as expected.
299-369: InferLayout logic is thorough.
Compatibility checks and conflict diagnostics look good.
376-392: CumSum ctor validations LGTM.
Dim bounds check and node wiring look correct.
394-418: CumSum Lower: scope gating is clear.
Shared memory lowering path only; fatal for fragments is intentional.
420-423: CumSum InferLayout stub OK.
No inference yet; safe default.src/op/atomic_add.cc (7)
7-8: Include path tweaks look right.
Local includes for atomic_add.h and region.h match the new structure.
57-60: Clone() implementation is correct.
Standard node copy and wrap.
78-95: Index construction LGTM.
Handles unit-extent dims correctly and validates arity.
97-126: Predicate construction LGTM.
Symbolic-bound checks minimize redundant conditions.
191-219: Lower path is cohesive.
ParallelOp-based inference + vectorization looks fine; predicate-guarded emission is correct.
221-242: InferLayout caching via par_op_ is good.
Fragment-layout equality check prevents silent mismatches.
37-55: RegionOp accessor methods are present and correctly exported
Verified in src/op/region.h that RegionOpNode defines both accessors and they are exposed through the RegionOp wrapper:
- Line 34:
const Buffer &GetBuffer() const { return buffer_; }- Line 35:
const Array<Range> &GetRanges() const { return ranges_; }No changes required.
src/op/gemm.h (4)
10-10: Include switch to operator.h looks correct.
Matches new base types.
17-21: Top-level GemmWarpPolicy is a good move.
Promotes reuse across ops.
23-59: GemmNode fields and API align with refactor.
Type key, object info, overrides, and state look consistent with gemm.cc.
61-66: Wrapper Gemm handle LGTM.
Constructor + Get() pattern matches registration macro.src/op/copy.cc (1)
38-99: Verify CUDA 12+ CUtensorMapDataType mappings for fp8 typesMapping float8_e4m3/e5m2 to CU_TENSOR_MAP_DATA_TYPE_UINT8 may be intentional; please confirm against the latest CUDA docs.
Would you like me to open a follow-up to sync these enums with CUDA 12.4+ definitions?
| AtomicAdd::AtomicAdd(Array<PrimExpr> args, BufferMap vmap) { | ||
| ObjectPtr<AtomicAddNode> node = make_object<AtomicAddNode>(); | ||
| Array<Range> rgs[2]; | ||
| Buffer bf[2]; | ||
| for (int i = 0; i < 2; i++) { | ||
| auto expr = args[i]; | ||
| auto call = expr.as<CallNode>(); | ||
| ICHECK(call); | ||
| auto region = RegionOp(call->args, vmap); | ||
| rgs[i] = region.GetRanges(); | ||
| bf[i] = region.GetBuffer(); | ||
| rgs[i] = region->GetRanges(); | ||
| bf[i] = region->GetBuffer(); | ||
| } | ||
| std::tie(this->src, this->dst) = std::tie(bf[0], bf[1]); | ||
| std::tie(this->src_range, this->dst_range) = std::tie(rgs[0], rgs[1]); | ||
| std::tie(node->src, node->dst) = std::tie(bf[0], bf[1]); | ||
| std::tie(node->src_range, node->dst_range) = std::tie(rgs[0], rgs[1]); | ||
| if (args.size() >= 3) { | ||
| coalesced_width = Downcast<IntImm>(args[2]); | ||
| node->coalesced_width = Downcast<IntImm>(args[2]); | ||
| } | ||
| data_ = std::move(node); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Harden RegionOp extraction.
Guard RegionOp construction and vmap lookups to avoid nullptr buffers.
Apply this diff:
for (int i = 0; i < 2; i++) {
auto expr = args[i];
auto call = expr.as<CallNode>();
ICHECK(call);
- auto region = RegionOp(call->args, vmap);
- rgs[i] = region->GetRanges();
- bf[i] = region->GetBuffer();
+ auto region = RegionOp(call->args, vmap);
+ ICHECK(region.defined()) << "Failed to parse RegionOp for arg " << i;
+ rgs[i] = region->GetRanges();
+ bf[i] = region->GetBuffer();
}📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| AtomicAdd::AtomicAdd(Array<PrimExpr> args, BufferMap vmap) { | |
| ObjectPtr<AtomicAddNode> node = make_object<AtomicAddNode>(); | |
| Array<Range> rgs[2]; | |
| Buffer bf[2]; | |
| for (int i = 0; i < 2; i++) { | |
| auto expr = args[i]; | |
| auto call = expr.as<CallNode>(); | |
| ICHECK(call); | |
| auto region = RegionOp(call->args, vmap); | |
| rgs[i] = region.GetRanges(); | |
| bf[i] = region.GetBuffer(); | |
| rgs[i] = region->GetRanges(); | |
| bf[i] = region->GetBuffer(); | |
| } | |
| std::tie(this->src, this->dst) = std::tie(bf[0], bf[1]); | |
| std::tie(this->src_range, this->dst_range) = std::tie(rgs[0], rgs[1]); | |
| std::tie(node->src, node->dst) = std::tie(bf[0], bf[1]); | |
| std::tie(node->src_range, node->dst_range) = std::tie(rgs[0], rgs[1]); | |
| if (args.size() >= 3) { | |
| coalesced_width = Downcast<IntImm>(args[2]); | |
| node->coalesced_width = Downcast<IntImm>(args[2]); | |
| } | |
| data_ = std::move(node); | |
| } | |
| AtomicAdd::AtomicAdd(Array<PrimExpr> args, BufferMap vmap) { | |
| ObjectPtr<AtomicAddNode> node = make_object<AtomicAddNode>(); | |
| Array<Range> rgs[2]; | |
| Buffer bf[2]; | |
| for (int i = 0; i < 2; i++) { | |
| auto expr = args[i]; | |
| auto call = expr.as<CallNode>(); | |
| ICHECK(call); | |
| auto region = RegionOp(call->args, vmap); | |
| ICHECK(region.defined()) << "Failed to parse RegionOp for arg " << i; | |
| rgs[i] = region->GetRanges(); | |
| bf[i] = region->GetBuffer(); | |
| } | |
| std::tie(node->src, node->dst) = std::tie(bf[0], bf[1]); | |
| std::tie(node->src_range, node->dst_range) = std::tie(rgs[0], rgs[1]); | |
| if (args.size() >= 3) { | |
| node->coalesced_width = Downcast<IntImm>(args[2]); | |
| } | |
| data_ = std::move(node); | |
| } |
🤖 Prompt for AI Agents
In src/op/atomic_add.cc around lines 37 to 55, the RegionOp construction and
buffer map lookups assume non-null RegionOp and non-null buffers; harden this by
validating call casting, ensuring RegionOp(...) returns a non-null pointer, and
checking that region->GetBuffer() is non-null before using it. Replace raw
ICHECK(call) with an explicit check that call is valid and return/raise a clear
error if not; after RegionOp(call->args, vmap) verify region is non-null and
that region->GetRanges() and region->GetBuffer() are valid (ICHECK or
conditional error) for each argument; only assign node->src/node->dst and ranges
after successful validation; keep the existing handling for args[2] but guard
the Downcast with a check that args.size() >= 3 and that args[2] is an IntImm
before casting.
| if (args.size() >= 4) { | ||
| this->disable_tma = Downcast<Bool>(args[3]); | ||
| node->disable_tma = Downcast<Bool>(args[3]); | ||
| } | ||
| if (args.size() >= 5) { | ||
| this->eviction_policy = args[4].as<IntImmNode>()->value; | ||
| node->eviction_policy = args[4].as<IntImmNode>()->value; | ||
| } | ||
| data_ = std::move(node); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Parse disable_tma as a bool, not Bool/ObjectRef
Downcasting to Bool is brittle for PrimExpr inputs. Use IntImm->value -> bool.
- if (args.size() >= 4) {
- node->disable_tma = Downcast<Bool>(args[3]);
- }
+ if (args.size() >= 4) {
+ node->disable_tma = static_cast<bool>(Downcast<IntImm>(args[3])->value);
+ }📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if (args.size() >= 4) { | |
| this->disable_tma = Downcast<Bool>(args[3]); | |
| node->disable_tma = Downcast<Bool>(args[3]); | |
| } | |
| if (args.size() >= 5) { | |
| this->eviction_policy = args[4].as<IntImmNode>()->value; | |
| node->eviction_policy = args[4].as<IntImmNode>()->value; | |
| } | |
| data_ = std::move(node); | |
| if (args.size() >= 4) { | |
| node->disable_tma = static_cast<bool>(Downcast<IntImm>(args[3])->value); | |
| } | |
| if (args.size() >= 5) { | |
| node->eviction_policy = args[4].as<IntImmNode>()->value; | |
| } | |
| data_ = std::move(node); |
🤖 Prompt for AI Agents
In src/op/copy.cc around lines 135-141, the code downcasts args[3] to Bool which
is brittle for PrimExpr inputs; instead parse the IntImm value and convert to
bool. Replace the Downcast<Bool>(args[3]) usage with retrieving
args[3].as<IntImmNode> and set node->disable_tma to (that->value != 0), ensuring
you access as<IntImmNode>() safely (assume an IntImm is provided) so the flag
comes from the integer value rather than an ObjectRef Bool.
| Conv2DIm2ColOp::Conv2DIm2ColOp(Array<PrimExpr> args, BufferMap vmap) { | ||
| src = vmap[GetVarFromAccessPtr(args[0])]; | ||
| dst = vmap[GetVarFromAccessPtr(args[1])]; | ||
| nhw_step = args[2]; | ||
| c_step = args[3]; | ||
| kernel = args[4].as<IntImm>().value()->value; | ||
| stride = args[5].as<IntImm>().value()->value; | ||
| dilation = args[6].as<IntImm>().value()->value; | ||
| padding = args[7].as<IntImm>().value()->value; | ||
| eviction_policy = args[8].as<IntImm>().value()->value; | ||
| ObjectPtr<Conv2DIm2ColOpNode> node = make_object<Conv2DIm2ColOpNode>(); | ||
| node->src = vmap[GetVarFromAccessPtr(args[0])]; | ||
| node->dst = vmap[GetVarFromAccessPtr(args[1])]; | ||
| node->nhw_step = args[2]; | ||
| node->c_step = args[3]; | ||
| node->kernel = args[4].as<IntImm>().value()->value; | ||
| node->stride = args[5].as<IntImm>().value()->value; | ||
| node->dilation = args[6].as<IntImm>().value()->value; | ||
| node->padding = args[7].as<IntImm>().value()->value; | ||
| node->eviction_policy = args[8].as<IntImm>().value()->value; | ||
| data_ = std::move(node); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix IntImm extraction for Conv2DIm2ColOp ctor
Current .as().value()->value is invalid; use Downcast(...)->value.
- node->kernel = args[4].as<IntImm>().value()->value;
- node->stride = args[5].as<IntImm>().value()->value;
- node->dilation = args[6].as<IntImm>().value()->value;
- node->padding = args[7].as<IntImm>().value()->value;
- node->eviction_policy = args[8].as[IntImm>().value()->value;
+ node->kernel = Downcast<IntImm>(args[4])->value;
+ node->stride = Downcast<IntImm>(args[5])->value;
+ node->dilation = Downcast<IntImm>(args[6])->value;
+ node->padding = Downcast<IntImm>(args[7])->value;
+ node->eviction_policy = Downcast<IntImm>(args[8])->value;📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| Conv2DIm2ColOp::Conv2DIm2ColOp(Array<PrimExpr> args, BufferMap vmap) { | |
| src = vmap[GetVarFromAccessPtr(args[0])]; | |
| dst = vmap[GetVarFromAccessPtr(args[1])]; | |
| nhw_step = args[2]; | |
| c_step = args[3]; | |
| kernel = args[4].as<IntImm>().value()->value; | |
| stride = args[5].as<IntImm>().value()->value; | |
| dilation = args[6].as<IntImm>().value()->value; | |
| padding = args[7].as<IntImm>().value()->value; | |
| eviction_policy = args[8].as<IntImm>().value()->value; | |
| ObjectPtr<Conv2DIm2ColOpNode> node = make_object<Conv2DIm2ColOpNode>(); | |
| node->src = vmap[GetVarFromAccessPtr(args[0])]; | |
| node->dst = vmap[GetVarFromAccessPtr(args[1])]; | |
| node->nhw_step = args[2]; | |
| node->c_step = args[3]; | |
| node->kernel = args[4].as<IntImm>().value()->value; | |
| node->stride = args[5].as<IntImm>().value()->value; | |
| node->dilation = args[6].as<IntImm>().value()->value; | |
| node->padding = args[7].as<IntImm>().value()->value; | |
| node->eviction_policy = args[8].as<IntImm>().value()->value; | |
| data_ = std::move(node); | |
| } | |
| Conv2DIm2ColOp::Conv2DIm2ColOp(Array<PrimExpr> args, BufferMap vmap) { | |
| ObjectPtr<Conv2DIm2ColOpNode> node = make_object<Conv2DIm2ColOpNode>(); | |
| node->src = vmap[GetVarFromAccessPtr(args[0])]; | |
| node->dst = vmap[GetVarFromAccessPtr(args[1])]; | |
| node->nhw_step = args[2]; | |
| node->c_step = args[3]; | |
| node->kernel = Downcast<IntImm>(args[4])->value; | |
| node->stride = Downcast<IntImm>(args[5])->value; | |
| node->dilation = Downcast<IntImm>(args[6])->value; | |
| node->padding = Downcast<IntImm>(args[7])->value; | |
| node->eviction_policy = Downcast<IntImm>(args[8])->value; | |
| data_ = std::move(node); | |
| } |
🤖 Prompt for AI Agents
In src/op/copy.cc around lines 1050 to 1062, the constructor uses
.as<IntImm>().value()->value which is invalid; replace each
args[N].as<IntImm>().value()->value with tir::Downcast<IntImm>(args[N])->value
(or simply Downcast<IntImm>(args[N])->value if tir namespace is imported) so the
IntImm is correctly downcast and its .value accessed; update lines for kernel
(args[4]), stride (args[5]), dilation (args[6]), padding (args[7]), and
eviction_policy (args[8]) accordingly.
| ObjectPtr<GemmSPNode> node = make_object<GemmSPNode>(); | ||
| node->A = vmap[GetVarFromAccessPtr(args[0])]; | ||
| node->E = vmap[GetVarFromAccessPtr(args[1])]; | ||
| node->B = vmap[GetVarFromAccessPtr(args[2])]; | ||
| node->C = vmap[GetVarFromAccessPtr(args[3])]; | ||
| node->trans_A = args[4].as<Bool>().value(); | ||
| node->trans_B = args[5].as<Bool>().value(); | ||
| node->M = args[6].as<IntImm>().value()->value; | ||
| node->N = args[7].as<IntImm>().value()->value; | ||
| node->K = args[8].as<IntImm>().value()->value; | ||
| node->policy = static_cast<GemmSPNode::GemmWarpPolicy>( | ||
| args[9].as<IntImm>().value()->value); | ||
| node->clear_accum = args[10].as<Bool>().value(); | ||
| if (args.size() > 11) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Validate vmap keys before indexing
Using vmap[GetVarFromAccessPtr(...)] without checking presence can produce undefined Buffers.
- node->A = vmap[GetVarFromAccessPtr(args[0])];
- node->E = vmap[GetVarFromAccessPtr(args[1])];
- node->B = vmap[GetVarFromAccessPtr(args[2])];
- node->C = vmap[GetVarFromAccessPtr(args[3])];
+ auto vA = GetVarFromAccessPtr(args[0]); ICHECK(vmap.count(vA)); node->A = vmap[vA];
+ auto vE = GetVarFromAccessPtr(args[1]); ICHECK(vmap.count(vE)); node->E = vmap[vE];
+ auto vB = GetVarFromAccessPtr(args[2]); ICHECK(vmap.count(vB)); node->B = vmap[vB];
+ auto vC = GetVarFromAccessPtr(args[3]); ICHECK(vmap.count(vC)); node->C = vmap[vC];📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| ObjectPtr<GemmSPNode> node = make_object<GemmSPNode>(); | |
| node->A = vmap[GetVarFromAccessPtr(args[0])]; | |
| node->E = vmap[GetVarFromAccessPtr(args[1])]; | |
| node->B = vmap[GetVarFromAccessPtr(args[2])]; | |
| node->C = vmap[GetVarFromAccessPtr(args[3])]; | |
| node->trans_A = args[4].as<Bool>().value(); | |
| node->trans_B = args[5].as<Bool>().value(); | |
| node->M = args[6].as<IntImm>().value()->value; | |
| node->N = args[7].as<IntImm>().value()->value; | |
| node->K = args[8].as<IntImm>().value()->value; | |
| node->policy = static_cast<GemmSPNode::GemmWarpPolicy>( | |
| args[9].as<IntImm>().value()->value); | |
| node->clear_accum = args[10].as<Bool>().value(); | |
| if (args.size() > 11) { | |
| ObjectPtr<GemmSPNode> node = make_object<GemmSPNode>(); | |
| - node->A = vmap[GetVarFromAccessPtr(args[0])]; | |
| - node->E = vmap[GetVarFromAccessPtr(args[1])]; | |
| - node->B = vmap[GetVarFromAccessPtr(args[2])]; | |
| auto vA = GetVarFromAccessPtr(args[0]); ICHECK(vmap.count(vA)); node->A = vmap[vA]; | |
| auto vE = GetVarFromAccessPtr(args[1]); ICHECK(vmap.count(vE)); node->E = vmap[vE]; | |
| auto vB = GetVarFromAccessPtr(args[2]); ICHECK(vmap.count(vB)); node->B = vmap[vB]; | |
| auto vC = GetVarFromAccessPtr(args[3]); ICHECK(vmap.count(vC)); node->C = vmap[vC]; | |
| node->trans_A = args[4].as<Bool>().value(); | |
| node->trans_B = args[5].as<Bool>().value(); | |
| node->M = args[6].as<IntImm>().value()->value; | |
| node->N = args[7].as<IntImm>().value()->value; | |
| node->K = args[8].as<IntImm>().value()->value; | |
| node->policy = static_cast<GemmSPNode::GemmWarpPolicy>( | |
| args[9].as<IntImm>().value()->value); | |
| node->clear_accum = args[10].as<Bool>().value(); | |
| if (args.size() > 11) { |
| class GemmSPNode : public TileOperatorNode { | ||
| public: | ||
| GemmSP(Array<PrimExpr> args, BufferMap vmap); | ||
| Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const final; | ||
| LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) final; | ||
| static const Op &Get(); | ||
| Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const; | ||
| LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) const; | ||
| enum class GemmWarpPolicy { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add TVM object type info + mark overrides.
GemmSPNode is missing _type_key and TVM_DECLARE_FINAL_OBJECT_INFO, and the vfuncs aren’t marked override. Add them to enable reflection and catch signature drift.
Apply this diff:
class GemmSPNode : public TileOperatorNode {
public:
- Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const;
- LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) const;
+ static constexpr const char* _type_key = "tl.GemmSP";
+ TVM_DECLARE_FINAL_OBJECT_INFO(GemmSPNode, TileOperatorNode);
+
+ Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override;
+ LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) const override;
@@
- TileOperator Clone() const;
+ TileOperator Clone() const override;
private:
mutable bool completed_ = false;
};Also applies to: 41-45
🤖 Prompt for AI Agents
In src/op/gemm_sp.h around lines 17-21 (and also apply same changes to lines
41-45), GemmSPNode is missing TVM object type info and its virtual functions
aren't marked override; add a static constexpr const char* _type_key with an
appropriate unique value (e.g. "op.GemmSPNode"), add
TVM_DECLARE_FINAL_OBJECT_INFO(GemmSPNode, TileOperatorNode) in the class body,
and mark the Lower and InferLayout method declarations with the C++ override
specifier; apply the identical pattern to the other node class at lines 41-45 so
both classes support TVM reflection and catch signature drift.
| Gemm::Gemm(Array<PrimExpr> args, BufferMap vmap) { | ||
| Aptr = args[0]; | ||
| Bptr = args[1]; | ||
| Cptr = args[2]; | ||
| A = vmap[GetVarFromAccessPtr(Aptr)]; | ||
| B = vmap[GetVarFromAccessPtr(Bptr)]; | ||
| C = vmap[GetVarFromAccessPtr(Cptr)]; | ||
| trans_A = args[3].as<Bool>().value(); | ||
| trans_B = args[4].as<Bool>().value(); | ||
| M = args[5].as<IntImm>().value()->value; | ||
| N = args[6].as<IntImm>().value()->value; | ||
| K = args[7].as<IntImm>().value()->value; | ||
| policy = static_cast<GemmWarpPolicy>(args[8].as<IntImm>().value()->value); | ||
| clear_accum = args[9].as<Bool>().value(); | ||
| stride_A = args[10].as<IntImm>().value()->value; | ||
| stride_B = args[11].as<IntImm>().value()->value; | ||
| offset_A = args[12].as<IntImm>().value()->value; | ||
| offset_B = args[13].as<IntImm>().value()->value; | ||
| ObjectPtr<GemmNode> node = make_object<GemmNode>(); | ||
|
|
||
| node->Aptr = args[0]; | ||
| node->Bptr = args[1]; | ||
| node->Cptr = args[2]; | ||
| node->A = vmap[GetVarFromAccessPtr(node->Aptr)]; | ||
| node->B = vmap[GetVarFromAccessPtr(node->Bptr)]; | ||
| node->C = vmap[GetVarFromAccessPtr(node->Cptr)]; | ||
| node->trans_A = args[3].as<Bool>().value(); | ||
| node->trans_B = args[4].as<Bool>().value(); | ||
| node->M = args[5].as<IntImm>().value()->value; | ||
| node->N = args[6].as<IntImm>().value()->value; | ||
| node->K = args[7].as<IntImm>().value()->value; | ||
| node->policy = | ||
| static_cast<GemmWarpPolicy>(args[8].as<IntImm>().value()->value); | ||
| node->clear_accum = args[9].as<Bool>().value(); | ||
| node->stride_A = args[10].as<IntImm>().value()->value; | ||
| node->stride_B = args[11].as<IntImm>().value()->value; | ||
| node->offset_A = args[12].as<IntImm>().value()->value; | ||
| node->offset_B = args[13].as<IntImm>().value()->value; | ||
| if (args.size() > 14) { | ||
| kPack = args[14].as<IntImm>().value()->value; | ||
| if (kPack != 1 && kPack != 2) { | ||
| node->kPack = args[14].as<IntImm>().value()->value; | ||
| if (node->kPack != 1 && node->kPack != 2) { | ||
| ICHECK(false) << "kPack must be 1 or 2"; | ||
| } | ||
| } | ||
| if (args.size() > 15) { | ||
| wg_wait = args[15].as<IntImm>().value()->value; | ||
| node->wg_wait = args[15].as<IntImm>().value()->value; | ||
| } | ||
| data_ = std::move(node); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Validate vmap keys before indexing
Guard vmap lookups to avoid undefined Buffers on malformed calls.
- node->A = vmap[GetVarFromAccessPtr(node->Aptr)];
- node->B = vmap[GetVarFromAccessPtr(node->Bptr)];
- node->C = vmap[GetVarFromAccessPtr(node->Cptr)];
+ auto vA = GetVarFromAccessPtr(node->Aptr); ICHECK(vmap.count(vA)); node->A = vmap[vA];
+ auto vB = GetVarFromAccessPtr(node->Bptr); ICHECK(vmap.count(vB)); node->B = vmap[vB];
+ auto vC = GetVarFromAccessPtr(node->Cptr); ICHECK(vmap.count(vC)); node->C = vmap[vC];📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| Gemm::Gemm(Array<PrimExpr> args, BufferMap vmap) { | |
| Aptr = args[0]; | |
| Bptr = args[1]; | |
| Cptr = args[2]; | |
| A = vmap[GetVarFromAccessPtr(Aptr)]; | |
| B = vmap[GetVarFromAccessPtr(Bptr)]; | |
| C = vmap[GetVarFromAccessPtr(Cptr)]; | |
| trans_A = args[3].as<Bool>().value(); | |
| trans_B = args[4].as<Bool>().value(); | |
| M = args[5].as<IntImm>().value()->value; | |
| N = args[6].as<IntImm>().value()->value; | |
| K = args[7].as<IntImm>().value()->value; | |
| policy = static_cast<GemmWarpPolicy>(args[8].as<IntImm>().value()->value); | |
| clear_accum = args[9].as<Bool>().value(); | |
| stride_A = args[10].as<IntImm>().value()->value; | |
| stride_B = args[11].as<IntImm>().value()->value; | |
| offset_A = args[12].as<IntImm>().value()->value; | |
| offset_B = args[13].as<IntImm>().value()->value; | |
| ObjectPtr<GemmNode> node = make_object<GemmNode>(); | |
| node->Aptr = args[0]; | |
| node->Bptr = args[1]; | |
| node->Cptr = args[2]; | |
| node->A = vmap[GetVarFromAccessPtr(node->Aptr)]; | |
| node->B = vmap[GetVarFromAccessPtr(node->Bptr)]; | |
| node->C = vmap[GetVarFromAccessPtr(node->Cptr)]; | |
| node->trans_A = args[3].as<Bool>().value(); | |
| node->trans_B = args[4].as<Bool>().value(); | |
| node->M = args[5].as<IntImm>().value()->value; | |
| node->N = args[6].as<IntImm>().value()->value; | |
| node->K = args[7].as<IntImm>().value()->value; | |
| node->policy = | |
| static_cast<GemmWarpPolicy>(args[8].as<IntImm>().value()->value); | |
| node->clear_accum = args[9].as<Bool>().value(); | |
| node->stride_A = args[10].as<IntImm>().value()->value; | |
| node->stride_B = args[11].as<IntImm>().value()->value; | |
| node->offset_A = args[12].as<IntImm>().value()->value; | |
| node->offset_B = args[13].as<IntImm>().value()->value; | |
| if (args.size() > 14) { | |
| kPack = args[14].as<IntImm>().value()->value; | |
| if (kPack != 1 && kPack != 2) { | |
| node->kPack = args[14].as<IntImm>().value()->value; | |
| if (node->kPack != 1 && node->kPack != 2) { | |
| ICHECK(false) << "kPack must be 1 or 2"; | |
| } | |
| } | |
| if (args.size() > 15) { | |
| wg_wait = args[15].as<IntImm>().value()->value; | |
| node->wg_wait = args[15].as<IntImm>().value()->value; | |
| } | |
| data_ = std::move(node); | |
| } | |
| Gemm::Gemm(Array<PrimExpr> args, BufferMap vmap) { | |
| ObjectPtr<GemmNode> node = make_object<GemmNode>(); | |
| node->Aptr = args[0]; | |
| node->Bptr = args[1]; | |
| node->Cptr = args[2]; | |
| auto vA = GetVarFromAccessPtr(node->Aptr); ICHECK(vmap.count(vA)); node->A = vmap[vA]; | |
| auto vB = GetVarFromAccessPtr(node->Bptr); ICHECK(vmap.count(vB)); node->B = vmap[vB]; | |
| auto vC = GetVarFromAccessPtr(node->Cptr); ICHECK(vmap.count(vC)); node->C = vmap[vC]; | |
| node->trans_A = args[3].as<Bool>().value(); | |
| node->trans_B = args[4].as<Bool>().value(); | |
| node->M = args[5].as<IntImm>().value()->value; | |
| node->N = args[6].as<IntImm>().value()->value; | |
| node->K = args[7].as<IntImm>().value()->value; | |
| node->policy = | |
| static_cast<GemmWarpPolicy>(args[8].as<IntImm>().value()->value); | |
| node->clear_accum = args[9].as<Bool>().value(); | |
| node->stride_A = args[10].as<IntImm>().value()->value; | |
| node->stride_B = args[11].as<IntImm>().value()->value; | |
| node->offset_A = args[12].as<IntImm>().value()->value; | |
| node->offset_B = args[13].as<IntImm>().value()->value; | |
| if (args.size() > 14) { | |
| node->kPack = args[14].as<IntImm>().value()->value; | |
| if (node->kPack != 1 && node->kPack != 2) { | |
| ICHECK(false) << "kPack must be 1 or 2"; | |
| } | |
| } | |
| if (args.size() > 15) { | |
| node->wg_wait = args[15].as<IntImm>().value()->value; | |
| } | |
| data_ = std::move(node); | |
| } |
🤖 Prompt for AI Agents
In src/op/gemm.cc around lines 36 to 67, the code indexes vmap with
GetVarFromAccessPtr(...) without checking the key exists which can yield
undefined Buffers on malformed calls; before assigning node->A/node->B/node->C,
call GetVarFromAccessPtr for each pointer, check vmap.find(var) != vmap.end()
(or equivalent) and assert/ICHECK with a clear message if missing, then use the
found iterator to assign the Buffer; apply the same guarded lookup pattern for
any future vmap accesses to avoid undefined behavior.
| private: | ||
| PrimExpr MakeInitValue() const; | ||
| PrimExpr MakeReduce(const PrimExpr &a, const PrimExpr &b) const; | ||
| std::string MakeCodegenReducer() const; | ||
| }; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Header uses std::string but doesn’t include
MakeCodegenReducer() returns std::string; include to avoid fragile transitive includes.
#include "operator.h"
+#include <string>📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| private: | |
| PrimExpr MakeInitValue() const; | |
| PrimExpr MakeReduce(const PrimExpr &a, const PrimExpr &b) const; | |
| std::string MakeCodegenReducer() const; | |
| }; | |
| #include "operator.h" | |
| #include <string> |
🤖 Prompt for AI Agents
In src/op/reduce.h around lines 41 to 45, the header declares a function
returning std::string but does not include the <string> header; add #include
<string> to this header (near the other includes) so std::string is defined and
you don't rely on fragile transitive includes.
…el operation cloning - Updated the Clone methods in AtomicAddNode and CopyNode to ensure that the parallel operation (par_op_) is properly cloned when defined, improving the integrity of cloned objects. - Refactored the FillNode class to use ParallelOp directly instead of std::make_unique, streamlining the creation of parallel operations. - Made minor adjustments in layout inference and other related methods for consistency and clarity.
- Eliminated the call to the global function "tl.fill.lower" in the FillNode::Lower method, streamlining the code and improving clarity. - Retained the core functionality of the method while enhancing maintainability by reducing unnecessary dependencies.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (6)
src/op/atomic_add.cc (3)
131-137: Scalar path is not atomic and not additive.Current scalar branch does a plain store (dst = src[0]). It must perform atomic add to dst.
For AtomicAddNode::MakeSIMTLoop(arith::Analyzer *analyzer) const { Array<IterVar> loop_vars = MakeIterVars(); bool is_scalar = loop_vars.size() == 0; if (is_scalar) { - return For(Var("i"), 0, 1, ForKind::kSerial, - BufferStore(dst, BufferLoad(src, {0}), {0})); + Array<PrimExpr> args; + args.push_back(StringImm("AtomicAdd")); + auto dst_ref = BufferLoad(dst, {0}); + args.push_back(Call(DataType::Handle(), builtin::address_of(), {dst_ref})); + PrimExpr val = BufferLoad(src, {0}); + if (src->dtype != dst->dtype) val = Cast(dst->dtype, val); + args.push_back(val); + return For(Var("i"), 0, 1, ForKind::kSerial, + Evaluate(Call(dst->dtype, builtin::call_extern(), args))); }
155-181: Don’t take address_of an if-then-else; guard the atomic call instead.
address_of()must receive a BufferLoad. Wrapping it withif_then_elseis invalid and can miscompile. Guard the call with the destination predicate (and optionally source).PrimExpr src_value = BufferLoad(src, src_indices); if (src->dtype != dst->dtype) src_value = Cast(dst->dtype, src_value); - if (src_predicate.defined()) - src_value = if_then_else(src_predicate, src_value, make_zero(dst->dtype)); - - PrimExpr dst_value = BufferLoad(dst, dst_indices); - if (dst_predicate.defined()) - dst_value = if_then_else(dst_predicate, dst_value, make_zero(dst->dtype)); - - Call address_of_value = - tvm::tir::Call(DataType::Handle(), builtin::address_of(), {dst_value}); - - new_args.push_back(address_of_value); + // Optional: keep src_value zeroing, but do not wrap the address. + if (src_predicate.defined()) + src_value = if_then_else(src_predicate, src_value, make_zero(dst->dtype)); + + PrimExpr dst_ref = BufferLoad(dst, dst_indices); + Call dst_ptr = + Call(DataType::Handle(), builtin::address_of(), {dst_ref}); + + new_args.push_back(dst_ptr); new_args.push_back(src_value); Call atomicadd_call = tvm::tir::Call(dst->dtype, builtin::call_extern(), new_args); - Stmt body = tvm::tir::Evaluate(atomicadd_call); + Stmt body = tvm::tir::Evaluate(atomicadd_call); + if (dst_predicate.defined()) + body = IfThenElse(dst_predicate, body);
194-215: Vectorizing generic atomic loops is unsafe. Use atomic-aware vectorization or skip.
VectorizeLoopmay reorder/merge memory ops and break atomicity. Prefer the specializedVectorizeAtomicAdd(...), or avoid vectorization here.- // TODO(@dyq): buggy implementation, need to fix - // vectorized_thread_loop = VectorizeAtomicAdd( - // thread_loop, thread_var, thread_bounds, GetArchInt(target)); - auto vectorized_thread_loop = VectorizeLoop(thread_loop); + // Preserve atomic semantics + auto vectorized_thread_loop = + VectorizeAtomicAdd(thread_loop, thread_var, thread_bounds, GetArchInt(target));If
VectorizeAtomicAddis not ready, keepthread_loopwithout vectorization.src/transform/layout_inference.cc (2)
331-335: Guard RegionOp path in getBufferFromAccessPtr.Assume/verify arg[0] is a BufferLoad to avoid nullptr deref.
- } else if (call->op.same_as(RegionOp::Get())) { - return call->args[0].as<BufferLoadNode>()->buffer; + } else if (call->op.same_as(RegionOp::Get())) { + auto bl = call->args[0].as<BufferLoadNode>(); + ICHECK(bl) << "RegionOp expects BufferLoad as arg0"; + return bl->buffer; }
561-566: Typo in memory scope breaks validation."local.framgent" → "local.fragment".
- if (buffer.scope() == "local.framgent") { + if (buffer.scope() == "local.fragment") {src/op/copy.cc (1)
38-98: Fix int16 mapping in to_CUtensorMapDataType.INT16 currently maps to UINT16.
} else if (dtype.is_int()) { switch (dtype.bits()) { case 64: tp = CU_TENSOR_MAP_DATA_TYPE_INT64; break; case 32: tp = CU_TENSOR_MAP_DATA_TYPE_INT32; break; case 16: - tp = CU_TENSOR_MAP_DATA_TYPE_UINT16; + tp = CU_TENSOR_MAP_DATA_TYPE_INT16; break; case 8: tp = CU_TENSOR_MAP_DATA_TYPE_UINT8; break;
♻️ Duplicate comments (3)
src/op/atomic_add.cc (1)
37-55: Validate args size and harden RegionOp extraction.
- Missing args.size() guard before indexing.
- RegionOp/buffer assumptions are unchecked; can crash on malformed calls.
AtomicAdd::AtomicAdd(Array<PrimExpr> args, BufferMap vmap) { + ICHECK_GE(args.size(), 2) << "AtomicAdd expects at least 2 args"; ObjectPtr<AtomicAddNode> node = make_object<AtomicAddNode>(); Array<Range> rgs[2]; Buffer bf[2]; for (int i = 0; i < 2; i++) { auto expr = args[i]; auto call = expr.as<CallNode>(); - ICHECK(call); - auto region = RegionOp(call->args, vmap); - rgs[i] = region->GetRanges(); - bf[i] = region->GetBuffer(); + ICHECK(call) << "AtomicAdd arg" << i << " is not a Call"; + auto region = RegionOp(call->args, vmap); + ICHECK(region.defined()) << "Failed to parse RegionOp for arg " << i; + rgs[i] = region->GetRanges(); + bf[i] = region->GetBuffer(); } std::tie(node->src, node->dst) = std::tie(bf[0], bf[1]); std::tie(node->src_range, node->dst_range) = std::tie(rgs[0], rgs[1]); if (args.size() >= 3) { node->coalesced_width = Downcast<IntImm>(args[2]); } data_ = std::move(node); }src/op/copy.cc (2)
115-142: Validate args and harden RegionOp extraction. Also fix flag parsing.
- Add args.size() check before indexing.
- ICHECK RegionOp is defined.
- Parse disable_tma from IntImm->value (Bool downcast is brittle).
- Prefer safe IntImm downcast for eviction_policy.
Copy::Copy(Array<PrimExpr> args, BufferMap vmap) { + ICHECK_GE(args.size(), 2) << "Copy expects at least 2 args"; ObjectPtr<CopyNode> node = make_object<CopyNode>(); Array<Range> rgs[2]; Buffer bf[2]; for (int i = 0; i < 2; i++) { auto expr = args[i]; auto call = expr.as<CallNode>(); ICHECK(call); - auto region = RegionOp(call->args, vmap); - rgs[i] = region->GetRanges(); - bf[i] = region->GetBuffer(); + auto region = RegionOp(call->args, vmap); + ICHECK(region.defined()) << "Failed to parse RegionOp for arg " << i; + rgs[i] = region->GetRanges(); + bf[i] = region->GetBuffer(); } std::tie(node->src, node->dst) = std::tie(bf[0], bf[1]); std::tie(node->src_range, node->dst_range) = std::tie(rgs[0], rgs[1]); if (args.size() >= 3) { auto coalesced_width = Downcast<IntImm>(args[2]); if (coalesced_width->value > 0) { node->coalesced_width = coalesced_width; } } if (args.size() >= 4) { - node->disable_tma = Downcast<Bool>(args[3]); + node->disable_tma = static_cast<bool>(Downcast<IntImm>(args[3])->value); } if (args.size() >= 5) { - node->eviction_policy = args[4].as<IntImmNode>()->value; + node->eviction_policy = Downcast<IntImm>(args[4])->value; } data_ = std::move(node); }
1051-1064: Fix IntImm extraction in Conv2DIm2ColOp ctor.
.as<IntImm>().value()->valueis invalid; use Downcast.node->nhw_step = args[2]; node->c_step = args[3]; - node->kernel = args[4].as<IntImm>().value()->value; - node->stride = args[5].as<IntImm>().value()->value; - node->dilation = args[6].as<IntImm>().value()->value; - node->padding = args[7].as<IntImm>().value()->value; - node->eviction_policy = args[8].as<IntImm>().value()->value; + node->kernel = Downcast<IntImm>(args[4])->value; + node->stride = Downcast<IntImm>(args[5])->value; + node->dilation = Downcast<IntImm>(args[6])->value; + node->padding = Downcast<IntImm>(args[7])->value; + node->eviction_policy = Downcast<IntImm>(args[8])->value;
🧹 Nitpick comments (4)
src/transform/layout_inference.cc (2)
83-85: Clearer error text.Minor nit: “null inside run_infer_step” → “undefined inside RunInferStep”.
- ICHECK(next.defined()) << "infer_list_[" << cur_infer_id - << "] is null inside run_infer_step."; + ICHECK(next.defined()) << "infer_list_[" << cur_infer_id + << "] is undefined inside RunInferStep.";
258-265: Avoid std::move here (readability).You don’t reuse infer_list_[i] afterward in this loop; moving obscures state. Keep a const handle.
- TileOperator base_infer = std::move(infer_list_[i]); + const TileOperator& base_infer = infer_list_[i];src/op/copy.cc (2)
372-391: Doc nit and condition fix suggestion.Comment for Bulk Store step 2 says “shared.dyn and local.fragment” but code checks src(shared)→dst(global). Update the comment to avoid confusion.
- * source is shared.dyn, destination is global memory, and both buffers have the - * same data type. + * source is shared.dyn (or shared), destination is global memory, and both + * buffers have the same data type.
515-548: GPU path OK; CPU vectorization unconditional.Consider guarding CPU vectorization behind a heuristic if it hurts small loops.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (5)
src/op/atomic_add.cc(7 hunks)src/op/copy.cc(22 hunks)src/op/elem.cc(5 hunks)src/op/parallel.h(2 hunks)src/transform/layout_inference.cc(11 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
- src/op/parallel.h
- src/op/elem.cc
🧰 Additional context used
🧬 Code graph analysis (3)
src/op/atomic_add.cc (4)
src/op/atomic_add.h (1)
AtomicAdd(48-53)src/op/region.cc (1)
RegionOp(14-32)src/op/copy.cc (16)
Clone(144-150)Clone(144-144)Clone(1065-1068)Clone(1065-1065)MakeIterVars(158-170)MakeIterVars(158-158)MakeIndices(181-198)MakeIndices(181-182)MakePredicate(209-237)MakePredicate(209-211)MakeSIMTLoop(247-292)MakeSIMTLoop(247-247)InferLayout(330-363)InferLayout(330-331)InferLayout(1248-1251)InferLayout(1248-1249)src/op/parallel.h (2)
ParallelOp(117-124)InferLayout(70-77)
src/op/copy.cc (4)
src/op/region.cc (5)
RegionOp(14-32)Clone(34-37)InferLayout(53-56)Lower(49-51)Lower(49-49)src/op/atomic_add.cc (6)
Clone(57-63)Clone(57-57)InferLayout(224-244)InferLayout(224-225)Lower(194-222)Lower(194-194)src/op/copy.h (1)
Clone(211-274)src/op/operator.cc (2)
GetVarFromAccessPtr(37-44)GetVarFromAccessPtr(37-37)
src/transform/layout_inference.cc (3)
src/op/parallel.h (1)
ParallelOp(117-124)src/op/parallel.cc (4)
op(96-106)op(96-96)op(108-116)op(108-108)src/transform/pipeline_planning.cc (10)
op(54-72)op(54-54)op(74-92)op(74-74)op(94-121)op(94-94)op(123-133)op(123-123)op(545-554)op(545-545)
🔇 Additional comments (10)
src/transform/layout_inference.cc (5)
18-18: Include RegionOp is correct.Region-based access will be needed for operator parsing and buffer discovery.
209-224: Good defensive checks on queue seeding.The added .defined() checks prevent UB later in inference.
349-356: OK to construct ParallelOp directly.Matches the TileOperator migration and enables cloning later.
417-424: Clone semantics look right.Backing up infer_list_ by Clone() preserves internal par_op_ state per node.
599-626: Predicate wrapping logic LGTM.Conditional wrapping of partitioned/vectorized loop is correct.
src/op/copy.cc (5)
330-363: Lazy par_op_ initialization LGTM.Defers building ParallelOp until needed; matches other operators.
455-470: GetCopyInst gating is correct.Order of preference TMA→LDSM/STSM→normal is sensible.
561-741: LDSM/STSM lowering preconditions are solid.Strict checks with clean fallbacks to normal copy. Good.
754-1011: TMA bulk copy stride checks look good; minor registration mismatch.You read optional eviction_policy from args[4], but the op is registered with 4 inputs. Either bump to 5 or drop the arg.
TIR_REGISTER_TL_OP(Copy, copy) - .set_num_inputs(4) + .set_num_inputs(5) // src, dst, coalesced_width, disable_tma, eviction_policyIf you intend to keep 4, remove parsing/use of eviction_policy here.
1248-1251: No-op InferLayout for Conv2DIm2Col is acceptable.This op doesn’t contribute layouts; returning {} is fine.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (3)
src/op/operator.h (1)
22-22: Avoidusing namespacein a header.This pollutes
tvm::tlwith all of TIR and leaks into all includers. Prefer selectiveusingor fully qualified names.Apply:
-using namespace tir; +using tir::Buffer; +using tir::Call; +using tir::PrimExpr; +using tir::Range; +using tir::Stmt; +using tir::Var;src/op/elem.cc (2)
31-38: Null-deref risk and inconsistent checks onRampNodefields.
ramp->stride.as<IntImmNode>()may be null; also preferICHECKconsistently.- if (const auto *ramp = index.as<RampNode>()) { - CHECK(ramp->stride.as<IntImmNode>()->value == 1) - << "Only stride 1 ramps are supported"; - const auto *lanes = ramp->lanes.as<IntImmNode>(); - CHECK(lanes) - << "Scalable vectors not supported in BufferRegion conversion"; + if (const auto* ramp = index.as<RampNode>()) { + const auto* stride = ramp->stride.as<IntImmNode>(); + ICHECK(stride && stride->value == 1) + << "Only stride 1 ramps are supported"; + // Disallow scalable lanes if not a constant int. + ICHECK(ramp->lanes.as<IntImmNode>()) + << "Scalable vectors not supported in BufferRegion conversion"; node->region.push_back(Range::FromMinExtent(ramp->base, ramp->lanes)); } else {
96-111: DuplicateInferLayoutcall; remove redundancy.
par_op->InferLayout(...)is called twice back-to-back.- par_op->InferLayout({T.target, T.thread_bounds, T.layout_map}, - InferLevel::kFree); - par_op->InferLayout({T.target, T.thread_bounds, T.layout_map}, - InferLevel::kFree); + par_op->InferLayout({T.target, T.thread_bounds, T.layout_map}, + InferLevel::kFree);
🧹 Nitpick comments (4)
src/op/operator.h (2)
51-66: Add a virtual destructor to the base node.Explicit virtual dtor is standard for polymorphic bases and avoids analyzer warnings.
class TileOperatorNode: public Object { public: virtual Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const = 0; virtual LayoutMap InferLayout(const LayoutInferArgs& T, InferLevel level) const = 0; virtual TileOperator Clone() const = 0; static constexpr const char* _type_key = "tl.TileOperator"; TVM_DECLARE_BASE_OBJECT_INFO(TileOperatorNode, Object); + virtual ~TileOperatorNode() = default; };
74-78: Pass by const-reference to avoid copies.These are lightweight but still better as const-refs; also aligns with TVM style.
-Var GetVarFromAccessPtr(const PrimExpr &expr); - -TileOperator ParseOperator(Call call, BufferMap vmap); -TileOperator ParseOperator(Stmt stmt, BufferMap vmap); +Var GetVarFromAccessPtr(const PrimExpr& expr); + +TileOperator ParseOperator(const Call& call, const BufferMap& vmap); +TileOperator ParseOperator(const Stmt& stmt, const BufferMap& vmap);src/op/elem.cc (2)
128-129: PreferICHECK(false)overLOG(FATAL)for invariant.Minor consistency nit with surrounding checks.
- LOG(FATAL) << "Unsupported scope " << dst.scope(); + ICHECK(false) << "Unsupported scope " << dst.scope();
79-94: Use explicit Int32 for loop vars.
In FillNode::MakeSIMTLoop (src/op/elem.cc:83), replace the derived dtype with a fixed 32-bit integer to match other loops (e.g., fused loops in loop_fusion_utils.h):- Var var = Var(std::string{char('i' + i)}, region[i]->extent->dtype); + Var var = Var(std::string{char('i' + i)}, DataType::Int(32));
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (4)
src/op/copy.cc(22 hunks)src/op/elem.cc(5 hunks)src/op/operator.h(2 hunks)src/transform/lower_tile_op.cc(2 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
- src/transform/lower_tile_op.cc
- src/op/copy.cc
🧰 Additional context used
🧬 Code graph analysis (2)
src/op/elem.cc (3)
src/op/operator.cc (2)
GetVarFromAccessPtr(37-44)GetVarFromAccessPtr(37-37)src/op/elem.h (1)
Fill(36-41)src/op/parallel.h (1)
ParallelOp(117-124)
src/op/operator.h (2)
src/op/copy.h (1)
Clone(211-274)src/op/operator.cc (6)
GetVarFromAccessPtr(37-44)GetVarFromAccessPtr(37-37)ParseOperator(18-27)ParseOperator(18-18)ParseOperator(29-35)ParseOperator(29-29)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: build-test-nvidia
- GitHub Check: build-test-amd
🔇 Additional comments (9)
src/op/operator.h (4)
14-15: Good addition of required TIR includes.These are necessary for Stmt and op attribute types.
68-71: LGTM on ObjectRef surface.Wrapper follows TVM Object/Ref pattern.
79-79: Confirmffi::TypedFunctiondefinition
The alias at src/op/operator.h:79 referencesffi::TypedFunction, but nonamespace ffiwas found in the codebase. Ensure you include the header that definesffi::TypedFunction(e.g. the appropriatetvm/ffi.h), or replace it withruntime::TypedPackedFuncfor consistency with existing op builders.
81-92: Verify operator constructors for TLOpBuilder
TheTIR_REGISTER_TL_OPmacro’s builder lambda doesEntry(args, vmap). Confirm every operator subclass (e.g., ReduceOp, CumSumOp, Copy, Conv2DIm2ColOp, AtomicAdd, etc.) defines a constructorClassName(Array<PrimExpr>, BufferMap)to match that call.src/op/elem.cc (5)
50-54: Type-safe cast to dst dtype.LGTM; avoids dtype mismatch on value.
71-77: Clone implementation is fine.Copy-construct node and wrap; OK.
116-127: LGTM on shared path lowering.Partition + vectorize with predicate guard is consistent with fragment path.
132-135: InferLayout stub is acceptable for now.Return empty map; OK until layout inference is implemented.
137-141: Registry chaining looks correct.
TIR_REGISTER_TL_OP(Fill, fill)chaining with.set_num_inputsand opaque effect is fine.
| Fill::Fill(Array<PrimExpr> args, BufferMap vmap) { | ||
| ObjectPtr<FillNode> node = make_object<FillNode>(); | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Validate argument count early.
Guard against out-of-bounds access on args[1].
Fill::Fill(Array<PrimExpr> args, BufferMap vmap) {
ObjectPtr<FillNode> node = make_object<FillNode>();
+ ICHECK_EQ(args.size(), 2) << "Fill expects 2 args: dst, value";📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| Fill::Fill(Array<PrimExpr> args, BufferMap vmap) { | |
| ObjectPtr<FillNode> node = make_object<FillNode>(); | |
| Fill::Fill(Array<PrimExpr> args, BufferMap vmap) { | |
| ObjectPtr<FillNode> node = make_object<FillNode>(); | |
| ICHECK_EQ(args.size(), 2) << "Fill expects 2 args: dst, value"; |
🤖 Prompt for AI Agents
In src/op/elem.cc around lines 25-27, the constructor uses args[1] without
validating args length; add an early guard that checks args.size() >= 2 and
handle the error path (e.g., DCHECK/CHECK_GE or throw a descriptive
exception/log fatal and return) before any access to args[1] to prevent
out-of-bounds access.
| node->dst = buffer_load->buffer; | ||
| } else { | ||
| dst = vmap[GetVarFromAccessPtr(args[0])]; | ||
| for (int i = 0; i < dst->shape.size(); i++) { | ||
| region.push_back(Range(0, dst->shape[i])); | ||
| node->dst = vmap[GetVarFromAccessPtr(args[0])]; | ||
| for (int i = 0; i < node->dst->shape.size(); i++) { | ||
| node->region.push_back(Range(0, node->dst->shape[i])); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing vmap lookup check; potential null Buffer deref.
Accessing node->dst->shape without confirming dst.defined() can crash.
- node->dst = vmap[GetVarFromAccessPtr(args[0])];
- for (int i = 0; i < node->dst->shape.size(); i++) {
+ Var buf_var = GetVarFromAccessPtr(args[0]);
+ ICHECK(vmap.count(buf_var)) << "Missing buffer mapping for " << buf_var;
+ node->dst = vmap[buf_var];
+ ICHECK(node->dst.defined());
+ for (int i = 0; i < static_cast<int>(node->dst->shape.size()); i++) {
node->region.push_back(Range(0, node->dst->shape[i]));
}📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| node->dst = buffer_load->buffer; | |
| } else { | |
| dst = vmap[GetVarFromAccessPtr(args[0])]; | |
| for (int i = 0; i < dst->shape.size(); i++) { | |
| region.push_back(Range(0, dst->shape[i])); | |
| node->dst = vmap[GetVarFromAccessPtr(args[0])]; | |
| for (int i = 0; i < node->dst->shape.size(); i++) { | |
| node->region.push_back(Range(0, node->dst->shape[i])); | |
| } | |
| node->dst = buffer_load->buffer; | |
| } else { | |
| Var buf_var = GetVarFromAccessPtr(args[0]); | |
| ICHECK(vmap.count(buf_var)) << "Missing buffer mapping for " << buf_var; | |
| node->dst = vmap[buf_var]; | |
| ICHECK(node->dst.defined()); | |
| for (int i = 0; i < static_cast<int>(node->dst->shape.size()); i++) { | |
| node->region.push_back(Range(0, node->dst->shape[i])); | |
| } |
🤖 Prompt for AI Agents
In src/op/elem.cc around lines 42 to 47, the code assumes
vmap[GetVarFromAccessPtr(args[0])] returns a defined buffer and directly
accesses node->dst->shape; add a safe lookup: fetch the buffer pointer into a
temporary, check that it is non-null / defined before assigning to node->dst or
dereferencing its shape, and handle the missing entry by logging/throwing an
explanatory error or returning early; only when the buffer is valid, set
node->dst and populate node->region using its shape.
| ICHECK(node->region.size() == node->dst->shape.size()) | ||
| << "region size = " << node->region.size() | ||
| << " != " << node->dst->shape.size(); | ||
| for (int i = 0; i < node->region.size(); i++) { | ||
| // bound check if region is static | ||
| if (region[i]->min.as<IntImm>()) { | ||
| int64_t min = Downcast<IntImm>(region[i]->min)->value; | ||
| if (node->region[i]->min.as<IntImm>()) { | ||
| int64_t min = Downcast<IntImm>(node->region[i]->min)->value; | ||
| ICHECK_GE(min, 0) << "region[" << i << "] = " << min << " < 0"; | ||
| } | ||
| if (region[i]->extent.as<IntImm>()) { | ||
| int64_t extent = Downcast<IntImm>(region[i]->extent)->value; | ||
| ICHECK_LE(extent, Downcast<IntImm>(dst->shape[i])->value) | ||
| << "region[" << i << "] = " << extent << " > " << dst->shape[i]; | ||
| if (node->region[i]->extent.as<IntImm>()) { | ||
| int64_t extent = Downcast<IntImm>(node->region[i]->extent)->value; | ||
| ICHECK_LE(extent, Downcast<IntImm>(node->dst->shape[i])->value) | ||
| << "region[" << i << "] = " << extent << " > " << node->dst->shape[i]; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix IntImm node checks and avoid assuming static dst shape.
Use IntImmNode in as<> and guard dst->shape[i] before Downcast.
- for (int i = 0; i < node->region.size(); i++) {
+ for (int i = 0; i < static_cast<int>(node->region.size()); i++) {
// bound check if region is static
- if (node->region[i]->min.as<IntImm>()) {
- int64_t min = Downcast<IntImm>(node->region[i]->min)->value;
+ if (node->region[i]->min.as<IntImmNode>()) {
+ int64_t min = Downcast<IntImm>(node->region[i]->min)->value;
ICHECK_GE(min, 0) << "region[" << i << "] = " << min << " < 0";
}
- if (node->region[i]->extent.as[IntImm>()) {
- int64_t extent = Downcast<IntImm>(node->region[i]->extent)->value;
- ICHECK_LE(extent, Downcast<IntImm>(node->dst->shape[i])->value)
- << "region[" << i << "] = " << extent << " > " << node->dst->shape[i];
+ if (node->region[i]->extent.as<IntImmNode>()) {
+ int64_t extent = Downcast<IntImm>(node->region[i]->extent)->value;
+ if (auto* s = node->dst->shape[i].as<IntImmNode>()) {
+ ICHECK_LE(extent, s->value)
+ << "region[" << i << "] extent " << extent
+ << " > static shape " << s->value;
+ }
}
}Optionally, use analyzer->CanProve for symbolic shapes.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| ICHECK(node->region.size() == node->dst->shape.size()) | |
| << "region size = " << node->region.size() | |
| << " != " << node->dst->shape.size(); | |
| for (int i = 0; i < node->region.size(); i++) { | |
| // bound check if region is static | |
| if (region[i]->min.as<IntImm>()) { | |
| int64_t min = Downcast<IntImm>(region[i]->min)->value; | |
| if (node->region[i]->min.as<IntImm>()) { | |
| int64_t min = Downcast<IntImm>(node->region[i]->min)->value; | |
| ICHECK_GE(min, 0) << "region[" << i << "] = " << min << " < 0"; | |
| } | |
| if (region[i]->extent.as<IntImm>()) { | |
| int64_t extent = Downcast<IntImm>(region[i]->extent)->value; | |
| ICHECK_LE(extent, Downcast<IntImm>(dst->shape[i])->value) | |
| << "region[" << i << "] = " << extent << " > " << dst->shape[i]; | |
| if (node->region[i]->extent.as<IntImm>()) { | |
| int64_t extent = Downcast<IntImm>(node->region[i]->extent)->value; | |
| ICHECK_LE(extent, Downcast<IntImm>(node->dst->shape[i])->value) | |
| << "region[" << i << "] = " << extent << " > " << node->dst->shape[i]; | |
| } | |
| ICHECK(node->region.size() == node->dst->shape.size()) | |
| << "region size = " << node->region.size() | |
| << " != " << node->dst->shape.size(); | |
| for (int i = 0; i < static_cast<int>(node->region.size()); i++) { | |
| // bound check if region is static | |
| if (node->region[i]->min.as<IntImmNode>()) { | |
| int64_t min = Downcast<IntImm>(node->region[i]->min)->value; | |
| ICHECK_GE(min, 0) << "region[" << i << "] = " << min << " < 0"; | |
| } | |
| if (node->region[i]->extent.as<IntImmNode>()) { | |
| int64_t extent = Downcast<IntImm>(node->region[i]->extent)->value; | |
| if (auto* s = node->dst->shape[i].as<IntImmNode>()) { | |
| ICHECK_LE(extent, s->value) | |
| << "region[" << i << "] extent " << extent | |
| << " > static shape " << s->value; | |
| } | |
| } | |
| } |
🤖 Prompt for AI Agents
In src/op/elem.cc around lines 56 to 69, the code incorrectly uses as<IntImm>()
and unconditionally Downcasts node->dst->shape[i], which assumes static shape;
change checks to use as<IntImmNode>() (or as<IntImmNode*>) when testing for
constant IntImm, and only Downcast/Downcast<IntImm> the dst->shape[i] after
guarding that dst->shape[i]->as<IntImmNode>() is non-null (or otherwise handle
symbolic shapes); optionally, for non-constant extents use the
analyzer->CanProve checks to validate bounds instead of assuming static
integers.
Docstrings generation was requested by @LeiWang1999. * #763 (comment) The following files were modified: * `src/op/atomic_add.cc` * `src/op/atomic_add.h` * `src/op/copy.cc` * `src/op/copy.h` * `src/op/elem.cc` * `src/op/elem.h` * `src/op/gemm.cc` * `src/op/gemm.h` * `src/op/gemm_sp.cc` * `src/op/gemm_sp.h` * `src/op/operator.cc` * `src/op/operator.h` * `src/op/parallel.cc` * `src/op/parallel.h` * `src/op/reduce.cc` * `src/op/reduce.h` * `src/op/region.cc` * `src/op/region.h` * `src/transform/layout_inference.cc` * `src/transform/lower_tile_op.cc`
|
Note Generated docstrings for this pull request at #770 |
* 📝 Add docstrings to `pytile_0826` Docstrings generation was requested by @LeiWang1999. * #763 (comment) The following files were modified: * `src/op/atomic_add.cc` * `src/op/atomic_add.h` * `src/op/copy.cc` * `src/op/copy.h` * `src/op/elem.cc` * `src/op/elem.h` * `src/op/gemm.cc` * `src/op/gemm.h` * `src/op/gemm_sp.cc` * `src/op/gemm_sp.h` * `src/op/operator.cc` * `src/op/operator.h` * `src/op/parallel.cc` * `src/op/parallel.h` * `src/op/reduce.cc` * `src/op/reduce.h` * `src/op/region.cc` * `src/op/region.h` * `src/transform/layout_inference.cc` * `src/transform/lower_tile_op.cc` * lint fix --------- Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Co-authored-by: LeiWang1999 <leiwang1999@outlook.com>
* 📝 Add docstrings to `pytile_0826` Docstrings generation was requested by @LeiWang1999. * tile-ai/tilelang#763 (comment) The following files were modified: * `src/op/atomic_add.cc` * `src/op/atomic_add.h` * `src/op/copy.cc` * `src/op/copy.h` * `src/op/elem.cc` * `src/op/elem.h` * `src/op/gemm.cc` * `src/op/gemm.h` * `src/op/gemm_sp.cc` * `src/op/gemm_sp.h` * `src/op/operator.cc` * `src/op/operator.h` * `src/op/parallel.cc` * `src/op/parallel.h` * `src/op/reduce.cc` * `src/op/reduce.h` * `src/op/region.cc` * `src/op/region.h` * `src/transform/layout_inference.cc` * `src/transform/lower_tile_op.cc` * lint fix --------- Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Co-authored-by: LeiWang1999 <leiwang1999@outlook.com>
* 📝 Add docstrings to `pytile_0826` Docstrings generation was requested by @LeiWang1999. * tile-ai/tilelang#763 (comment) The following files were modified: * `src/op/atomic_add.cc` * `src/op/atomic_add.h` * `src/op/copy.cc` * `src/op/copy.h` * `src/op/elem.cc` * `src/op/elem.h` * `src/op/gemm.cc` * `src/op/gemm.h` * `src/op/gemm_sp.cc` * `src/op/gemm_sp.h` * `src/op/operator.cc` * `src/op/operator.h` * `src/op/parallel.cc` * `src/op/parallel.h` * `src/op/reduce.cc` * `src/op/reduce.h` * `src/op/region.cc` * `src/op/region.h` * `src/transform/layout_inference.cc` * `src/transform/lower_tile_op.cc` * lint fix --------- Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Co-authored-by: LeiWang1999 <leiwang1999@outlook.com>
* [Index] Relocate Int64 Auto Promoter to ConfigBitWidth Pass, removing it from FlattenBuffer (#714) * Update submodule 'tvm' to commit e11521e6936a827efa334588d29571fbb4620107 * Refactor inject_pipeline.cc to enhance pipeline body rewriting and condition handling - Introduced a new function to replace IfThenElse nodes with their then_case while preserving attributes. - Streamlined the PipelineBodyRewriter to improve buffer access rewriting and async state management. - Enhanced the handling of pipeline loop conditions and added support for predicate conditions in the pipeline body. - Removed obsolete code and improved overall code clarity and maintainability. * lint fix * Refactor return statements in inject_pipeline.cc to remove unnecessary std::move calls - Updated return statements in multiple methods to return objects directly instead of using std::move, improving code clarity and potentially avoiding unnecessary moves. - Ensured consistent handling of BufferStore and BufferLoad nodes during pipeline transformations. * test fix * Enhance global read detection in pipeline planning - Updated the handling of global reads to account for condition expressions within IfThenElse nodes, ensuring accurate identification of global memory accesses. - Introduced a new flag to track whether the visitor is within a condition expression, improving the correctness of buffer access analysis. - Refactored the VisitStmt_ method to properly handle the structure of IfThenElse nodes, enhancing the clarity and maintainability of the code. * Add IndexLegalizer to enforce int64 for out-of-bound indices - Introduced the IndexLegalizer class to ensure that indices in BufferStore and BufferLoad nodes are promoted to int64 when they exceed their type bounds. - Refactored the Int64Promoter logic from flatten_buffer.cc into IndexLegalizer, improving code organization and reusability. - Updated the ConfigIndexBitwidth pass to apply IndexLegalizer after rewriting the body, enhancing the handling of index bitwidths in transformations. * [CI] Bind build-test CI to NVIDIA as AMD runners are being introduced (#718) * Update submodule 'tvm' to commit e11521e6936a827efa334588d29571fbb4620107 * Rename build-test job to build-test-nvidia and specify nvidia as a runner label in CI workflow. * Update CI workflow to specify 'nvidia' as an additional runner label for the format-check job. * fix: NVRTC backend (#717) * fix: NVRTC backend * fix: CI --------- Co-authored-by: LeiWang1999 <leiwang1999@outlook.com> * [CUDA] Init support for sm_120 (#716) * Init support for sm120 * fmt * resolve comments * unify mma gemm * fmt --------- Co-authored-by: LeiWang1999 <leiwang1999@outlook.com> * [CI] fix docs ci (#720) * [Chore] fix typos (#719) * chore: fix typos * chore: fix ruff * chore: fix clang-format * [CI][AMD] Add AMD GPU CI and fix some related bugs (#694) * [Enhancement] Refactor buffer index handling for improved precision and clarity (#668) - Enhanced buffer index handling to address precision issues by removing redundant operations. - Streamlined the logic for determining buffer overlaps, ensuring more accurate conflict detection. - Updated related documentation to reflect changes in buffer management practices. * Remove obsolete test script for AMD example, streamlining the examples directory. * Remove unused dtype_size variable in AMD example script to streamline code. * Add input configuration file and update AMD example script for enhanced flexibility - Introduced a new input.txt file for configurable parameters. - Modified the example_amd_flash_attn_fwd.py script to allow for a wider range of configurations, including additional options for num_stages, enable_rasterization, and k_pack. - Streamlined the main function for better clarity and organization. - Added a new test script to facilitate running the example with specified parameters. * Remove input configuration file and obsolete test script; enhance AMD example with swizzle layout annotations - Deleted input.txt and test.sh files as they are no longer needed. - Updated example_amd_flash_attn_fwd.py to include swizzle layout annotations for shared memory, improving bank conflict avoidance. - Reintroduced swizzle usage in the kernel for better performance. * Refactor AMD example script for FlashAttention-2 - Updated function names for clarity, changing `get_v2_configs` to `get_configs` and `fast_flashattn_v2` to `fast_flashattn`. - Streamlined the main function by renaming `main_v2` to `main` and adjusting the corresponding calls. - Removed outdated comments and improved code organization for better readability. * Refactor formatting in AMD FlashAttention example script - Improved code readability by adjusting line breaks and indentation in the `fast_flashattn` function. - Streamlined the `main` function parameter formatting for consistency. - Removed unnecessary blank lines to enhance overall code organization. * Update example_amd_flash_attn_fwd.py * Update AMD FlashAttention example and TVM submodule - Added a new example script `example_amd_flash_attn_fwd_k_block.py` for FlashAttention with K-blocking support. - Enhanced `example_amd_flash_attn_fwd.py` by expanding configuration options for block sizes and threads. - Updated the TVM submodule to the latest commit for improved functionality. - Introduced a new test script `test.sh` to facilitate running the new example with specified parameters. * Add CI workflow for automated format checking and testing - Introduced a new GitHub Actions workflow in `amd_ci.yml` to automate format checks and testing for pull requests. - The workflow includes steps for setting up a Python environment, running format checks, and executing tests. - Removed obsolete example script `example_amd_flash_attn_fwd_k_block.py` and test script `test.sh` to streamline the examples directory. * Rename CI workflow from "CI" to "AMD CI" for clarity and specificity. * Update AMD CI workflow to include copying PyTorch, TorchVision, and Torchaudio packages to the virtual environment for improved dependency management. * Update AMD CI workflow to install pytest directly instead of using requirements-test.txt * Update AMD CI workflow to remove 'flash-attn' from requirements and install dependencies from requirements-test.txt * Refactor AMD CI workflow to enhance clarity in removing 'flash-attn' from requirements-test.txt before installation * Remove Torchaudio package copying from AMD CI workflow to streamline dependency management. * Refactor AMD CI workflow to remove the format-check job and streamline the build-test process by directly copying PyTorch and TorchVision packages to the virtual environment. * Add installation of ROCm in AMD CI workflow - Included a step to execute the `install_rocm.sh` script for improved setup. - Removed unnecessary blank line for better readability in the workflow script. * Remove installation step for ROCm in AMD CI workflow to simplify the setup process. * Update AMD CI workflow to run specific test file with verbose output instead of all tests. * Add new tilelang built-in operations for AMD architecture - Introduced `tvm_mfma`, `tvm_mfma_store`, `tvm_rdna_wmma`, and `tvm_rdna_wmma_store` built-in operations to enhance support for matrix multiplication and storage in tilelang. - Each operation is configured with the appropriate number of inputs and marked as opaque in terms of call effects. * Enhance autotuner configurations and GEMM operations in AMD example - Updated block sizes and num_split_q parameters in `get_configs` for improved autotuning. - Modified `T.gemm` calls in `fast_flashattn` to utilize `GemmWarpPolicy.FullRow`, optimizing performance for matrix multiplications. * Update autotuner configurations in AMD example for enhanced performance - Refined block sizes, thread counts, and added new parameters in `get_configs` to optimize autotuning. - Adjusted `fast_flashattn` function to incorporate new parameters for panel size and coalesced widths, improving memory access patterns. * Enhance autotuner configurations and memory handling in AMD example - Expanded block sizes and thread counts in `get_configs` for improved autotuning capabilities. - Updated `fast_flashattn` to utilize a new shared memory allocation strategy, optimizing memory access patterns during GEMM operations. * Refine autotuner configurations and memory usage in AMD example - Reduced block sizes and adjusted thread counts in `get_configs` for optimized autotuning. - Updated `fast_flashattn` to utilize register fragments for accumulation, minimizing LDS usage and enhancing performance during GEMM operations. * Update autotuner configurations in AMD example for enhanced performance - Expanded block sizes and thread counts in `get_configs` to improve autotuning capabilities. - Adjusted `num_split_q` and `v_coalesced_width` parameters for better optimization during GEMM operations. * Enhance autotuner configurations and GEMM operations in AMD example - Expanded thread counts in `get_configs` to include higher values for improved autotuning. - Updated `fast_flashattn` to adjust accumulation logic and ensure proper handling of causal conditions, optimizing performance during matrix multiplications. * Update AMD CI workflow and remove obsolete test script - Modified the CI workflow to run on multiple environments: self-hosted, amd, and gpu. - Deleted the outdated `test.sh` script from the examples directory, streamlining the project structure. * Remove TVM subproject from 3rdparty directory * Refactor configuration generation and accumulation logic in AMD example - Reformatted the `get_configs` function for improved readability by aligning parameters. - Adjusted the `fast_flashattn` function to enhance clarity in the conditional logic for accumulation, ensuring better handling of causal conditions. * Enhance AMD CI workflow with additional logging and setup steps - Added echo statements to provide feedback during the CI process, indicating when the environment is running on an AMD GPU, copying necessary packages, and installing requirements. - Improved clarity in the workflow by explicitly stating when the project is being installed and when tests are being executed. * Comment out package copying in AMD CI workflow to prevent potential issues during environment setup * Update AMD CI workflow to install nightly versions of PyTorch and remove obsolete package copying steps * Enhance BuildTileLangHIP function by adding whitespace for improved readability * Refactor kTVMGridConstant definition for clarity and remove unnecessary comment * Update TVM subproject to latest commit a64a5926a6e59f5417ef2501f9d88b467337cf6a * lint fix * Update AMD CI workflow to use requirements-rocm.txt for dependency installation * fix ci * Remove dependency on format-check from AMD CI workflow * fix ci * fix ci * fix ci * Remove format-check job from AMD CI workflow * Add torch to requirements-rocm.txt and remove explicit pip install commands from AMD CI workflow * Add dependency on format-check job in AMD CI workflow * Add format-check job to AMD CI workflow * Update format-check job in AMD CI workflow to run on self-hosted environment * Enhance format-check job in AMD CI workflow with improved Python environment setup and automatic commit of lint changes * Update amd_ci.yml --------- Co-authored-by: xinxyxiao <xinyxiao@amd.com> Co-authored-by: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Co-authored-by: LeiWang1999 <leiwang1999@outlook.com> * [Carver][Bugfix] Correct score function for warp tile selection in tensorcore policy (#724) * [Carver][Bugfix] Correct score function for warp tile selection in tensorcore policy * [Typo] Correct architecture selection for CUDA and CDNA * [Refactor] Refactor CUDA code generation to simplify eviction policy handling (#721) * Update submodule 'tvm' to commit e11521e6936a827efa334588d29571fbb4620107 * Refactor CUDA code generation to simplify eviction policy handling - Updated `VisitExpr_` methods in `codegen_cuda.cc` to use default eviction policy for `tma_load`, `tma_load_im2col`, and `tma_store` functions, reducing complexity. - Removed conditional assembly code for `EVICT_NORMAL` in `copy_sm90.h`, streamlining the assembly calls for tensor memory operations. * lint fix * [Language] Introduce `StridedTensor` to support non contigious torch inputs (#722) * Update submodule 'tvm' to commit e11521e6936a827efa334588d29571fbb4620107 * Support strided tensors * Refactor target attribute helper functions for improved clarity * No code changes made in proxy.py and setup.py * lint fix * lint fix via gemini * lint fix * test fix * test fix * lint fix * Update wrapper.py * test fix * Enhance test for InjectSoftwarePipeline by adding LowerOpaqueBlock transformation and updating expected function signature to use match_buffer for better clarity. * lint fix --------- Co-authored-by: Chenggang Zhao <chenggangz@deepseek.com> * [Enhancement][Bugfix] Fix bug in warp specialized pass and add gemm_sr fallback support for Hopper (#712) * bug fix and support gemm_sr fallback for hopper * Update gemm.cc --------- Co-authored-by: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Co-authored-by: LeiWang1999 <leiwang1999@outlook.com> * 📝 Add docstrings to `fix` (#726) Docstrings generation was requested by @LeiWang1999. * https://github.com/tile-ai/tilelang/pull/712#issuecomment-3190680851 The following files were modified: * `src/op/gemm.cc` * `src/tl_templates/cuda/gemm_sm90.h` * `src/transform/warp_specialized_rewriter.cc` Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> * [CI] Fix AMD CI (#729) * [Enhancement] Refactor buffer index handling for improved precision and clarity (#668) - Enhanced buffer index handling to address precision issues by removing redundant operations. - Streamlined the logic for determining buffer overlaps, ensuring more accurate conflict detection. - Updated related documentation to reflect changes in buffer management practices. * Remove obsolete test script for AMD example, streamlining the examples directory. * Remove unused dtype_size variable in AMD example script to streamline code. * Add input configuration file and update AMD example script for enhanced flexibility - Introduced a new input.txt file for configurable parameters. - Modified the example_amd_flash_attn_fwd.py script to allow for a wider range of configurations, including additional options for num_stages, enable_rasterization, and k_pack. - Streamlined the main function for better clarity and organization. - Added a new test script to facilitate running the example with specified parameters. * Remove input configuration file and obsolete test script; enhance AMD example with swizzle layout annotations - Deleted input.txt and test.sh files as they are no longer needed. - Updated example_amd_flash_attn_fwd.py to include swizzle layout annotations for shared memory, improving bank conflict avoidance. - Reintroduced swizzle usage in the kernel for better performance. * Refactor AMD example script for FlashAttention-2 - Updated function names for clarity, changing `get_v2_configs` to `get_configs` and `fast_flashattn_v2` to `fast_flashattn`. - Streamlined the main function by renaming `main_v2` to `main` and adjusting the corresponding calls. - Removed outdated comments and improved code organization for better readability. * Refactor formatting in AMD FlashAttention example script - Improved code readability by adjusting line breaks and indentation in the `fast_flashattn` function. - Streamlined the `main` function parameter formatting for consistency. - Removed unnecessary blank lines to enhance overall code organization. * Update example_amd_flash_attn_fwd.py * Enhance AMD example script and update CI workflows - Improved the `example_amd_flash_attn_fwd.py` script for better clarity and organization. - Added new CI workflows for AMD and documentation publishing. - Updated various requirements files to include necessary dependencies. - Introduced new test cases and examples for better coverage and functionality. - Refactored existing code for improved readability and maintainability. * Remove redundant tool cache cleanup step in AMD CI workflow * Remove `torch` dependency from `requirements-rocm.txt` to streamline requirements. --------- Co-authored-by: xinxyxiao <xinyxiao@amd.com> Co-authored-by: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> * [Feature] Low-bit twiddling dequantization and FP4 GEMM (#725) * [Dequant] Add bit-twiddling dequantize cuda for fp4-->bf16 * [Dequant] Add extern call and serial dequantization * [Dequant] Parallel Dequant wait for fence debug. * [Scale] Add scale matrix to mxfp4 gemm * [Remove] Remove fence-buggy example and some generated source cuda code * [MXFP4] Update initial version of MXFP4 GEMM * [Scale] Add scale to latest mxfp4 gemm * [Lint] * [BugFix] Load Scale, disabe TMA to recover performance * [Lint] * [Lint] * [Scale] Use L2 to hold Scale and enable TMA will slightly boost performance * [Lint] * Update example_dequant_gemm_bf16_fp4_hopper_serial.py * Remove deprecated dequantization examples for BF16 and MXFP4 in the dequantize_gemm directory. * Refactor dequantization examples for improved readability and consistency. Adjusted formatting in matmul function and added spacing for clarity. Updated function signatures and comments for better understanding. * Refactor index_to_coordinates usage in bitnet example and update dequantization example configurations. Removed the custom index_to_coordinates function and replaced it with the built-in version. Adjusted block_K parameter in dequantization example for consistency. * lint fix * ci fix * Remove non-existent example * [BugFix] Add smem swizzle to recover performance of TMA * [BugFix] Enough reg for producer when threads=512 --------- Co-authored-by: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Co-authored-by: LeiWang1999 <leiwang1999@outlook.com> * 📝 Add docstrings to `mxfp4` (#732) * 📝 Add docstrings to `mxfp4` Docstrings generation was requested by @LeiWang1999. * https://github.com/tile-ai/tilelang/pull/725#issuecomment-3191656561 The following files were modified: * `examples/bitnet-1.58b/kernel_benchmark/tilelang_bitnet_158_int8xint2_prefill.py` * `examples/dequantize_gemm/example_dequant_gemm_bf16_fp4_hopper.py` * `examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper.py` * `examples/dequantize_gemm/utils.py` * `examples/gemm/example_gemm_autotune.py` * `tilelang/intrinsics/utils.py` * `tilelang/language/__init__.py` * `tilelang/language/utils.py` * `tilelang/quantize/mxfp.py` * `tilelang/quantize/quantization.py` * [Lint] More accurate docstring * [Lint] --------- Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Co-authored-by: tzj-fxz <tzjfxz@gmail.com> * [Refactor] Refactor env into a more flexible version (#740) * Fix environment variable name for compilation print setting in `env.py` * Remove deprecated test file for warp specialized pass configuration and refactor environment variable access in `env.py` to utilize a centralized `EnvVar` class for better management and clarity. * lint fix * Refactor cache check to use `env.is_cache_enabled()` for consistency in `tuner.py` * [Enhancement] Add stride index validation in CythonKernelWrapper (#743) * Introduced an assertion to ensure that the stride index is within the valid range of tensor dimensions in `cython_wrapper.pyx`. * This change prevents potential out-of-bounds errors when accessing tensor dimensions, enhancing the robustness of the code. * [Bugfix]:Fix atomic add auto vectorize memory access out of bound error (#742) * [Bugfix]:Fix atomic add auto vectorize memory access out of bound error * Update atomicadd_vectorize.cc * format * 📝 Add docstrings to PR #744 (#745) * 📝 Add docstrings to `main` Docstrings generation was requested by @LeiWang1999. * https://github.com/tile-ai/tilelang/pull/742#issuecomment-3205103559 The following files were modified: * `src/transform/atomicadd_vectorize.cc` * lint fix --------- Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Co-authored-by: LeiWang1999 <leiwang1999@outlook.com> * [Refactor] Refactor barrier management (#744) * Introduce Barrier * Enhance CUDA kernel with new barrier management and post-processing support - Added a new CUDA kernel implementation in `example_mla_decode.py` for improved performance with shared memory barriers. - Refactored barrier handling in `codegen_cuda.cc` and `codegen_hip.cc` to utilize a more flexible mbarrier structure. - Updated intrinsic definitions from `ptx_stmatirx` to `ptx_stmatrix` across multiple files for consistency. - Introduced additional print statements for debugging in the lowering phase of the TileLang engine. - Enhanced the overall structure and readability of the codebase. * Remove unused barrier handling code in CUDA and HIP code generators to streamline the implementation. This change enhances code clarity and reduces complexity in the barrier management logic. * Enhance barrier management in TileLang - Introduced a new intrinsic `allocate_barrier` for dynamic barrier allocation in the TileLang framework. - Updated CUDA code generation to support the new barrier structure, allowing for improved synchronization in shared memory. - Refactored existing barrier handling logic to accommodate the new intrinsic and streamline code. - Added print statements for debugging purposes in various examples and the lowering phase of the TileLang engine. - Removed deprecated memory scope handling code to enhance clarity and maintainability. * lint fix * lint fix * Remove `allocate_barrier` intrinsic and related code from TileLang to streamline barrier management. This includes updates to CUDA code generation and the removal of associated Python wrappers, enhancing code clarity and maintainability. * Refactor logging in JITKernel to improve kernel compilation tracking - Removed unused import of `torch.backends` in the example file. - Introduced logging for kernel compilation in `JITKernel`, replacing print statements with structured logging for better traceability and debugging. - Added an assertion to ensure the presence of the `global_symbol` attribute in the kernel function. * Refactor dequantization tests and update barrier function - Removed the test for `example_dequant_gemm_bf16_fp4_hopper_serial` to streamline the testing suite. - Updated the `mbarrier_cp_async_arrive` function to support both pointer and non-pointer types, enhancing flexibility in barrier management. * Update CI configuration to increase pytest parallelism from 4 to 8 threads for improved test execution speed. * Fix typos in rasterization parameters and update import path for cached module - Corrected the spelling of `enable_rasteration` to `enable_rasterization` in the matmul function and its usage. - Updated the import statement for the `cached` module to reflect the new path in the cache submodule. - Added `StridedTensor` import in the language module for enhanced tensor functionality. * Update ci.yml * [Refactor] Merge bulk copy into copy and improve layout inference for bulk copy (#746) * [Refactor] Merge bulk copy into copy and refactor layout inference for bulk copy * Deleted the `bulk_copy` operator implementation and its header file as it is no longer needed. * Introduced a new function `cuTensorMapType()` to return the data type for CUDA tensor mapping. * Updated related files to reflect these changes, ensuring that the codebase remains clean and maintainable. * lint fix * Fix typos in intrinsic names and remove unused print statement in block_sparse_attn_tilelang.py. Updated references from `ptx_ldmatirx` to `ptx_ldmatrix` across multiple files for consistency. * remove bulk copy * Refactor copy and atomic add operations to support TMA lower configuration - Updated `GetCopyInst` to accept a `disable_tma_lower` parameter, allowing for conditional usage of TMA in bulk load/store operations. - Modified `Lower` method in `Copy` to incorporate the new TMA configuration. - Refactored `AtomicAdd::Lower` to streamline layout inference and vectorization logic. - Removed unused `disable_tma_lower` field from `LowerArgs` structure for clarity. - Enhanced atomic add vectorization by replacing the buggy implementation with a more robust loop vectorization approach. * Enhance TMA bulk copy logic in `LowerBulkCopy` method - Added a condition to set `desc.swizzle` to `CU_TENSOR_MAP_SWIZZLE_NONE` when `shared_layout` matches `linear_layout`, improving clarity in layout handling. - Updated warning log to provide more detailed information about fallback scenarios, including source and destination buffer names and shapes, enhancing debugging capabilities. * lint fix * Remove fallback logging for non-swizzled global layout in `LowerBulkCopy` method to streamline the bulk copy logic. This change enhances code clarity by eliminating unnecessary warning messages related to inner box dimensions. * Enhance reshape kernel compilation in `run_reshape` and `run_reshape_smem_1d_2_2d` functions - Updated the `tl.compile` method to include `pass_configs` that disable TMA lower and warp specialization, addressing shared memory layout transformation limitations. - Added TODO comments to indicate the need for further improvements in shared memory handling. * Update `native_sparse_attention` function to include TMA configuration options - Added `pass_configs` to the JIT decorator to disable TMA lower and warp specialization, addressing potential issues with shared memory layout transformations. - Updated comments to clarify modifications in tensor shapes for inference, specifically setting `q` sequence length to 1. * Refactor JIT decorator formatting in `native_sparse_attention` function - Improved readability by reformatting the JIT decorator parameters for `native_sparse_attention`, ensuring consistent style across the codebase. - No functional changes were made; this update focuses on code clarity and maintainability. * Enhance thread management and logging in TileLang compilation - Added a method to check if printing is enabled during compilation, improving control over logging behavior. - Updated the JIT kernel class to utilize the new method for logging compilation status, ensuring consistent and clear output. - Added comments to clarify the purpose of changes and improve code readability. * Add warp specialization scope and refactor register management in TileLang - Introduced a new constant `kWarpSpecializationScope` in `builtin.h` for better attribute management. - Removed the `SetMaxNRegCollector` class and its related logic from `warp_specialized_rewriter.cc`, streamlining the warp specialization process. - Added functions `annotate_producer_reg_dealloc` and `annotate_consumer_reg_alloc` in `builtin.py` to facilitate register management. - Implemented `AnnotateWarpGroupRegAlloc` in `__init__.py` to inject register allocation calls into warp-specialized functions, enhancing the overall register handling in the compilation process. * Refactor test for InjectSetMaxNReg pass in TileLang - Improved readability by restructuring conditional checks and assertions in the test cases. - Enhanced clarity in the collection of `set_max_nreg` calls by simplifying the logic. - Ensured consistent formatting and spacing throughout the test functions for better maintainability. * Enhance bulk copy and store checks in `Copy` class - Updated scope validation for source and destination tensors in `CheckBulkLoad` and `CheckBulkStore` methods to include both `shared.dyn` and `shared` as valid options. - Modified `CheckLDSMCopy` and `CheckSTSMCopy` methods to accommodate the new scope validation, ensuring compatibility with shared memory configurations. - Improved logging in `LowerBulkCopy` to provide clearer warnings regarding unsupported swizzle layouts, including source and destination names for better debugging. * lint fix * [Refactor] Merge ThreadPartialSync and ThreadStorageSync (#741) * Remove `thread_partial_sync.cc` and refactor `thread_storage_sync.cc` to streamline synchronization handling. Introduce `thread_sync_types.h` for thread-bound key definitions and reserved named barriers. Update related logic in `ThreadSyncInserter` and `TileLangThreadSync` for improved clarity and efficiency. * Remove `sync_thread_partial` references and related documentation from the codebase. Update CUDA and HIP code generation files to eliminate calls to the removed function. Refactor `__sync_thread_partial` to `sync_thread_partial` in CUDA common header for consistency. * Remove unused import of `bulk_copy.h` in `codegen_hip.cc` to enhance code clarity and maintainability. * Add import of `bulk_copy.h` in `codegen_hip.cc` to support new functionality. * typo fix * Update data type in reduce_sum tests from float16 to float32 for consistency and clarity. Remove redundant dtype tests and streamline run functions. Enhance reshape kernel compilation with pass configurations to address shared memory layout issues. * lint fix * test fix * Enhance CI configuration by adding verbose output to pip install command for better visibility during installation. * use ninja instead of make * Add CMake configuration step for Ninja build system in setup.py * Update pyproject.toml to include additional build dependencies: build, torch, tox, auditwheel, patchelf, and ninja. * Enhance CI configuration by adding verbose output to pytest commands for improved test visibility. * Update pyproject.toml to add Cython as a build dependency. Enhance thread storage synchronization in thread_storage_sync.cc by introducing new thread variable handling and improving index disjointness checks. * Update data type in cumulative sum tests from float16 to float32 for consistency. Modify run_cumsum function to utilize the updated dtype and enhance result validation with assertions. Adjust test cases accordingly. * Refactor storage access handling by introducing buffer data mapping in TileLangStorageAccessVisitor. Enhance access entry structure to include pointer access flag. Update thread storage synchronization to accommodate new buffer data mappings. Adjust quickstart example to print kernel source for debugging purposes. * Refactor linear index conversion in TileLangStorageAccessVisitor to utilize the analyzer for simplification. Update buffer index calculations to ensure consistent simplification of range expressions. * bugfix * Refactor buffer index calculation in TileLangStorageAccessVisitor to simplify access handling. Removed unused buffer mapping logic, ensuring consistent buffer index generation with a default ramp. * Refactor TileLangStorageAccessVisitor to replace buffer indices with buffer ranges for improved pointer access handling. Update AccessEntry structure to include buffer_ranges and adjust thread storage synchronization logic to account for pointer access conflicts. * Refactor thread storage synchronization to replace 'shared.dyn' with 'shared' for consistency in memory allocation. Update related test cases to reflect this change and ensure proper functionality. * [Enhancement] Optimize loop body handling in IR (#749) - Updated the loop body construction in `ir.cc` to conditionally include an output statement based on the analyzable condition of the `waves` variable. - This change enhances performance by avoiding unnecessary statement wrapping when the condition is met, improving the efficiency of loop execution. Co-authored-by: LeiWang1999 <leiwang1999@outlook.com> * [MXFP4] Fix bugs and optimize exponential operation (#750) * [MXFP4] Fix bugs - Optimize exp2 with shift operation to boost performance - Fix bug of simple dequantization function call - Fix bug of scaling factor with bias * [Lint] --------- Co-authored-by: LeiWang1999 <leiwang1999@outlook.com> * [Enhancement] Add DispatchInstruction specialization for fp8 types in gemm_sm90.h (#751) - Introduced specialized DispatchInstruction templates for fp8_e4_t and fp8_e5_t types, enhancing support for new data formats in CUDA GEMM operations. - Each specialization defines the corresponding MMA and MMA_Group types, optimizing performance for specific configurations. * [Enhancement] Add shape checking for reduce options (#748) * Add shape checking for reduce options * lint fix * Handle special case reducing into shape-1 tensor Allow reducing [X, d, Y] into [X, Y] or [X, 1, Y] --------- Co-authored-by: LeiWang1999 <leiwang1999@outlook.com> * [Bugfix] Add missing FP8 header include (#752) * [Enhancement] Add DispatchInstruction specialization for fp8 types in gemm_sm90.h - Introduced specialized DispatchInstruction templates for fp8_e4_t and fp8_e5_t types, enhancing support for new data formats in CUDA GEMM operations. - Each specialization defines the corresponding MMA and MMA_Group types, optimizing performance for specific configurations. Co-authored-by: LeiWang1999 <leiwang1999@outlook.com> * [Enhancement] Include cuda_fp8.h in gemm_sm90.h - Added the inclusion of the "cuda_fp8.h" header file to support new data formats in CUDA GEMM operations, enhancing compatibility with recent updates for fp8 types. Co-authored-by: LeiWang1999 <leiwang1999@outlook.com> * lint fix * [Refactor] Remove unused tl_shuffle_elect and related functions from common.h - Deleted the `tl_shuffle_elect` function and its associated comments to streamline the codebase. - Added inclusion of "intrin.h" for improved intrinsic support in CUDA operations. - Cleaned up the file by removing unnecessary template parameters and functions, enhancing clarity and maintainability. * lint fix * [Refactor] Update header inclusions in common.h and gemm_sm90.h - Removed the inclusion of "intrin.h" from common.h to streamline dependencies. - Added "intrin.h" inclusion in gemm_sm90.h to ensure intrinsic support for CUDA operations, enhancing functionality and maintainability. * bug fix * [MXFP4] Add bias to MXFP4 GEMM kernel (#753) * [MXFP4] Add bias to gemm kernel * [Lint] * [Lint] Rename "bias" to "Bias" * [Bugfix][WS] Consider loop min extent when computing phase id (#754) * Update test parameters and remove debug print statement - Adjusted test cases in `test_tilelang_dynamic_symbolic_bench.py` to use smaller matrix sizes (1024x1024) for improved performance and quicker execution. - Removed a debug print statement from `phase.py` to clean up the code and enhance clarity. * Refactor loop stack management in warp_specialized_rewriter - Introduced a new `LoopInfo` struct to encapsulate loop variable details, including `loop_var`, `extent`, and `min`, enhancing clarity and maintainability. - Updated the `loop_stack_` to utilize `LoopInfo` instead of a pair, improving type safety and readability. - Adjusted linear index calculations to account for the new structure, ensuring correct behavior in loop transformations. * [Typo] Remove `disable_cache` in some tests (#755) * Update test parameters and remove debug print statement - Adjusted test cases in `test_tilelang_dynamic_symbolic_bench.py` to use smaller matrix sizes (1024x1024) for improved performance and quicker execution. - Removed a debug print statement from `phase.py` to clean up the code and enhance clarity. * Refactor loop stack management in warp_specialized_rewriter - Introduced a new `LoopInfo` struct to encapsulate loop variable details, including `loop_var`, `extent`, and `min`, enhancing clarity and maintainability. - Updated the `loop_stack_` to utilize `LoopInfo` instead of a pair, improving type safety and readability. - Adjusted linear index calculations to account for the new structure, ensuring correct behavior in loop transformations. * Remove unused `torch.backends` import and `tilelang.disable_cache()` calls from multiple test files to enhance code clarity and maintainability. * [README] Update GDN README for clarity and add acknowledgements (#758) - Improved formatting and clarity of the GDN kernel implementation description. - Updated requirement section to list dependencies in a clearer format. - Added an acknowledgements section to credit the developers and the Xiaomi LLM-Core Team for their contributions. * cutlass v4.2.0 supporting cuda 13 (#760) * [Feature] Add 1D TMA support (#761) * [Feature] Add 1D TMA support - Check the contiguous conditions of 1D TMA copy - Add new interface and params order of `tma_load` and `tma_store` call - Add 1D `tma_store` interface in sm90 template - Add elementwise kernel for 1D TMA example * [Lint] * [BugFix] Add conditions for 1D TMA copy on non-swizzle shared tensors * [Lint] * [BugFix] 1D TMA load * [README] Update GDN README for clarity and add acknowledgements (#758) - Improved formatting and clarity of the GDN kernel implementation description. - Updated requirement section to list dependencies in a clearer format. - Added an acknowledgements section to credit the developers and the Xiaomi LLM-Core Team for their contributions. * cutlass v4.2.0 supporting cuda 13 (#760) * [Lint] * [Lint] * [MXFP4] Add test for bf16&mxfp4 gemm * [BugFix] * [Lint] --------- Co-authored-by: Yu Cheng <54519279+chengyupku@users.noreply.github.com> Co-authored-by: Johnny <johnnync13@gmail.com> * [Example] Add vertical slash sparse attention pattern (#762) * upd sparse attn * lint * rename * update test file * update benchmark * lint * update benchmark * [Bugfix] Address PassContext contamination from CI and fix incorrect rewrites in warp specialized pass (#767) * fix ci and pass bug * fix * try * lint * [MXFP4] Add 1D TMA copy for Scale tensor in MXFP4 GEMM (#766) * [TMA] Add 1D TMA copy for Scale tensor * [Lint] * [Test] Add test for kernel * [BugFix] * hot fix blackwell (#768) * [Refactor] Refactor `Operator` into `TileOperator` and with tvm reflection (#763) * Refactor operator classes to inherit from TileOperator and update layout inference methods - Changed base class of several operator classes (AtomicAdd, Copy, Gemm, etc.) from Operator to TileOperator for better alignment with tile operations. - Updated InferLayout and Lower methods to use 'override' specifier for clarity and consistency. - Adjusted header inclusions to replace "op.h" with "operator.h" across multiple files for improved organization. - Added missing layout inference implementations for Fill and Conv2DIm2ColOp. - Removed deprecated op.cc and op.h files to streamline the codebase. * lint fix * Refactor operator classes to use Node pattern and improve memory management - Updated several operator classes (AtomicAdd, Copy, Gemm, etc.) to utilize the Node pattern for better memory management and encapsulation. - Changed constructors to initialize member variables through a node object, enhancing clarity and reducing direct member access. - Updated Clone methods to return TileOperator instances instead of unique pointers, aligning with the new design. - Refactored InferLayout and Lower methods to ensure consistency across operator implementations. - Adjusted header files to reflect the new class structure and removed deprecated code for a cleaner codebase. * Enhance Clone methods in AtomicAdd and Copy classes to support parallel operation cloning - Updated the Clone methods in AtomicAddNode and CopyNode to ensure that the parallel operation (par_op_) is properly cloned when defined, improving the integrity of cloned objects. - Refactored the FillNode class to use ParallelOp directly instead of std::make_unique, streamlining the creation of parallel operations. - Made minor adjustments in layout inference and other related methods for consistency and clarity. * Refactor FillNode::Lower method to remove unused global function call - Eliminated the call to the global function "tl.fill.lower" in the FillNode::Lower method, streamlining the code and improving clarity. - Retained the core functionality of the method while enhancing maintainability by reducing unnecessary dependencies. * [Reducer] Introduce `alloc_reducer` to separate inter and intra warp reduction (#757) * [Enhancement] Introduce finalize_reducer operator and layout reducer support - Added `FinalizeReducer` operator to handle reduction finalization in the TileLang framework, allowing for efficient reduction operations. - Implemented layout inference for local.reducer buffers, enhancing the handling of layout mappings and reducing complexity in buffer management. - Updated `setup.py` to include logging for build directory paths, improving build process visibility. - Enhanced atomic operations with new functions for atomic max, min, load, and store, providing more robust atomicity control in memory operations. - Refactored parallel loop handling to incorporate reducer information, ensuring proper management of reduction operations in parallel contexts. - Cleaned up test cases by removing unnecessary cache disabling and optimizing test parameters for better performance. * Refactor code formatting and improve readability in multiple files - Cleaned up whitespace in `setup.py` to enhance logging clarity. - Reformatted `AtomicMax` and `AtomicMin` functions in `common.h` for better alignment and readability. - Adjusted `debug_print_var` function in `debug.h` to improve code structure and maintainability. - Enhanced readability of the `atomic_add` function in `customize.py` by breaking long lines for better clarity. * Remove debug print statements from `copy.cc` and `inject_tma_barrier.cc` to enhance code clarity and maintainability. * [Enhancement] Disable reuse of small arrays in shared memory allocation - Added logic to prevent the reuse of small arrays (<= 32 bits) in `merge_shared_memory_allocations.cc`, ensuring they are lowered to registers in LLVM for improved performance and memory management. * Refactor `setup.py` to remove duplicate logging statements and enhance clarity. Update `finalize_reducer` function documentation in `reduce.py` to include detailed parameter and return descriptions, improving code readability and maintainability. * Refactor `finalize_reducer` and `reduce` functions to remove redundant target checks. Simplified conditionals by retaining only the `TargetIsHopper` check, enhancing code clarity and maintainability. * bug fix * Add thread checks workaround for replicated cases * Remove the is_one check * fix lint error * lint fix * Update autotune tests to use smaller matrix sizes for improved performance and reliability * [Refactor] Update FinalizeReducer to FinalizeReducerOp and adjust related methods - Refactored FinalizeReducer class to FinalizeReducerOp, updating constructor and method signatures for consistency with the new TileOperator structure. - Enhanced layout inference and cloning methods in FinalizeReducerOpNode. - Updated test_example_flash_attention.py to call test_example_gqa_bwd instead of tilelang.testing.main. - Adjusted header inclusions for improved organization and clarity across multiple files. * [Refactor] Update atomic operations in common.h and modify test_example_flash_attention.py - Enhanced atomic operations (Add, Min, Max) in common.h to handle half and bfloat16 types more efficiently. - Updated test_example_flash_attention.py to call test_example_gqa_bwd instead of tilelang.testing.main, improving test organization. * [Refactor] Simplify CopyNode::LowerBulkCopy logic and update test execution - Removed redundant checks for contiguous memory access in CopyNode::LowerBulkCopy, streamlining the logic for TMA copy operations. - Updated test_tilelang_kernel_gemm.py to comment out the main testing function and call a specific test for i8i8i32 tensor operations instead, improving test focus. --------- Co-authored-by: Huanqi Cao <caohuanqi@deepseek.com> Co-authored-by: Freebase6912 <amid-gauze-racing@duck.com> * 📝 Add docstrings to `pytile_0826` (#770) * 📝 Add docstrings to `pytile_0826` Docstrings generation was requested by @LeiWang1999. * https://github.com/tile-ai/tilelang/pull/763#issuecomment-3224197814 The following files were modified: * `src/op/atomic_add.cc` * `src/op/atomic_add.h` * `src/op/copy.cc` * `src/op/copy.h` * `src/op/elem.cc` * `src/op/elem.h` * `src/op/gemm.cc` * `src/op/gemm.h` * `src/op/gemm_sp.cc` * `src/op/gemm_sp.h` * `src/op/operator.cc` * `src/op/operator.h` * `src/op/parallel.cc` * `src/op/parallel.h` * `src/op/reduce.cc` * `src/op/reduce.h` * `src/op/region.cc` * `src/op/region.h` * `src/transform/layout_inference.cc` * `src/transform/lower_tile_op.cc` * lint fix --------- Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Co-authored-by: LeiWang1999 <leiwang1999@outlook.com> * [Bugfix]:Fix atomic add auto vectorize negative optimization (#765) * [Bugfix]:Fix atomic add auto vectorize negative optimization * fixbug * format * fix bug * 📝 Add docstrings to `reducer_0825` (#772) * 📝 Add docstrings to `reducer_0825` Docstrings generation was requested by @LeiWang1999. * https://github.com/tile-ai/tilelang/pull/757#issuecomment-3219088118 The following files were modified: * `setup.py` * `src/op/builtin.h` * `src/op/finalize_reducer.cc` * `src/op/finalize_reducer.h` * `src/op/parallel.cc` * `src/op/parallel.h` * `src/op/reduce.cc` * `src/target/codegen_cuda.cc` * `src/tl_templates/cuda/common.h` * `src/transform/layout_inference.cc` * `src/transform/layout_reducer.cc` * `src/transform/layout_reducer.h` * `src/transform/merge_shared_memory_allocations.cc` * `src/transform/storage_access.cc` * `src/transform/warp_specialized_rewriter.cc` * `testing/python/autotune/test_tilelang_autotune_with_inputs.py` * `tilelang/engine/phase.py` * `tilelang/language/customize.py` * `tilelang/language/reduce.py` * `tilelang/transform/__init__.py` * lint fix * lint fix --------- Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Co-authored-by: LeiWang1999 <leiwang1999@outlook.com> * Allow fill global buffer (#774) * Allow fill global buffer * fix lint error * [BugFix] Refactor the op check in LowerTileOp pass using the member function instead of string match (#771) * [BugFix] Refactor the op check in LowerTileOp pass using the member function instead of string match * [Lint] * add bf16 exp fallback (#776) * [Lint] Introduce clang-tidy into format.sh (#777) * [Refactor] Update Clang-Tidy Checks and Improve Code Consistency - Enhanced .clang-tidy configuration by adding specific checks for better bug detection and performance optimization. - Refactored function signatures across multiple files to use `const` references for parameters, improving performance and code clarity. - Updated various methods to ensure consistent handling of parameters, particularly in `AddPredicate`, `Substitute`, and `PlanLoopPartition` functions. - Improved readability by replacing size checks with `empty()` method calls in several locations, ensuring clearer intent in the code. - General code cleanup and adherence to best practices for better maintainability. * [Refactor] Enhance Code Consistency and Clang-Tidy Configuration - Updated .clang-tidy configuration to include additional checks for improved code quality and performance. - Refactored function signatures across multiple files to use `const` references, enhancing performance and clarity. - Replaced size checks with `empty()` method calls in various locations for clearer intent. - Improved handling of parameters in several functions, ensuring consistent usage of `std::move` where applicable. - General code cleanup to adhere to best practices and improve maintainability. * [Refactor] Integrate Clang-Tidy Checks and Enhance Code Consistency - Added clang-tidy checks to the format script for improved code quality assurance. - Refactored function signatures across multiple files to consistently use `const` references, enhancing performance and clarity. - Updated the requirements-lint.txt file to include clang-tidy as a dependency. - General code cleanup to adhere to best practices and improve maintainability. * [CI] Update AMD CI Workflow to Include Build Directory Creation - Added steps to create a build directory and configure CMake with ROCm support during the format check process. - Ensured cleanup of the build directory after the format check to maintain a clean workspace. * [Refactor] Remove Unused Member Variables in AtomicAddNode and CopyNode - Removed the `args_` member variable from both `AtomicAddNode` and `CopyNode` classes to streamline the code and eliminate unnecessary data members. - This change enhances code clarity and maintainability by focusing on relevant attributes for each class. * [Refactor] Update Clang-Tidy Integration and Code Improvements - Modified the format script to include the `-fix` option in the clang-tidy command for automatic code fixes. - Refactored the `AtomicAddVectorizePlanner` class to improve variable handling and consistency, including changes to member variable types and function signatures. - Enhanced code clarity by removing unnecessary `std::move` calls and ensuring consistent usage of types across the class. - General code cleanup to adhere to best practices and improve maintainability. * [Refactor] Improve Parameter Handling and Consistency in AtomicAddVectorize - Updated function signatures in `AtomicAddVectorizePlanResult` and `AtomicAddVectorizeRewriter` to use `const` references and `std::move` for better performance and clarity. - Enhanced the `UpdateVectorSize` method to accept `const Array<PrimExpr>&` for improved efficiency. - General code cleanup to maintain consistency and adhere to best practices. * [CI] Add Git Submodule Initialization to CI Workflow - Included a step to initialize and update git submodules recursively in the CI workflow. - This change ensures that all necessary submodules are available during the format check process, improving build reliability. * [CI] Add Git Submodule Update Step to Format Check - Included a command to initialize and update git submodules recursively in the CI workflow during the format check process. - This enhancement ensures that all required submodules are available, contributing to improved build reliability. * [Refactor] Update Function Signatures in AtomicAddVectorize - Modified the `VectorizeAtomicAdd` function signature to use `const` references for `thread_var` and `thread_bounds`, enhancing performance and code clarity. - This change aligns with previous refactoring efforts to improve parameter handling and consistency across the codebase. * [Cache] Introduce detailed target information for the disk kernel cache (#780) * Fix type hint for target_host parameter in compile function to allow None value * Refactor target handling in compile function to utilize determine_target for improved clarity and consistency * Update PrintConst function in codegen_cuda.cc to use hexfloat format for bfloat16 and float8/float4 types, while adding scientific notation comments for clarity. This change enhances the representation of floating-point constants in the generated code. * Refactor PrintType function in codegen_cuda.cc to remove unnecessary failure conditions for floating-point types with lane counts greater than 4. This change simplifies the logic and improves code clarity. * Enhance benchmark_matmul.py to conditionally print Reference TFlops only if ref_latency is not None. Update param.py to ensure target is converted to string for consistency. Refactor tuner.py to utilize determine_target for improved clarity in target handling. * Remove automatic commit and push step from AMD and NVIDIA CI workflows to streamline the process and avoid unnecessary commits. * [Example]Adds example for top-k operation (#775) * [Example]Adds example for top-k operation Adds an example demonstrating the top-k operation using tilelang * format * Adds topk tilelang example test * fix lint * [Math] Dispatch `T.rsqrt(x)` into cuda intrin instead of `1 / T.sqrt(x)` (#781) * Fix type hint for target_host parameter in compile function to allow None value * Refactor target handling in compile function to utilize determine_target for improved clarity and consistency * Update PrintConst function in codegen_cuda.cc to use hexfloat format for bfloat16 and float8/float4 types, while adding scientific notation comments for clarity. This change enhances the representation of floating-point constants in the generated code. * Refactor PrintType function in codegen_cuda.cc to remove unnecessary failure conditions for floating-point types with lane counts greater than 4. This change simplifies the logic and improves code clarity. * Enhance benchmark_matmul.py to conditionally print Reference TFlops only if ref_latency is not None. Update param.py to ensure target is converted to string for consistency. Refactor tuner.py to utilize determine_target for improved clarity in target handling. * Remove automatic commit and push step from AMD and NVIDIA CI workflows to streamline the process and avoid unnecessary commits. * Add intrin_rule source files to CMakeLists.txt and implement hrsqrt function for half_t in common.h * lint fix * remove cmake dep in pyproject as it may lead to different cmake paths in diff stages * lint fix * Add cmake dependency to pyproject.toml and improve build logging in setup.py * [CI] Adds pytest-durations for test timing (#782) * [Ci] Adds pytest-durations for test timing Adds `pytest-durations` to the test requirements and configures pytest to display test durations. This helps in identifying slow-running tests and optimizing the test suite for faster feedback. * add amd ci durations * Removes flash_attn installation from CI * [Refactor] Support python reflection for tile operators (#783) * Implement Fill operator and related reflection methods in TileLang - Added Fill operator implementation in `fill.cc` and `fill.h` for element-wise filling of buffers. - Introduced reflection methods for Fill, AtomicAdd, Copy, Conv2DIm2Col, FinalizeReducer, Gemm, and Parallel operators to enhance introspection capabilities. - Updated relevant files to register reflection methods and ensure proper initialization in static blocks. - Removed outdated comments and unnecessary code in various operator files to improve clarity and maintainability. - Added new Python bindings for the Fill operator in `tilelang/ir/fill.py` and updated the module imports accordingly. * Refactor operator reflection methods and improve code clarity - Updated reflection methods for AtomicAdd, Copy, FinalizeReducer, Gemm, and Parallel operators to enhance readability by using `empty()` instead of size checks. - Consolidated static initialization blocks for various operators to a single line for improved consistency. - Cleaned up whitespace and formatting in multiple files to adhere to coding standards and improve maintainability. - Added new Python bindings for operators in the `tilelang/ir` module, ensuring proper registration and organization of imports. * Refactor GEMM and AtomicAdd operations for improved clarity - Updated the `GetArchInt` function in `atomic_add.cc` to use `std::string` and `std::stoi` for better readability and type safety. - Removed unnecessary variables and comments in `gemm_sp.cc` and `gemm.cc` to streamline the `ComputeWarpPartition` method. - Cleaned up the `layout_reducer.cc` file by removing unused variable declarations, enhancing code clarity. - Added import for the `ir` module in `tilelang/__init__.py` to ensure proper organization of module imports. * Remove deprecated operator files from the tilelang IR module - Deleted files for Fill, AtomicAdd, Copy, Gemm, GemmSP, FinalizeReducer, Parallel, Reduce, and Region operators to streamline the codebase. - This cleanup enhances maintainability by removing unused code and improving overall organization of the module. * Refactor imports in tilelang IR module for improved organization - Updated import statements in `tilelang/ir.py` to reflect changes in the TVM library structure, enhancing clarity and maintainability of the codebase. * lint fix * Refactor GEMM and GEMM-SP operations to enhance clarity and maintainability - Updated the `Gemm` and `GemmSP` classes to utilize a new `GemmWarpPolicy` object for warp partitioning, improving encapsulation and readability. - Removed deprecated `ComputeWarpPartition` methods and replaced them with calls to the new policy object, streamlining the code. - Cleaned up comments and unnecessary code in `gemm.cc`, `gemm_sp.cc`, and related header files to enhance overall clarity. - Introduced a new `GemmWarpPolicyNode` class to manage warp policy attributes and methods, facilitating better organization of related functionalities. - Updated reflection methods to include the new policy structure, ensuring proper registration and introspection capabilities. * Refactor Reduce operation to utilize ReduceType class for improved clarity and maintainability - Replaced multiple conditional checks for reduce types with a single ReduceType object, simplifying the code structure. - Introduced a new ReduceTypeNode class to encapsulate reduce type logic and methods, enhancing organization. - Updated MakeInitValue, MakeReduce, and Lower methods to leverage the new ReduceType class, improving readability. - Added Python bindings for the ReduceType class in tilelang IR module to ensure proper registration and usability. * comment * Refactor operator header files for improved readability - Cleaned up formatting and whitespace in `atomic_add.h`, `copy.h`, `fill.h`, `reduce.cc`, and `reduce.h` to enhance code clarity. - Consolidated comments and adjusted line breaks for better organization and maintainability across multiple operator definitions. * Refactor MakeReduce method in ReduceOpNode for clarity - Updated the parameter name in the MakeReduce method from `rhs` to `b` and assigned it to `rhs` for improved readability. - This change enhances the clarity of the method's purpose and aligns with the overall refactoring efforts in the Reduce operation. * Update Reduce operation type checks for consistency - Changed string comparisons for reduce types in the MakeReduce method from "abs_sum" to "abssum" and "abs_max" to "absmax" for uniformity. - This adjustment enhances the clarity and consistency of the reduce type handling in the codebase. * [AMD] Fix amd tir&add examples (#784) * [Enhancement] Refactor buffer index handling for improved precision and clarity (#668) - Enhanced buffer index handling to address precision issues by removing redundant operations. - Streamlined the logic for determining buffer overlaps, ensuring more accurate conflict detection. - Updated related documentation to reflect changes in buffer management practices. * Remove obsolete test script for AMD example, streamlining the examples directory. * Remove unused dtype_size variable in AMD example script to streamline code. * Add input configuration file and update AMD example script for enhanced flexibility - Introduced a new input.txt file for configurable parameters. - Modified the example_amd_flash_attn_fwd.py script to allow for a wider range of configurations, including additional options for num_stages, enable_rasterization, and k_pack. - Streamlined the main function for better clarity and organization. - Added a new test script to facilitate running the example with specified parameters. * Remove input configuration file and obsolete test script; enhance AMD example with swizzle layout annotations - Deleted input.txt and test.sh files as they are no longer needed. - Updated example_amd_flash_attn_fwd.py to include swizzle layout annotations for shared memory, improving bank conflict avoidance. - Reintroduced swizzle usage in the kernel for better performance. * Refactor AMD example script for FlashAttention-2 - Updated function names for clarity, changing `get_v2_configs` to `get_configs` and `fast_flashattn_v2` to `fast_flashattn`. - Streamlined the main function by renaming `main_v2` to `main` and adjusting the corresponding calls. - Removed outdated comments and improved code organization for better readability. * Refactor formatting in AMD FlashAttention example script - Improved code readability by adjusting line breaks and indentation in the `fast_flashattn` function. - Streamlined the `main` function parameter formatting for consistency. - Removed unnecessary blank lines to enhance overall code organization. * Update example_amd_flash_attn_fwd.py * Enhance AMD example script and update CI workflows - Improved the `example_amd_flash_attn_fwd.py` script for better clarity and organization. - Added new CI workflows for AMD and documentation publishing. - Updated various requirements files to include necessary dependencies. - Introduced new test cases and examples for better coverage and functionality. - Refactored existing code for improved readability and maintainability. * Remove redundant tool cache cleanup step in AMD CI workflow * Remove `torch` dependency from `requirements-rocm.txt` to streamline requirements. * Add new AMD FlashAttention example and test script - Introduced `example_amd_flash_attn_bwd.py` for backward attention computation using TileLang. - Added `test.sh` script to facilitate running the new example with specified parameters. - Enhanced the overall structure and organization of the example for better clarity and usability. * Update configurations in `example_amd_flash_attn_fwd.py` for autotuner - Reduced the number of threads and `num_split_q` options for improved performance. - Adjusted `panel_size` options to streamline configuration settings. * Update submodule 'tvm' to commit 6ccc74f622c7ec4ac25d430d0f6546e7b9edb217 * Update submodule 'tvm' to commit 14ff70ab142b9e5a31bbf9c7923c8a697d41e86c * Add example for AMD Flash Attention backward pass implementation - Introduced a new example script `example_amd_flash_attn_bwd.py` demonstrating the forward and backward operations of Flash Attention using TileLang. - Implemented JIT-compiled functions for both forward and backward passes, including preprocessing and postprocessing steps. - Added a main function to facilitate testing and benchmarking of the attention mechanism with configurable parameters. - Included reference implementation for validation against PyTorch's attention mechanism. This addition enhances the examples directory by providing a comprehensive guide for users to understand and utilize Flash Attention in their applications. * Enhance AMD Flash Attention example with additional testing capabilities - Updated `example_amd_flash_attn_bwd.py` to include more comprehensive testing features for the Flash Attention implementation. - Improved the main function to allow for better parameter configuration and benchmarking. - Added validation checks against PyTorch's attention mechanism to ensure accuracy and reliability of the example. This update aims to provide users with a more robust tool for understanding and utilizing Flash Attention in their applications. * Update submodule TVM to commit a64a5926a6e59f5417ef2501f9d88b467337cf6a * Refactor HIP intrinsic rules to CUDA - Updated file name from `intrin_rule_hip.cc` to `intrin_rule_cuda.cc` to reflect the change in focus from HIP to CUDA intrinsic rules. - Adjusted include paths for better organization and clarity in the code structure. * Update AMD CI workflow to uninstall specific PyTorch packages before installation - Removed the installation of `flash_attn==2.5.8` to streamline the CI process. - Added a step to uninstall `torch`, `torchvision`, and `torchaudio` prior to installing pre-release versions, ensuring compatibility and reducing potential conflicts. * Remove unused shared memory allocations in AMD Flash Attention backward example - Eliminated the allocation of shared memory for `dv_shared` and `dk_shared` in `example_amd_flash_attn_bwd.py` to streamline memory usage and improve performance. - This change focuses on optimizing the backward pass implementation by reducing unnecessary memory overhead. * Remove unnecessary pip uninstall command from AMD CI workflow - Eliminated the step to uninstall `torch`, `torchvision`, and `torchaudio` in the AMD CI workflow, as it is no longer required for the installation of pre-release versions. - This change simplifies the CI process and reduces potential overhead during package management. * Refactor DispatchHIPWarpActiveMask function in HIP intrinsic rules - Updated the return statement to use std::string for concatenation in the case of 16-bit types, improving code clarity. - Added a null check for the CallNode pointer in DispatchHIPWarpActiveMask to enhance robustness and prevent potential dereferencing issues. * Refactor formatting of HIP intrinsic rule registrations - Adjusted the formatting of TVM_REGISTER_OP calls for better readability by aligning method chaining. - No functional changes were made; this update focuses on code style improvements to enhance maintainability. * Update file na…
Summary by CodeRabbit
New Features
Refactor
Chores
Documentation