Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
216 changes: 216 additions & 0 deletions examples/attention_sink/benchmark_gqa_sink_fwd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
import torch
import argparse
from tilelang.profiler import do_bench
import triton
import triton.language as tl
from triton.tools.tensor_descriptor import TensorDescriptor
from example_gqa_sink_fwd_bhsd_wgmma_pipelined import flashattn, ref_program, gen_inputs


@triton.jit
def triton_kernel(
Q,
K,
V,
Sinks,
sm_scale,
Out,
Z,
H,
N_Q_CTX,
N_KV_CTX,
Comment on lines +18 to +21
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Remove unused parameters.

Parameters Z, N_Q_CTX, and N_KV_CTX are declared but never used in the function body. Remove them to reduce clutter.

Apply this diff:

-    Z,
-    H,
-    N_Q_CTX,
-    N_KV_CTX,
+    H,
     HEAD_DIM: tl.constexpr,

Then update the caller in triton_program (around line 113-114) to remove these arguments from the kernel invocation.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
Z,
H,
N_Q_CTX,
N_KV_CTX,
H,
HEAD_DIM: tl.constexpr,
🧰 Tools
🪛 Ruff (0.14.0)

18-18: Unused function argument: Z

(ARG001)


20-20: Unused function argument: N_Q_CTX

(ARG001)


21-21: Unused function argument: N_KV_CTX

(ARG001)

🤖 Prompt for AI Agents
In examples/attention_sink/benchmark_gqa_sink_fwd.py around lines 18 to 21, the
function signature includes unused parameters Z, N_Q_CTX, and N_KV_CTX; remove
these parameters from the function definition and any related parameter list to
eliminate clutter, then update the caller in triton_program (around lines
113-114) to remove the corresponding arguments from the kernel invocation so the
call site matches the new signature.

HEAD_DIM: tl.constexpr,
groups: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BANDWIDTH: tl.constexpr,
start_q: tl.constexpr,
):
tl.static_assert(BLOCK_N <= HEAD_DIM)
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
off_z = off_hz // H
off_h = off_hz % H

# load attention sinks
if Sinks is not None: # noqa: SIM108
sink = tl.load(Sinks + off_h).to(tl.float32)
else:
sink = 0

# initialize offsets
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
# initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) + sink
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
# load scales
qk_scale = sm_scale
q = Q.load([off_z, off_h, start_m * BLOCK_M, 0]).reshape([BLOCK_M, HEAD_DIM])

if BANDWIDTH:
lo, hi = tl.maximum(0, start_q + start_m * BLOCK_M -
BANDWIDTH), start_q + (start_m + 1) * BLOCK_M
else:
lo, hi = 0, start_q + (start_m + 1) * BLOCK_M

for start_n in range(lo, hi, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)

mask = (start_n + offs_n)[None, :] > (start_q + offs_m)[:, None]

if BANDWIDTH:
too_old = (start_n + offs_n[None, :]) < (start_q + offs_m[:, None] - BANDWIDTH + 1)
mask = mask | too_old

