-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Major Change][Undecided yet] Move to FlashDecoding instead of PagedAttention kernel. #1940
Conversation
server/Makefile-flash-att-v2
Outdated
@@ -1,11 +1,11 @@ | |||
flash_att_v2_commit_cuda := 23e8fa5a263d1c7122bc46a86ef32030ee7130f9 | |||
flash_att_v2_commit_cuda := v2.5.8 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can actually pip install now.
f2813ee
to
2c6430d
Compare
7085898
to
b5ff704
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like a pretty straightforward change, added some comments.
num_seqs, num_heads, head_size = query.shape | ||
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE | ||
input_lengths = cu_seqlen_k | ||
|
||
# NOTE(woosuk): We use a simple heuristic to decide whether to use |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NIT: this comment should move down to the paged attention version condition.
cu_seqlen_k = torch.cat( | ||
[ | ||
torch.zeros( | ||
(1,), device=input_lengths.device, dtype=input_lengths.dtype | ||
), | ||
input_lengths.cumsum(dim=-1), | ||
] | ||
).to(dtype=torch.int32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if this is premature optimization, but saves two allocations:
cu_seqlen_k = torch.empty(input_lengths.size(-1) + 1, device=input_lengths.device, dtype=torch.int32)
cu_seqlen_k[0] = 0
torch.cumsum(input_lengths, -1, out=cu_seqlen_k[1:])
@@ -368,6 +373,23 @@ def forward( | |||
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( | |||
position_ids, max_s, hidden_states.dtype | |||
) | |||
if cu_seqlen_prefill is None and FLASH_DECODING: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we could do this in the paged_attention
function? Then it has a non-ambiguous signature and we don't have to add this to all the models.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The problem is that the tensor creation adds too much overhead.
I did it that way initially and the performance were worse than raw paged just because of that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could also maybe to it all the way up in flash_causal_lm
. That was my next best idea (but I don't like obfuscating tensor content since then the tensors might be either cu_seqlen_q
and cu_seqlen_kor
Noneand
input_lengths` (we could dataclass stuff and all sorts of shenanigans, still obfuscation I feel.
Given the totally optional nature of flash decoding for now, I'm ok if this lives into this particular modeling code while we test, and either rollback or finish the work and put everything into causal_lm once there's only 1 format (biggest drawback will be AMD and intel which do not support FA2 with paged afaik)
@@ -32,7 +40,8 @@ def paged_attention( | |||
kv_head_mapping: torch.Tensor, | |||
softmax_scale: float, | |||
block_tables: torch.Tensor, | |||
input_lengths: torch.Tensor, | |||
cu_seqlen_q: torch.Tensor, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cu_seqlen_q: torch.Tensor, | |
cu_seqlen_q: Optional[torch.Tensor], |
@@ -253,6 +253,7 @@ def forward( | |||
self.kv_head_mapping, | |||
self.softmax_scale, | |||
block_tables, | |||
None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Breaks when flash decoding is enabled?
Conditional flashdecoding. Fix max_q. Working kvcache Working version with flash decoding. Make it work for mistral. Fix after rebase.. Less intrusive. REvert changes in modeling. Speedup flashdecoding. HHachweew Hack to make other models work. Fixing non flash decoding llama path. Router logic knows about page size. Missing 2 models. Missing cohere. Fixing cohere flash decoding. Revamped all this architecture. Fix cohere. Fixing falcon. Enabling custom block size schedule. Update router/src/infer.rs Not sending preallocated output.
…ttention kernel. (#1940) * Using flash decoding Conditional flashdecoding. Fix max_q. Working kvcache Working version with flash decoding. Make it work for mistral. Fix after rebase.. Less intrusive. REvert changes in modeling. Speedup flashdecoding. HHachweew Hack to make other models work. Fixing non flash decoding llama path. Router logic knows about page size. Missing 2 models. Missing cohere. Fixing cohere flash decoding. Revamped all this architecture. Fix cohere. Fixing falcon. Enabling custom block size schedule. Update router/src/infer.rs Not sending preallocated output. * Making it work on non flash decoding. * Fix Cohere. * Fix non decoding paths. * Rebased. * No need for cache_manager anymore. * Update? * "ipex" -> "cpu" * These do not belong. * Factoring cu_seqlen_qk for better abstracting over every model. * Fixing non flash tests/imports. * Changing return everywhere. * Update mistral past. * Fixing Mi{s,x}tral (non functional in Flash Decoding mode though). * Fixup mistral clamping (had issues with cuda graphs). * No need to recreate anything actually.
…ttention kernel. (#1940) * Using flash decoding Conditional flashdecoding. Fix max_q. Working kvcache Working version with flash decoding. Make it work for mistral. Fix after rebase.. Less intrusive. REvert changes in modeling. Speedup flashdecoding. HHachweew Hack to make other models work. Fixing non flash decoding llama path. Router logic knows about page size. Missing 2 models. Missing cohere. Fixing cohere flash decoding. Revamped all this architecture. Fix cohere. Fixing falcon. Enabling custom block size schedule. Update router/src/infer.rs Not sending preallocated output. * Making it work on non flash decoding. * Fix Cohere. * Fix non decoding paths. * Rebased. * No need for cache_manager anymore. * Update? * "ipex" -> "cpu" * These do not belong. * Factoring cu_seqlen_qk for better abstracting over every model. * Fixing non flash tests/imports. * Changing return everywhere. * Update mistral past. * Fixing Mi{s,x}tral (non functional in Flash Decoding mode though). * Fixup mistral clamping (had issues with cuda graphs). * No need to recreate anything actually.
…1940) * Using flash decoding Conditional flashdecoding. Fix max_q. Working kvcache Working version with flash decoding. Make it work for mistral. Fix after rebase.. Less intrusive. REvert changes in modeling. Speedup flashdecoding. HHachweew Hack to make other models work. Fixing non flash decoding llama path. Router logic knows about page size. Missing 2 models. Missing cohere. Fixing cohere flash decoding. Revamped all this architecture. Fix cohere. Fixing falcon. Enabling custom block size schedule. Update router/src/infer.rs Not sending preallocated output. * Making it work on non flash decoding. * Fix Cohere. * Fix non decoding paths. * Rebased. * No need for cache_manager anymore. * Update? * "ipex" -> "cpu" * These do not belong. * Factoring cu_seqlen_qk for better abstracting over every model. * Fixing non flash tests/imports. * Changing return everywhere. * Update mistral past. * Fixing Mi{s,x}tral (non functional in Flash Decoding mode though). * Fixup mistral clamping (had issues with cuda graphs). * No need to recreate anything actually.
What does this PR do?
This PR proposes a long standing change, which is to move towards using FlashDecoding instead of PagedAttention.
FlashDecoding defines it's signature as (query, cu_seqlen_q) + (kv, cu_seqlen_kv) + block_tables (to simplify).
This means we can in a single attention pass, merge prefill and decodes, but most importantly we can have huge query_lengths at query times. With current paged kernels,
there is a hard assumption that Q lengths = 1. For medusa speculation, we're currently faking it by duplicating "queries" in the query slots and adjusting input_lengths and slots.
The longer the query the more wasteful it is (which is ok for small sizes).
With FlashDecoding the expected upsides are:
Current takeaways:
Why not FashInfer (or others)?
API for KV-cache is much more different than FD : https://docs.flashinfer.ai/api/python/prefill.html#batch-prefill-append-attention
It requires some scratch buffers (of unclear size) , and keeping hold of them.
It requires very different bookkeeping splitting the pages and last_page_indices separately (meaning lot more changes to get it to work).
Layout being the same has FD, it can be explored if performance is there (API would allow similar features).
Fixes # (issue)
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.