Skip to content

[Bugfix][ROCm] fix the power of 2 exception from triton_unified_attention.py when running llama4 models and unit test fix #18100

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
May 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion tests/kernels/attention/test_triton_unified_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
BLOCK_SIZES = [16, 32]

DTYPES = [torch.float16, torch.bfloat16]
QDTYPES = [None, torch.float8_e4m3fn]
QDTYPES = [None, torch.float8_e4m3fn] if not current_platform.is_rocm() else [
None, torch.float8_e4m3fnuz
]
Comment on lines +16 to +18
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is incorrect, it should use current_platform.fp8_dtype():

QDTYPES = [None, current_platform.fp8_dtype()]

# one value large enough to test overflow in index calculation.
# one value small enough to test the schema op check
NUM_BLOCKS = [32768, 2048]
Expand Down
106 changes: 51 additions & 55 deletions vllm/attention/ops/triton_unified_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,41 +29,42 @@ def apply_softcap(S, x):

@triton.jit
def kernel_unified_attention_2d(
output_ptr, # [num_tokens, num_query_heads, head_size]
query_ptr, # [num_tokens, num_query_heads, head_size]
key_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size]
value_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size]
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
seq_lens_ptr, # [num_seqs]
alibi_slopes_ptr, # [num_query_heads]
scale, # float32
k_scale, # float32
v_scale, # float32
softcap, # float32
num_query_heads: tl.constexpr, # int
num_queries_per_kv: tl.constexpr, # int
block_table_stride: tl.int64, # int
query_stride_0: tl.int64, # int
query_stride_1: tl.int64, # int, should be equal to head_size
output_stride_0: tl.int64, # int
output_stride_1: tl.int64, # int, should be equal to head_size
BLOCK_SIZE: tl.constexpr, # int
HEAD_SIZE: tl.constexpr, # int
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
USE_ALIBI_SLOPES: tl.constexpr, # bool
USE_SOFTCAP: tl.constexpr, # bool
SLIDING_WINDOW: tl.constexpr, # int
stride_k_cache_0: tl.int64, # int
stride_k_cache_1: tl.int64, # int
stride_k_cache_2: tl.int64, # int
stride_k_cache_3: tl.constexpr, # int
stride_v_cache_0: tl.int64, # int
stride_v_cache_1: tl.int64, # int
stride_v_cache_2: tl.int64, # int
stride_v_cache_3: tl.constexpr, # int
query_start_len_ptr, # [num_seqs+1]
BLOCK_Q: tl.constexpr, # int
num_seqs: tl.int32,
output_ptr, # [num_tokens, num_query_heads, head_size]
query_ptr, # [num_tokens, num_query_heads, head_size]
key_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size]
value_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size]
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
seq_lens_ptr, # [num_seqs]
alibi_slopes_ptr, # [num_query_heads]
scale, # float32
k_scale, # float32
v_scale, # float32
softcap, # float32
num_query_heads: tl.constexpr, # int
num_queries_per_kv: tl.constexpr, # int
block_table_stride: tl.int64, # int
query_stride_0: tl.int64, # int
query_stride_1: tl.int64, # int, should be equal to head_size
output_stride_0: tl.int64, # int
output_stride_1: tl.int64, # int, should be equal to head_size
BLOCK_SIZE: tl.constexpr, # int
HEAD_SIZE: tl.constexpr, # int
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
USE_ALIBI_SLOPES: tl.constexpr, # bool
USE_SOFTCAP: tl.constexpr, # bool
SLIDING_WINDOW: tl.constexpr, # int
stride_k_cache_0: tl.int64, # int
stride_k_cache_1: tl.int64, # int
stride_k_cache_2: tl.int64, # int
stride_k_cache_3: tl.constexpr, # int
stride_v_cache_0: tl.int64, # int
stride_v_cache_1: tl.int64, # int
stride_v_cache_2: tl.int64, # int
stride_v_cache_3: tl.constexpr, # int
query_start_len_ptr, # [num_seqs+1]
BLOCK_Q: tl.constexpr, # int
num_seqs: tl.int32,
BLOCK_M: tl.constexpr, # int
):

