-
-
Notifications
You must be signed in to change notification settings - Fork 10.3k
[V1] Add sliding window support to Flex Attention backend #24089
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Also cc @drisspg for visibility |
|
||
def build_block_mask(self) -> BlockMask: | ||
if self.causal: | ||
if self.sliding_window is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that this would still work w/ the direct build
path are your new test checking this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed in 181f15d, the new test can cover the direct build code path, it's disabled for torch2.8 currently:
use_direct_block_mask = is_torch_equal_or_newer("2.9.0.dev0")
if backend == "FLEX_ATTENTION_SLOW":
actual_backend = _Backend.FLEX_ATTENTION
use_direct_block_mask = False
I modified it to this locally and confirmed it passed on torch2.8:
use_direct_block_mask = True
if backend == "FLEX_ATTENTION_SLOW":
actual_backend = _Backend.FLEX_ATTENTION
use_direct_block_mask = False
But I didn't push it in that commit, just in case that there are something need to disable direct build.
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Please test Alibaba-NLP/gte-reranker-modernbert-base and google/embeddinggemma-300m (need to manually set dtype = float32) to ensure the results of bi-directional attention + sliding window + Flex Attention are correct pytest tests/models/language/pooling/test_st_projector.py::test_embed_models_mteb[model_info1] #24318 |
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
@noooop Have confirmed both tests passed with fp32 locally now:
|
|
||
num_actual_tokens = attn_metadata.num_actual_tokens | ||
|
||
if attn_metadata.sliding_window != self.sliding_window: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how does the init of sliding window work, this line feels a little weird.. we ideally want to create 1 block-mask for prior to running forward and reuse. Do we actually hit this case in practice?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I hit an edge case for models like Alibaba-NLP/gte-reranker-modernbert-base
mentioned above. It needs full attention every 3 layer while remained layers are using sliding window (https://huggingface.co/Alibaba-NLP/gte-reranker-modernbert-base/blob/main/config.json#L18). Its attention layout looks like this:
{
"layer 1": "full_attention",
"layer 2": "sliding_window",
"layer 3": "sliding_window",
"layer 4": "full_attention",
"layer 5": "sliding_window",
"layer 6": "sliding_window",
...
}
And we're sharing attention metadata in same attn_group when preparing model runner's inputs:
vllm/vllm/v1/worker/gpu_model_runner.py
Lines 1084 to 1090 in c1eda61
attn_metadata_i = builder.build( | |
common_prefix_len=common_prefix_len, | |
common_attn_metadata=common_attn_metadata, | |
**extra_attn_metadata_args) | |
for layer_name in attn_group.layer_names: | |
attn_metadata[layer_name] = attn_metadata_i |
Therefore, when running this model with FlexAttention, we can prepare sliding window mask at layer 2 and reuse it at layer 3. But at layer 4, we have to revert to normal bidirectional mask and re-prepare sliding window mask at layer 5 and so on.
layer 1, impl_sliding_window: None, metadata_sliding_window: None
layer 2, impl_sliding_window: 64, metadata_sliding_window: None
layer 3, impl_sliding_window: 64, metadata_sliding_window: 64
layer 4, impl_sliding_window: None, metadata_sliding_window: 64
layer 5, impl_sliding_window: 64, metadata_sliding_window: None
layer 6, impl_sliding_window: 64, metadata_sliding_window: 64
...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmmm, so we end up creating an attn_metadata per group, can we get the sliding window similiar to how _cascade_group does:
vllm/vllm/v1/worker/gpu_model_runner.py
Lines 1175 to 1177 in c1eda61
use_sliding_window = (isinstance(kv_cache_spec, SlidingWindowSpec) or | |
(isinstance(kv_cache_spec, FullAttentionSpec) | |
and kv_cache_spec.sliding_window is not None)) |
So that the we end up having a causal attn-metadata and a sliding window attn-metadata?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Introduced the concept of "Attention Groups" (each gets its own metadata); this was done so we could decouple the different attention types (determined by their backends) from kv-cache-specs since the hybrid-kv-cache manager is not always enabled and in those cases sliding-window, chunked-local-attention, full-attention, etc. will all share kv-cache-spec and a kv-cache-group (basically the kv-cache-manager will treat them all as full). That PR targets llama4 which does
L0: Full Attn
L1: Chunked Local Attn
L2: Chunked Local Attn
L3: Chunked Local Attn
L4: Full Attn
L5: Chunked Local Attn
L6: Chunked Local Attn
L7: Chunked Local Attn
When the hybrid-kv-cache manager is enabled we get:
kv-cache-group 0, attn-group 0: L0, L4
kv-cache-group 1, attn-group 0: L1, L5
kv-cache-group 2, attn-group 0: L2, L6
kv-cache-group 3, attn-group 0: L3, L7
When the hybrid-kv-cache manager is disabled we get:
kv-cache-group 0, attn-group 0: L0, L4
kv-cache-group 0, attn-group 1: L1, L2, L3, L5, L6, L7
We achieve this by wrapping the attention backends to make a new ChunkLocalAttentionBackend
(see:
vllm/vllm/attention/layers/chunked_local_attention.py
Lines 71 to 76 in e42af78
underlying_attn_backend = get_attn_backend(head_size, dtype, | |
kv_cache_dtype, | |
block_size) | |
attn_backend = create_chunked_local_attention_backend( | |
underlying_attn_backend, attention_chunk_size, block_size) |
Im wondering if it makes sense to do something similar sliding window so backends can just build either sliding window attention metadata or full-attention metadata and the GPU model runner will make sure both get built and associated with the correct layers (via Attention Groups) if needed.
Thoughts @heheda12345 ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Im wondering if it makes sense to do something similar sliding window so backends can just build either sliding window attention metadata or full-attention metadata
BTW, Alibaba-NLP/gte-reranker-modernbert-base
mentioned above is an encoder-only model, which only uses EncoderOnlyAttentionSpec
. Perhaps we also need to refactor encoder-only attention interface to allow split attention groups?
Also cc @maxdebayser for any thoughts.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This makes me rethink whether it is reasonable to split out the Encoder-only in #23154. This would result in all Attention modules having corresponding Encoder-only modules.
I think this discussion should belongs to another PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Many new models have used sliding window recently
Since the result of Flex Attention + sliding window is correct, why can't we merge this PR first so that users can use v1+float32+Flex Attention + sliding window.
We can polishing it in the following PRs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I agree. This PR is clean and concise as is. Let's verify the encoder-only support in a follow-up PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree that we can have different attention groups for encoder + full attention and encoder + sliding window attention. And yes we can do it in a follow-up PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for adding tests!
@drisspg as the person most familiar with the FlexAttention backend can you please do a final review? This looks good to me assuming you are ok with it 👍 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I find these lines in Flex Attention Backend:
self.direct_build: bool = is_torch_equal_or_newer("2.9.0.dev0")
and
if self.direct_build and self.causal:
self.block_mask = self._build_block_mask_direct()
else:
self.block_mask = self.build_block_mask()
Does it mean this PR is not correct for torch >= 2.9.0.dev0?
I think we can raise NotImplementedError
for this code path and implement _build_block_mask_direct
in a future PR.
_build_block_mask_direct
should be important for sliding window attention as build_block_mask
iterates over all tokens while _build_block_mask_direct
can only iterate over tokens inside sliding window.
if self.direct_build and self.causal:
self.block_mask = self._build_block_mask_direct()
else:
self.block_mask = self.build_block_mask() In FlexAttentionMetadata's When calling the attn_impl, if the layer uses sliding window attention, we will update the block mask for initial running through fast/slow building if necessary and reuse/update it for remain layers: if attn_metadata.sliding_window != self.sliding_window:
attn_metadata.sliding_window = self.sliding_window
if attn_metadata.direct_build:
# update mask mod in attention metadata
attn_metadata.mask_mod = attn_metadata.get_mask_mod()
attn_metadata.block_mask = (
attn_metadata._build_block_mask_direct())
else:
attn_metadata.block_mask = attn_metadata.build_block_mask() |
But as
|
Different from vllm/vllm/v1/attention/backends/flex_attention.py Lines 459 to 467 in a8c0f59
In fact, the key is # update mask mod in attention metadata
attn_metadata.mask_mod = attn_metadata.get_mask_mod()
attn_metadata.block_mask = (
attn_metadata._build_block_mask_direct()) If At the switching stage for |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think as is this fine, but can you create some issues on the follow up work. The dynamic creating of block_mask is not ideal
Did you try some e2e model? I tried
The logs are attached |
I tried I just ran
Full logs: I suspect the dynamic block_mask creation caused this issue when using hybrid allocator, let me investigate then. |
+1 PTAL #24872 (comment) (I only saw the keyword ‘compile’, maybe it’s not related.) |
Oh, seems we have to disable hybrid allocator when using FlexAttention 😢:
|
By the way, e.g.
Welcome to use and fix what you need! |
The direct build path should skip non intra window blocks if the page table correctly evicts those blocks |
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Trying to some follow-ups:
- https://github.com/vllm-project/vllm/pull/24089/files#r2341783626
- support hybrid allocator
- support real sliding window when using direct build
Purpose
Test Plan
Test Result
Test should still pass.
Essential Elements of an Effective PR Description Checklist
supported_models.md
andexamples
for a new model.