Skip to content

[BugFix] llama4 fa3 fix - RuntimeError: scheduler_metadata must have shape (metadata_size) #16998

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 23, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 48 additions & 28 deletions vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ class LocalAttentionMetadata:
local_block_table: torch.Tensor
local_max_query_len: int
local_max_seq_len: int
local_scheduler_metadata: Optional[torch.Tensor]

local_attn_metadata: Optional[LocalAttentionMetadata] = None

Expand Down Expand Up @@ -286,7 +287,9 @@ def __init__(self, runner: "GPUModelRunner"):

self.runner = runner
self.aot_schedule = (get_flash_attn_version() == 3)
self.num_heads = model_config.get_num_attention_heads(
self.num_heads_q = model_config.get_num_attention_heads(
runner.parallel_config)
self.num_heads_kv = model_config.get_num_kv_heads(
runner.parallel_config)
self.headdim = model_config.get_head_size()
self.page_size = self.runner.block_size
Expand All @@ -308,6 +311,23 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
self.runner.device, non_blocking=True).long()

def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
max_seq_len, causal):
if self.aot_schedule:
return get_scheduler_metadata(
batch_size=batch_size,
max_seqlen_q=max_query_len,
max_seqlen_k=max_seq_len,
cache_seqlens=seqlens,
num_heads_q=self.num_heads_q,
num_heads_kv=self.num_heads_kv,
headdim=self.headdim,
page_size=self.page_size,
cu_seqlens_q=cu_query_lens,
causal=causal,
)
return None

# for local attention
local_attn_metadata = None
if self.runner.attention_chunk_size is not None:
Expand All @@ -319,36 +339,31 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
block_table,
self.runner.block_size,
)
local_query_start_loc = torch.from_numpy(virt_q_cu_seqlens_np).to(
self.runner.device, non_blocking=True)
local_seqused_k = torch.from_numpy(virt_k_seqlens_np).to(
self.runner.device, non_blocking=True)
local_max_query_len = seqlens_q_local_np.max()
local_max_seq_len = virt_k_seqlens_np.max()
local_scheduler_metadata = schedule(
batch_size=local_query_start_loc.shape[0] - 1,
cu_query_lens=local_query_start_loc,
max_query_len=local_max_query_len,
seqlens=local_seqused_k,
max_seq_len=local_max_seq_len,
causal=True)

local_attn_metadata = FlashAttentionMetadata.LocalAttentionMetadata(
local_query_start_loc=torch.from_numpy(
virt_q_cu_seqlens_np).to(self.runner.device,
non_blocking=True),
local_seqused_k=torch.from_numpy(virt_k_seqlens_np).to(
self.runner.device, non_blocking=True),
local_query_start_loc=local_query_start_loc,
local_seqused_k=local_seqused_k,
local_block_table=virt_block_table,
local_max_query_len=seqlens_q_local_np.max(),
local_max_seq_len=virt_k_seqlens_np.max(),
local_max_query_len=local_max_query_len,
local_max_seq_len=local_max_seq_len,
local_scheduler_metadata=local_scheduler_metadata,
)

use_cascade = common_prefix_len > 0

def schedule(cu_query_lens, max_query_len, seqlens, max_seq_len,
causal):
if self.aot_schedule:
return get_scheduler_metadata(
batch_size=num_reqs,
max_seqlen_q=max_query_len,
max_seqlen_k=max_seq_len,
cache_seqlens=seqlens,
num_heads_q=self.num_heads,
num_heads_kv=self.num_heads,
headdim=self.headdim,
page_size=self.page_size,
cu_seqlens_q=cu_query_lens,
causal=causal,
)
return None

if use_cascade:
cu_prefix_query_lens = torch.tensor([0, num_actual_tokens],
dtype=torch.int32,
Expand All @@ -361,12 +376,14 @@ def schedule(cu_query_lens, max_query_len, seqlens, max_seq_len,
suffix_kv_lens = torch.from_numpy(suffix_kv_lens).to(
self.runner.device)
prefix_scheduler_metadata = schedule(
batch_size=num_reqs,
cu_query_lens=cu_prefix_query_lens,
max_query_len=num_actual_tokens,
seqlens=prefix_kv_lens,
max_seq_len=common_prefix_len,
causal=False)
scheduler_metadata = schedule(cu_query_lens=query_start_loc,
scheduler_metadata = schedule(batch_size=num_reqs,
cu_query_lens=query_start_loc,
max_query_len=max_query_len,
seqlens=suffix_kv_lens,
max_seq_len=max_seq_len -
Expand All @@ -377,7 +394,8 @@ def schedule(cu_query_lens, max_query_len, seqlens, max_seq_len,
prefix_kv_lens = None
suffix_kv_lens = None
prefix_scheduler_metadata = None
scheduler_metadata = schedule(cu_query_lens=query_start_loc,
scheduler_metadata = schedule(batch_size=num_reqs,
cu_query_lens=query_start_loc,
max_query_len=max_query_len,
seqlens=seq_lens,
max_seq_len=max_seq_len,
Expand Down Expand Up @@ -541,12 +559,14 @@ def forward(
max_seqlen_q = local_metadata.local_max_query_len
max_seqlen_k = local_metadata.local_max_seq_len
block_table = local_metadata.local_block_table
scheduler_metadata = local_metadata.local_scheduler_metadata
else:
cu_seqlens_q = attn_metadata.query_start_loc
seqused_k = attn_metadata.seq_lens
max_seqlen_q = attn_metadata.max_query_len
max_seqlen_k = attn_metadata.max_seq_len
block_table = attn_metadata.block_table
scheduler_metadata = attn_metadata.scheduler_metadata

descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])

Expand All @@ -565,7 +585,7 @@ def forward(
window_size=self.sliding_window,
block_table=block_table,
softcap=self.logits_soft_cap,
scheduler_metadata=attn_metadata.scheduler_metadata,
scheduler_metadata=scheduler_metadata,
fa_version=self.vllm_flash_attn_version,
q_descale=layer._q_scale.expand(descale_shape),
k_descale=layer._k_scale.expand(descale_shape),
Expand Down