q_block_global_idx = tl.program_id(0)
Expand Down Expand Up @@ -94,23 +95,21 @@ def kernel_unified_attention_2d(
if q_block_local_idx * BLOCK_Q >= cur_batch_query_len:
return

offs_m = tl.arange(0, BLOCK_Q * num_queries_per_kv)
offs_m = tl.arange(0, BLOCK_M)
offs_d = tl.arange(0, HEAD_SIZE_PADDED)

query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv

query_offset_0 = cur_batch_in_all_start_index + query_pos
query_offset_1 = kv_head_idx * num_queries_per_kv + \
offs_m % num_queries_per_kv

query_offset = (query_offset_0[:, None] * query_stride_0 +
query_offset_1[:, None] * query_stride_1 + offs_d[None, :])

dim_mask = tl.where(offs_d < HEAD_SIZE, 1, 0).to(tl.int1)
query_mask_0 = tl.where(query_pos < cur_batch_query_len, 1, 0).to(tl.int1)
query_mask_1 = tl.where(query_offset_1 < num_query_heads, 1, 0).to(tl.int1)

# Q : (BLOCK_Q * num_queries_per_kv, HEAD_SIZE,)
# Q : (BLOCK_M, HEAD_SIZE_PADDED)
Q = tl.load(
query_ptr + query_offset,
mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None],
Expand All @@ -119,12 +118,9 @@ def kernel_unified_attention_2d(

block_table_offset = seq_idx * block_table_stride

M = tl.full([BLOCK_Q * num_queries_per_kv],
float("-inf"),
dtype=tl.float32)
L = tl.full([BLOCK_Q * num_queries_per_kv], 1.0, dtype=tl.float32)
acc = tl.zeros([BLOCK_Q * num_queries_per_kv, HEAD_SIZE_PADDED],
dtype=tl.float32)
M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
L = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
acc = tl.zeros([BLOCK_M, HEAD_SIZE_PADDED], dtype=tl.float32)

# sequence len for this particular sequence
seq_len = tl.load(seq_lens_ptr + seq_idx)
Expand Down Expand Up @@ -183,13 +179,12 @@ def kernel_unified_attention_2d(
else:
V = V_load

seq_offset = j * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
seq_offset = j * BLOCK_SIZE + offs_n

seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1

# S : (BLOCK_Q * num_queries_per_kv, BLOCK_SIZE,)
S = tl.zeros(shape=(BLOCK_Q * num_queries_per_kv, BLOCK_SIZE),
dtype=tl.float32)
# S : (BLOCK_M, BLOCK_SIZE)
S = tl.zeros(shape=(BLOCK_M, BLOCK_SIZE), dtype=tl.float32)

S += scale * tl.dot(Q, K)

Expand All @@ -207,29 +202,29 @@ def kernel_unified_attention_2d(
S += alibi_slope[:, None] * (seq_offset - context_len)

# compute running maximum
# m_j : (BLOCK_Q * num_queries_per_kv,)
# m_j : (BLOCK_M,)
m_j = tl.maximum(M, tl.max(S, axis=1))
# For sliding window there's a chance the max is -inf due to masking of
# the entire row. In this case we need to set m_j 0 to avoid NaN
m_j = tl.where(m_j > float("-inf"), m_j, 0.0)

# P : (BLOCK_Q * num_queries_per_kv, BLOCK_SIZE,)
# P : (BLOCK_M, BLOCK_SIZE)
P = tl.exp(S - m_j[:, None])

# l_j : (BLOCK_Q * num_queries_per_kv,)
# l_j : (BLOCK_M,)
l_j = tl.sum(P, axis=1)

# alpha : (BLOCK_Q * num_queries_per_kv, )
# alpha : (BLOCK_M, )
alpha = tl.exp(M - m_j)

# acc : (BLOCK_Q * num_queries_per_kv, BLOCK_SIZE,)
# acc : (BLOCK_M, HEAD_SIZE_PADDED)
acc = acc * alpha[:, None]

# update constants
L = L * alpha + l_j
M = m_j

# acc : (BLOCK_Q * num_queries_per_kv, BLOCK_SIZE,)
# acc : (BLOCK_M, HEAD_SIZE_PADDED)
acc += tl.dot(P.to(V.dtype), V)

# epilogue
Expand Down Expand Up @@ -334,4 +329,5 @@ def unified_attention(
query_start_len_ptr=cu_seqlens_q,
BLOCK_Q=BLOCK_Q,
num_seqs=num_seqs,
BLOCK_M=BLOCK_M,
)