Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmark/matmul/benchmark_matmul_sp.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def main(

# Clear out the accumulation buffer
T.clear(C_local)
T.no_set_max_nreg()
T.disable_warp_group_reg_alloc()

T.use_swizzle(panel_size=10, enable=enable_rasterization)
T.annotate_layout({
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def main_no_split(
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
T.no_set_max_nreg()
T.disable_warp_group_reg_alloc()
loop_range = T.ceildiv(seqlen_kv, block_N)
for k in T.Pipelined(loop_range, num_stages=2):
T.copy(KV[bx, k * block_N:(k + 1) * block_N, cur_kv_head, :], qKV_shared)
Expand Down
13 changes: 9 additions & 4 deletions examples/deepseek_nsa/example_tilelang_nsa_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,14 @@
tilelang.testing.set_random_seed(42)


@tilelang.jit(out_idx=[-1])
# TODO(lei): workaround, as threads is not divisible by warp group size,
# auto warp specialization may have some bugs.
@tilelang.jit(
out_idx=[-1],
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
def native_sparse_attention(
batch,
heads,
Expand All @@ -22,7 +29,7 @@ def native_sparse_attention(
if scale is None:
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
head_kv = heads // groups
# Modified shapes for inference (q has seq_len=1)
# Modified shapes for inference (q has seq_len=1)a
q_shape = [batch, 1, heads, dim] # Changed seq_len to 1
kv_shape = [batch, seq_len, head_kv, dim]
block_indices_shape = [batch, 1, head_kv, selected_blocks] # Changed seq_len to 1
Expand Down Expand Up @@ -167,8 +174,6 @@ def main():
block_counts=block_counts,
block_size=block_size,
)
print("out", out)
print("ref", ref)
torch.testing.assert_close(ref, out, atol=1e-2, rtol=1e-2)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ def main(
C_shared: tilelang.layout.make_swizzled_layout(C_shared),
})
if threads == 512:
T.no_set_max_nreg()
T.disable_warp_group_reg_alloc()

T.clear(C_local)
for k in T.Pipelined(K // block_K, num_stages=num_stages):
Expand Down
8 changes: 3 additions & 5 deletions examples/flash_attention/example_gqa_bwd.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import torch
import torch.nn.functional as F
import tilelang
from tilelang.autotuner import *
import tilelang.language as T
import argparse

Expand Down Expand Up @@ -340,11 +339,10 @@ def main(BATCH: int = 1,
dK_ref, K.grad = K.grad.clone(), None
dV_ref, V.grad = V.grad.clone(), None

assert torch.allclose(O, O_ref, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(O, O_ref, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(dV, dV_ref, rtol=1e-2, atol=1e-2)
assert torch.allclose(dV, dV_ref, rtol=1e-2, atol=1e-2)
assert torch.allclose(dK, dK_ref, rtol=1e-2, atol=1e-2)
assert torch.allclose(dQ, dQ_ref, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(dK, dK_ref, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(dQ, dQ_ref, rtol=1e-2, atol=1e-2)

def run():
O_ref.backward(dO, retain_graph=True)
Expand Down
2 changes: 1 addition & 1 deletion examples/gdn/example_chunk_o.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def kernel(

T.clear(A_fragment)
T.clear(O_fragment)
T.no_set_max_nreg()
T.disable_warp_group_reg_alloc()
for i_k in T.Pipelined(T.ceildiv(DK, block_DK), num_stages=num_stages):
T.copy(
Q[bb, bs * block_S:(bs + 1) * block_S, bh, i_k * block_DK:(i_k + 1) * block_DK],
Expand Down
2 changes: 1 addition & 1 deletion examples/gdn/example_chunk_scaled_dot_kkt.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def kernel(
})

T.fill(A_fragment, 0)
T.no_set_max_nreg()
T.disable_warp_group_reg_alloc()
for i_s in T.Parallel(block_S):
Beta_shared[i_s] = Beta[bb, bs * block_S + i_s, bh]

Expand Down
2 changes: 1 addition & 1 deletion examples/gdn/example_wy_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def kernel(
U_Beta_shared: tilelang.layout.make_swizzled_layout(U_Beta_shared),
})

T.no_set_max_nreg()
T.disable_warp_group_reg_alloc()
for i_s in T.Parallel(block_S):
Beta_shared[i_s] = Beta[bb, bs * block_S + i_s, bh]
G_shared[i_s] = T.exp(G[bb, bs * block_S + i_s, bh])
Expand Down
1 change: 0 additions & 1 deletion examples/seer_attention/block_sparse_attn_tilelang.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,6 @@ def test_topk_sparse_attention():
# Run tilelang kernel
kernel = blocksparse_flashattn(
BATCH, N_HEADS, SEQ_LEN, SEQ_LEN, D_HEAD, downsample_len, is_causal=True)
print(kernel.get_kernel_source())
tilelang_output = kernel(q, k, v, block_mask.to(torch.int8))

# Compute reference
Expand Down
30 changes: 14 additions & 16 deletions src/op/atomic_add.cc
Original file line number Diff line number Diff line change
Expand Up @@ -182,27 +182,25 @@ For AtomicAdd::MakeSIMTLoop(arith::Analyzer *analyzer) const {

Stmt AtomicAdd::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
Target target = T.target;
bool is_cpu_target = target->GetTargetDeviceType() == kDLCPU;
auto simt_loop = MakeSIMTLoop(analyzer);
auto fused_loop = Downcast<For>(ParallelLoopFuser::Fuse(simt_loop));
For vectorized_thread_loop;
auto par_op = std::make_unique<ParallelOp>(fused_loop);

if (!is_cpu_target) {
std::vector<InferLevel> levels = {InferLevel::kCommon, InferLevel::kStrict,
InferLevel::kFree};
for (auto level : levels) {
par_op->InferLayout(
{T.target, T.thread_bounds, T.layout_map, T.buffer_remap}, level);
}
auto loop_layout = par_op->GetLoopLayout();
Var thread_var = T.thread_var;
Range thread_bounds = T.thread_bounds;
auto thread_loop =
PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer, loop_layout);
vectorized_thread_loop = VectorizeAtomicAdd(
thread_loop, thread_var, thread_bounds, GetArchInt(target));
std::vector<InferLevel> levels = {InferLevel::kCommon, InferLevel::kStrict,
InferLevel::kFree};
for (auto level : levels) {
par_op->InferLayout(
{T.target, T.thread_bounds, T.layout_map, T.buffer_remap}, level);
}
auto loop_layout = par_op->GetLoopLayout();
Var thread_var = T.thread_var;
Range thread_bounds = T.thread_bounds;
auto thread_loop =
PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer, loop_layout);
// TODO(@dyq): buggy implementation, need to fix
// vectorized_thread_loop = VectorizeAtomicAdd(
// thread_loop, thread_var, thread_bounds, GetArchInt(target));
auto vectorized_thread_loop = VectorizeLoop(thread_loop);

Comment on lines +200 to 204
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Replacing specialized AtomicAdd vectorizer with generic VectorizeLoop likely breaks vectorization semantics

The previous path used VectorizeAtomicAdd which:

  • plans a vector size based on arch and dtype,
  • and rewrites call_extern("AtomicAdd", ...) into AtomicAddx2/AtomicAddx4 with pointer-to-src conventions.

Switching to VectorizeLoop(thread_loop) bypasses that rewriter. Unless VectorizeLoop knows how to handle AtomicAdd and emit AtomicAddx{2,4}, vectorization will either be skipped or produce invalid IR for the extern call.

Reinstate the specialized vectorizer here until VectorizeLoop subsumes it:

-  // TODO(@dyq): buggy implementation, need to fix
-  // vectorized_thread_loop = VectorizeAtomicAdd(
-  //     thread_loop, thread_var, thread_bounds, GetArchInt(target));
-  auto vectorized_thread_loop = VectorizeLoop(thread_loop);
+  // Use the specialized AtomicAdd vectorizer pending a robust generic path.
+  auto vectorized_thread_loop =
+      VectorizeAtomicAdd(thread_loop, thread_var, thread_bounds,
+                         GetArchInt(target));
  • Confirm that AtomicAddx2/AtomicAddx4 are defined/linked in the runtime:

  • If you intend to keep VectorizeLoop, point me to the lowering that handles AtomicAdd so I can review that path.


🏁 Script executed:

#!/bin/bash
rg -nP '\bAtomicAddx(2|4)\b' -g 'src/**|include/**' -C2

Length of output: 206


🏁 Script executed:

#!/bin/bash
# Properly search for AtomicAddx2/4 definitions or references across the entire repository
rg -nP '\bAtomicAddx(2|4)\b' -C2 || echo "No matches found in repository"

# Additionally, constrain search to src and include directories
rg -nP '\bAtomicAddx(2|4)\b' -g 'src/**' -g 'include/**' -C2 || echo "No matches found in src/ or include/"

Length of output: 4981


🏁 Script executed:

#!/bin/bash
# Search for the definition and usage of VectorizeLoop in the source
rg -n -C3 'VectorizeLoop' -g 'src/**' || echo "VectorizeLoop not found in src/"

Length of output: 6630


Reinstate the specialized AtomicAdd vectorizer to preserve correct semantics

The change from VectorizeAtomicAdd to a plain VectorizeLoop bypasses the custom rewriting that emits AtomicAddx2/AtomicAddx4, leading to either skipped vectorization or invalid IR for the extern call.

• We verified that AtomicAddx2 and AtomicAddx4 are indeed defined and used in the codebase:

  • In the generic customization layer (tilelang), they’re generated by calls in tilelang/language/customize.py .
  • The specialized mutator in src/transform/atomicadd_vectorize.cc pushes "AtomicAddx2"/"AtomicAddx4" into the call arguments based on vector size .
  • Their device implementations live in src/tl_templates/cuda/common.h, underlined by TL_DEVICE void AtomicAddx2(…) and …AtomicAddx4(…) .

• The generic VectorizeLoop implementation (in src/transform/loop_vectorize.cc and loop_vectorize.cc) contains no special handling for extern calls to "AtomicAdd"—it merely vectorizes based on arithmetic patterns, so it won’t rewrite or emit the correct AtomicAddxN calls.

Action: Revert to the specialized vectorizer call in src/op/atomic_add.cc until the generic path is extended to cover AtomicAdd.

--- a/src/op/atomic_add.cc
+++ b/src/op/atomic_add.cc
@@ -200,7 +200,10 @@ Stmt MakeAtomicAddOp(...){
-  // TODO(@dyq): buggy implementation, need to fix
-  // vectorized_thread_loop = VectorizeAtomicAdd(
-  //     thread_loop, thread_var, thread_bounds, GetArchInt(target));
-  auto vectorized_thread_loop = VectorizeLoop(thread_loop);
+  // Use the specialized AtomicAdd vectorizer to emit AtomicAddx2/4 calls
+  auto vectorized_thread_loop =
+      VectorizeAtomicAdd(thread_loop, thread_var, thread_bounds,
+                         GetArchInt(target));
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
// TODO(@dyq): buggy implementation, need to fix
// vectorized_thread_loop = VectorizeAtomicAdd(
// thread_loop, thread_var, thread_bounds, GetArchInt(target));
auto vectorized_thread_loop = VectorizeLoop(thread_loop);
// Use the specialized AtomicAdd vectorizer to emit AtomicAddx2/4 calls
auto vectorized_thread_loop =
VectorizeAtomicAdd(thread_loop, thread_var, thread_bounds,
GetArchInt(target));
🤖 Prompt for AI Agents
In src/op/atomic_add.cc around lines 200 to 204, the code currently uses
VectorizeLoop which skips the specialized AtomicAdd rewriting and produces
incorrect IR for extern AtomicAdd calls; restore the original specialized
vectorizer by calling VectorizeAtomicAdd(thread_loop, thread_var, thread_bounds,
GetArchInt(target)) (and remove or update the TODO/commented buggy note) so the
mutator emits AtomicAddx2/AtomicAddx4 as before until the generic vectorizer is
enhanced to handle AtomicAdd.

if (par_op->GetPredicate(T.thread_var).defined()) {
return IfThenElse(par_op->GetPredicate(T.thread_var).value(),
Expand Down
4 changes: 3 additions & 1 deletion src/op/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ TVM_REGISTER_PASS_CONFIG_OPTION(kPtxasRegisterUsageLevel, Integer);
TVM_REGISTER_PASS_CONFIG_OPTION(kEnablePTXASVerboseOutput, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableShuffleElect, Bool);

DataType cuTensorMapType() { return DataType::UInt(8, 128); }

#define TIR_DEFINE_TL_BUILTIN(OpName) \
const Op &OpName() { \
static const Op &op = Op::Get("tl." #OpName); \
Expand Down Expand Up @@ -78,7 +80,7 @@ TIR_DEFINE_TL_BUILTIN(mbarrier_expect_tx)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(ptx_ldmatirx)
TIR_DEFINE_TL_BUILTIN(ptx_ldmatrix)
.set_num_inputs(4)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
Expand Down
16 changes: 13 additions & 3 deletions src/op/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ namespace tl {

namespace attr {
static constexpr const char *kPaddingMap = "padding_map";
static constexpr const char *kWarpSpecializationScope =
"kWarpSpecializationScope";
} // namespace attr

static constexpr const char *kDebugMergeSharedMemoryAllocations =
Expand Down Expand Up @@ -54,6 +56,14 @@ static constexpr const char *kDisableDynamicTailSplit =
*/
static constexpr const char *kDynamicAlignment = "tl.dynamic_alignment";

/*!
* \brief Get the type of the CUDA tensor map
*
* DataType cuTensorMapType()
*
*/
DataType cuTensorMapType();

/*!
* \brief tvm intrinsics for TMADescriptor creation for tiled load
*
Expand Down Expand Up @@ -138,15 +148,15 @@ TVM_DLL const Op &mbarrier_expect_tx();
/*!
* \brief tvm intrinsics for ldmatrix
*
* ptx_ldmatirx(transposed, num, shared_addr, local_addr)
* ptx_ldmatrix(transposed, num, shared_addr, local_addr)
*
*/
TVM_DLL const Op &ptx_ldmatirx();
TVM_DLL const Op &ptx_ldmatrix();

/*!
* \brief tvm intrinsics for stmatrix
*
* ptx_ldmatirx(transposed, num, shared_addr, int32_values...)
* ptx_ldmatrix(transposed, num, shared_addr, int32_values...)
*
*/
TVM_DLL const Op &ptx_stmatrix();
Expand Down
Loading
Loading