k = K.load([off_z, off_h // groups, start_n, 0]).reshape([BLOCK_N, HEAD_DIM]).T
qk = tl.dot(q, k, allow_tf32=False)

qk = qk * qk_scale + tl.where(mask, -1.0e6, 0.0)
m_ij = tl.maximum(m_i, tl.max(qk, 1))
qk -= m_ij[:, None]

p = tl.math.exp(qk)
alpha = tl.math.exp(m_i - m_ij)
l_ij = tl.sum(p, 1)
acc = acc * alpha[:, None]

v = V.load([off_z, off_h // groups, start_n, 0]).reshape([BLOCK_N, HEAD_DIM])
# v = v.to(tl.float32)
p = p.to(v.dtype) # We perform fp16 gemm to utilize tensor core
acc = tl.dot(p, v, acc, allow_tf32=False)

l_i = l_i * alpha + l_ij
m_i = m_ij

sink = tl.math.exp(sink - m_i)
z = l_i + sink
acc = acc / z[:, None]
# m_i += tl.math.log(l_i)
# m_ptrs = M + off_hz * N_Q_CTX + offs_m
# tl.store(m_ptrs, m_i)
acc = acc.to(Out.dtype)[None, None, :, :]
Out.store([off_z, off_h, start_m * BLOCK_M, 0], acc)


def triton_program(Q, K, V, Sinks, window_size: int | None = None) -> torch.Tensor:
bs, n_heads, seq_q, head_dim = Q.shape
_, n_heads_kv, seq_kv, _ = K.shape
BLOCK_M = 64
BLOCK_N = 64
groups = n_heads // n_heads_kv

o = torch.empty_like(Q)
grid = (triton.cdiv(seq_q, BLOCK_M), bs * n_heads, 1)
triton_kernel[grid](
TensorDescriptor.from_tensor(Q, [1, 1, BLOCK_M, head_dim]),
TensorDescriptor.from_tensor(K, [1, 1, BLOCK_N, head_dim]),
TensorDescriptor.from_tensor(V, [1, 1, BLOCK_N, head_dim]),
Sinks,
1.0 / head_dim**0.5,
TensorDescriptor.from_tensor(o, [1, 1, BLOCK_M, head_dim]),
bs,
n_heads,
N_Q_CTX=seq_q,
N_KV_CTX=seq_kv,
HEAD_DIM=head_dim,
groups=groups,
BANDWIDTH=window_size,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
start_q=seq_kv - seq_q)
return o


def main(
batch: int = 1,
heads: int = 32,
seq_q: int = 256,
seq_kv: int = 256,
dim: int = 128,
groups: int = 8,
window_size: int | None = None,
dtype: str = "float16",
tune: bool = False,
):
torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype]
if window_size is not None:
print('Using sliding window attention.')
assert window_size <= seq_q
flops_per_matmul = 2.0 * batch * heads * min(
window_size, seq_kv // 2) * seq_q * dim # just a rough estimation
else:
print('Using full attention.')
flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim * 0.5
total_flops = 2 * flops_per_matmul

if tune:
kernel = flashattn(batch, heads, seq_q, seq_kv, dim, groups, window_size, dtype=dtype)
print(f"Best latency: {kernel.latency}")
print(f"Best TFlops: {total_flops / kernel.latency * 1e-9}")
print(f"Best config: {kernel.config}")
else:
block_M = 128
block_N = 128
num_stages = 2
threads = 256
print(f"{block_M=}, {block_N=}, {num_stages=}, {threads=}")

kernel = flashattn(
batch,
heads,
seq_q,
seq_kv,
dim,
groups,
window_size,
block_M=block_M,
block_N=block_N,
num_stages=num_stages,
threads=threads,
dtype=dtype)

Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim, groups, dtype=torch_dtype)

if torch.allclose(
triton_program(Q, K, V, sinks, window_size),
ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype),
rtol=1e-2,
atol=1e-2):
print("Checks for triton passed.✅")
else:
print("Checks for triton failed.❌")

# Benchmark triton
latency_triton = do_bench(lambda: triton_program(Q, K, V, sinks, window_size), warmup=500)
print("Triton: {:.2f} ms".format(latency_triton))
print("Triton: {:.2f} TFlops".format(total_flops / latency_triton * 1e-9))

# Benchmark tilelang
latency_tilelang = do_bench(lambda: kernel(Q, K, V, sinks), warmup=500)
print("Tilelang: {:.2f} ms".format(latency_tilelang))
print("Tilelang: {:.2f} TFlops".format(total_flops / latency_tilelang * 1e-9))

print("Speedup: {:.2f}x".format(latency_triton / latency_tilelang))
Comment on lines +174 to +195
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Validate both implementations before benchmarking.

The current code only validates the Triton implementation against the reference (lines 176-183), but never validates the Tilelang kernel. Both implementations should be validated for correctness before benchmarking.

Additionally, use torch.testing.assert_close for consistency with the codebase instead of torch.allclose combined with manual print statements.

Apply this diff to validate both implementations:

         Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim, groups, dtype=torch_dtype)
 
-        if torch.allclose(
-                triton_program(Q, K, V, sinks, window_size),
-                ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype),
-                rtol=1e-2,
-                atol=1e-2):
-            print("Checks for triton passed.✅")
-        else:
-            print("Checks for triton failed.❌")
+        # Validate Tilelang implementation
+        torch.testing.assert_close(
+            kernel(Q, K, V, sinks),
+            ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype),
+            rtol=1e-2,
+            atol=1e-2)
+        
+        # Validate Triton implementation
+        torch.testing.assert_close(
+            triton_program(Q, K, V, sinks, window_size),
+            ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype),
+            rtol=1e-2,
+            atol=1e-2)
+        print("All checks passed.✅")
 
         # Benchmark triton
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim, groups, dtype=torch_dtype)
if torch.allclose(
triton_program(Q, K, V, sinks, window_size),
ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype),
rtol=1e-2,
atol=1e-2):
print("Checks for triton passed.✅")
else:
print("Checks for triton failed.❌")
# Benchmark triton
latency_triton = do_bench(lambda: triton_program(Q, K, V, sinks, window_size), warmup=500)
print("Triton: {:.2f} ms".format(latency_triton))
print("Triton: {:.2f} TFlops".format(total_flops / latency_triton * 1e-9))
# Benchmark tilelang
latency_tilelang = do_bench(lambda: kernel(Q, K, V, sinks), warmup=500)
print("Tilelang: {:.2f} ms".format(latency_tilelang))
print("Tilelang: {:.2f} TFlops".format(total_flops / latency_tilelang * 1e-9))
print("Speedup: {:.2f}x".format(latency_triton / latency_tilelang))
Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim, groups, dtype=torch_dtype)
# Validate Tilelang implementation
torch.testing.assert_close(
kernel(Q, K, V, sinks),
ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype),
rtol=1e-2,
atol=1e-2)
# Validate Triton implementation
torch.testing.assert_close(
triton_program(Q, K, V, sinks, window_size),
ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype),
rtol=1e-2,
atol=1e-2)
print("All checks passed.✅")
# Benchmark triton
latency_triton = do_bench(lambda: triton_program(Q, K, V, sinks, window_size), warmup=500)
print("Triton: {:.2f} ms".format(latency_triton))
print("Triton: {:.2f} TFlops".format(total_flops / latency_triton * 1e-9))
# Benchmark tilelang
latency_tilelang = do_bench(lambda: kernel(Q, K, V, sinks), warmup=500)
print("Tilelang: {:.2f} ms".format(latency_tilelang))
print("Tilelang: {:.2f} TFlops".format(total_flops / latency_tilelang * 1e-9))
print("Speedup: {:.2f}x".format(latency_triton / latency_tilelang))



if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=1, help='batch size')
parser.add_argument('--heads', type=int, default=64, help='heads')
parser.add_argument('--seq_q', type=int, default=2048, help='sequence length of query')
parser.add_argument('--seq_kv', type=int, default=2048, help='sequence length of key/value')
parser.add_argument('--dim', type=int, default=128, help='dim')
parser.add_argument('--groups', type=int, default=8, help='groups')
parser.add_argument(
'--window_size',
type=int,
default=None,
help='window size (default: None, which means full attention)')
parser.add_argument(
'--dtype', type=str, default="float16", help="dtype, can be float16 or bfloat16")
parser.add_argument('--tune', action='store_true', help='tune configs')
args = parser.parse_args()
main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.groups, args.window_size,
args.dtype, args.tune)
Loading
Loading