Skip to content

Commit 34ba9ed

Browse files
Allow to use flex_attention instead of FSDPA in HPUAttentionImpl (#876)
Co-authored-by: Michał Kuligowski <mkuligowski@habana.ai>
1 parent 03df014 commit 34ba9ed

File tree

3 files changed

+30
-17
lines changed

3 files changed

+30
-17
lines changed

README_GAUDI.md

+1
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,7 @@ Additionally, there are HPU PyTorch Bridge environment variables impacting vLLM
373373
- `PT_HPU_LAZY_MODE`: if `0`, PyTorch Eager backend for Gaudi will be used, if `1` PyTorch Lazy backend for Gaudi will be used. `1` is the default.
374374
- `PT_HPU_ENABLE_LAZY_COLLECTIVES` must be set to `true` for tensor parallel inference with HPU Graphs.
375375
- `PT_HPUGRAPH_DISABLE_TENSOR_CACHE` must be set to `false` for llava model.
376+
- `VLLM_PROMPT_USE_FLEX_ATTENTION` is enabled only for llama model, and allows to use torch.nn.attention.flex_attention instead of FusedSDPA. Note, this requires `VLLM_PROMPT_USE_FUSEDSDPA=0`
376377

377378
# Quantization, FP8 Inference and Model Calibration Process
378379

requirements-hpu.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,4 @@ pandas
88
tabulate
99
setuptools>=61
1010
setuptools-scm>=8
11-
vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@3fd0250
11+
vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@ecb60e4

vllm/attention/backends/hpu_attn.py

+28-16
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22

33
###############################################################################
4-
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company
4+
# Copyright (C) 2024-2025 Habana Labs, Ltd. an Intel Company
55
###############################################################################
66

77
from dataclasses import dataclass
@@ -159,6 +159,8 @@ def __init__(
159159
logger().warning("Could not import HPU FusedSDPA kernel. "
160160
"vLLM will use native implementation.")
161161

162+
self.prefill_use_flex_attention = "flex_attention" in enabled_flags()
163+
162164
suppored_head_sizes = HPUPagedAttention.get_supported_head_sizes()
163165
if head_size not in suppored_head_sizes:
164166
raise ValueError(
@@ -237,7 +239,8 @@ def forward(
237239
self.head_size)
238240

239241
if attn_metadata is None or attn_metadata.block_list is None:
240-
if not self.prefill_use_fusedsdpa:
242+
if (not self.prefill_use_fusedsdpa
243+
and not self.prefill_use_flex_attention):
241244
# TODO: move this outside of model
242245
assert attn_metadata.attn_bias is not None, \
243246
'attn_bias must be set before calling model.forward'
@@ -252,20 +255,29 @@ def forward(
252255
else:
253256
attn_bias = attn_metadata.attn_bias
254257

255-
out = ops.prompt_attention(
256-
query.view(query_shape),
257-
key.view(kv_shape),
258-
value.view(kv_shape),
259-
attn_bias=attn_bias,
260-
p=0.0,
261-
scale=self.scale,
262-
matmul_qk_op=self.matmul_qk,
263-
softmax_op=self.softmax,
264-
matmul_av_op=self.matmul_av,
265-
valid_seq_lengths=attn_metadata.seq_lens_tensor,
266-
fsdpa_op=self.fused_scaled_dot_product_attention
267-
if self.prefill_use_fusedsdpa else None,
268-
)
258+
if not self.prefill_use_flex_attention:
259+
out = ops.prompt_attention(
260+
query.view(query_shape),
261+
key.view(kv_shape),
262+
value.view(kv_shape),
263+
attn_bias=attn_bias,
264+
p=0.0,
265+
scale=self.scale,
266+
matmul_qk_op=self.matmul_qk,
267+
softmax_op=self.softmax,
268+
matmul_av_op=self.matmul_av,
269+
valid_seq_lengths=attn_metadata.seq_lens_tensor,
270+
fsdpa_op=self.fused_scaled_dot_product_attention
271+
if self.prefill_use_fusedsdpa else None,
272+
)
273+
else:
274+
out = ops.flex_attention(
275+
query.view(query_shape),
276+
key.view(kv_shape),
277+
value.view(kv_shape),
278+
scale=self.scale,
279+
)
280+
269281
else:
270282
# TODO: enable FusedSDPA
271283
out = HPUPagedAttention.forward_prefix(

0 commit comments

Comments
 (0)