Skip to content

Commit fa88b14

Browse files
yuguo68facebook-github-bot
authored andcommitted
fix AiterFlashAttentionImpl init (#20103)
Summary: Signed-off-by: Yu Guo <yuguo@meta.com> get error ```TypeError: AiterFlashAttentionImpl.__init__() got multiple values for argument 'use_irope'``` for llama4, AiterFlashAttentionImpl.__init__() is missing the `kv_sharing_target_layer_name` arg, https://github.com/vllm-project/vllm/blob/296ce95d8e72f4c6680bda539058f48dbe0f340a/vllm/attention/layer.py#L54 Test Plan: launch a llama4 server with this fix Rollback Plan: Differential Revision: D77340637
1 parent 296ce95 commit fa88b14

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

vllm/v1/attention/backends/rocm_aiter_fa.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,11 +387,15 @@ def __init__(
387387
blocksparse_params: Optional[dict[str, Any]] = None,
388388
logits_soft_cap: Optional[float] = None,
389389
attn_type: AttentionType = AttentionType.DECODER,
390+
kv_sharing_target_layer_name: Optional[str] = None,
390391
use_irope: bool = False,
391392
) -> None:
392393
if blocksparse_params is not None:
393394
raise ValueError(
394395
"AiterFlashAttention does not support block-sparse attention.")
396+
if kv_sharing_target_layer_name is not None:
397+
raise NotImplementedError(
398+
"KV sharing is not supported in AiterFlashAttention.")
395399
self.num_heads = num_heads
396400
self.head_size = head_size
397401
self.scale = float(scale)

0 commit comments

Comments
 (0)