Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions examples/deepseek_v32/fp8_lighting_indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,6 @@ def mqa_attn_return_logits_kernel(
cu_k_s_min = T.alloc_local([1], index_dtype)
cu_k_e_max = T.alloc_local([1], index_dtype)

T.no_set_max_nreg()

cu_k_s_min[0] = 2147483647
cu_k_e_max[0] = -2147483648

Expand Down
25 changes: 22 additions & 3 deletions src/transform/annotate_warp_group_reg_alloc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,27 @@ class SetMaxNRegCollector : public StmtExprVisitor {
bool warp_specialized_ = false;
};

class SimtCopyDetector : public StmtExprVisitor {
public:
static bool Detect(const Stmt &stmt) {
SimtCopyDetector detector;
detector.VisitStmt(stmt);
return detector.has_simt_copy_;
}

private:
void VisitStmt_(const BufferStoreNode *op) final {
auto scope =
runtime::StorageScope::Create(GetPtrStorageScope(op->buffer->data));
if (scope.to_string() != "global") {
has_simt_copy_ = true;
}
StmtExprVisitor::VisitStmt_(op);
}
Comment on lines +71 to +78
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Over-broad SIMT copy detection disables register hints everywhere.

SimtCopyDetector now sets has_simt_copy_ for every BufferStore whose scope is anything other than "global". In warp-specialized producers, the stores almost always target shared memory—even when they're the result of plain computation rather than a SIMT copy from global. As a consequence, SetMaxNRegInjector will see has_simt_copy == true for nearly all producer branches and skip re-inserting set_max_nreg, effectively turning the pass off and regressing performance.

Please tighten the detection to match actual SIMT copies (e.g., shared stores fed solely by global loads, mirroring the logic already used in WarpSpecializedRewriter) instead of treating every non-global store as a copy.

🤖 Prompt for AI Agents
In src/transform/annotate_warp_group_reg_alloc.cc around lines 71-78, the
detector currently marks any non-"global" BufferStore as a SIMT copy; change it
to only mark has_simt_copy_ when the store writes to non-global storage AND the
stored value is composed solely of loads from global buffers (mirror the
WarpSpecializedRewriter logic). Implement a small recursive helper that walks
op->value and returns true only if every BufferLoad encountered has
GetPtrStorageScope(load->buffer->data).to_string() == "global" (allow other pure
arithmetic/casts between those loads), and use that helper in the VisitStmt_
check; otherwise do not set has_simt_copy_. Ensure you still call the base
StmtExprVisitor::VisitStmt_(op).


bool has_simt_copy_{false};
};

class SetMaxNRegInjector : public StmtExprMutator {
public:
static PrimFunc Inject(PrimFunc f) {
Expand Down Expand Up @@ -113,9 +134,7 @@ class SetMaxNRegInjector : public StmtExprMutator {
auto dec_reg_stmt = Evaluate(0);

// Only inject if we have valid register hints and no SIMT copy
// For now, we assume no SIMT copy detection is available here
// TODO: Add SIMT copy detection if needed
bool has_simt_copy = false; // Placeholder
bool has_simt_copy = SimtCopyDetector::Detect(producer_body);

if (dec_reg >= 0 && inc_reg >= 0 && !has_simt_copy) {
auto inc_reg_num =
Expand Down
3 changes: 2 additions & 1 deletion tilelang/engine/phase.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,6 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod = tilelang.transform.MultiVersionBuffer()(mod)
mod = tilelang.transform.WarpSpecialized()(mod)
mod = tilelang.transform.InjectTmaBarrier()(mod)
mod = tilelang.transform.AnnotateWarpGroupRegAlloc()(mod)
# if tma is not enabled, we can also do pipeline planning
# to get better performance with async copy
mod = tilelang.transform.PipelinePlanning()(mod)
Expand Down Expand Up @@ -206,6 +205,8 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
# Inject PTX async copy must behind the thread sync pass
# as ptx async copy won't be recognized as a valid buffer load
mod = tilelang.transform.InjectPTXAsyncCopy()(mod)
if allow_tma_and_warp_specialized(pass_ctx=pass_ctx, target=target):
mod = tilelang.transform.AnnotateWarpGroupRegAlloc()(mod)
mod = tilelang.transform.MakePackedAPI()(mod)
mod = tilelang.transform.LowerDeviceKernelLaunch()(mod)

Expand Down
Loading