Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Major Change][Undecided yet] Move to FlashDecoding instead of PagedAttention kernel. #1940

Merged
merged 16 commits into from
Jul 1, 2024

Conversation

Narsil
Copy link
Collaborator

@Narsil Narsil commented May 23, 2024

What does this PR do?

This PR proposes a long standing change, which is to move towards using FlashDecoding instead of PagedAttention.

  • FlashDecoding is supposedly faster than paged (on par at best in early testing, but needs for thorough testing).
  • More importantly it will unlock lots of new use cases for much faster speedups (not in current PR).

FlashDecoding defines it's signature as (query, cu_seqlen_q) + (kv, cu_seqlen_kv) + block_tables (to simplify).
This means we can in a single attention pass, merge prefill and decodes, but most importantly we can have huge query_lengths at query times. With current paged kernels,
there is a hard assumption that Q lengths = 1. For medusa speculation, we're currently faking it by duplicating "queries" in the query slots and adjusting input_lengths and slots.
The longer the query the more wasteful it is (which is ok for small sizes).

With FlashDecoding the expected upsides are:

  • Overall speedups since FlashDecoding is supposed to be faster.
  • KV-cache hits (to prevent recomputing kv-cache for common prefixes like system messages or assistant prompts, could have been done with Paged only).
  • Cleaner+Faster with Speculation methods.
  • Insane speedups for harshly constrained grammars (like JSON, fill-a-hole kind of prompting).

Current takeaways:

  • Speed improvement is not really there it's ISO with Paged at best, and has slightly worse scaling seqlen or batch_size (5-10% slower) (potentially linked to the fact that FlashDecoding from FA2 implements only block_size=256, which we may be able to update).
  • Not all the code has been adapted, meaning there are still a few optimizations left (at every call size we recreate cu_seqlen_{q,kv} which is more kernels and more overhead independant of FA's performance.
  • API is much cleaner of varlen queries (even with KV cache).

Why not FashInfer (or others)?

API for KV-cache is much more different than FD : https://docs.flashinfer.ai/api/python/prefill.html#batch-prefill-append-attention
It requires some scratch buffers (of unclear size) , and keeping hold of them.
It requires very different bookkeeping splitting the pages and last_page_indices separately (meaning lot more changes to get it to work).

Layout being the same has FD, it can be explored if performance is there (API would allow similar features).

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

router/src/infer.rs Outdated Show resolved Hide resolved
@@ -1,11 +1,11 @@
flash_att_v2_commit_cuda := 23e8fa5a263d1c7122bc46a86ef32030ee7130f9
flash_att_v2_commit_cuda := v2.5.8
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can actually pip install now.

Copy link
Member

@danieldk danieldk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like a pretty straightforward change, added some comments.

num_seqs, num_heads, head_size = query.shape
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
input_lengths = cu_seqlen_k

# NOTE(woosuk): We use a simple heuristic to decide whether to use
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: this comment should move down to the paged attention version condition.

Comment on lines 478 to 485
cu_seqlen_k = torch.cat(
[
torch.zeros(
(1,), device=input_lengths.device, dtype=input_lengths.dtype
),
input_lengths.cumsum(dim=-1),
]
).to(dtype=torch.int32)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if this is premature optimization, but saves two allocations:

cu_seqlen_k = torch.empty(input_lengths.size(-1) + 1, device=input_lengths.device, dtype=torch.int32)
cu_seqlen_k[0] = 0
torch.cumsum(input_lengths, -1, out=cu_seqlen_k[1:])

@@ -368,6 +373,23 @@ def forward(
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
position_ids, max_s, hidden_states.dtype
)
if cu_seqlen_prefill is None and FLASH_DECODING:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we could do this in the paged_attention function? Then it has a non-ambiguous signature and we don't have to add this to all the models.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is that the tensor creation adds too much overhead.

I did it that way initially and the performance were worse than raw paged just because of that.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could also maybe to it all the way up in flash_causal_lm. That was my next best idea (but I don't like obfuscating tensor content since then the tensors might be either cu_seqlen_q and cu_seqlen_korNoneandinput_lengths` (we could dataclass stuff and all sorts of shenanigans, still obfuscation I feel.

Given the totally optional nature of flash decoding for now, I'm ok if this lives into this particular modeling code while we test, and either rollback or finish the work and put everything into causal_lm once there's only 1 format (biggest drawback will be AMD and intel which do not support FA2 with paged afaik)

@@ -32,7 +40,8 @@ def paged_attention(
kv_head_mapping: torch.Tensor,
softmax_scale: float,
block_tables: torch.Tensor,
input_lengths: torch.Tensor,
cu_seqlen_q: torch.Tensor,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
cu_seqlen_q: torch.Tensor,
cu_seqlen_q: Optional[torch.Tensor],

@@ -253,6 +253,7 @@ def forward(
self.kv_head_mapping,
self.softmax_scale,
block_tables,
None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Breaks when flash decoding is enabled?

Narsil added 13 commits July 1, 2024 13:42
Conditional flashdecoding.

Fix max_q.

Working kvcache

Working version with flash decoding.

Make it work for mistral.

Fix after rebase..

Less intrusive.

REvert changes in modeling.

Speedup flashdecoding.

HHachweew
Hack to make other models work.

Fixing non flash decoding llama path.

Router logic knows about page size.

Missing 2 models.

Missing cohere.

Fixing cohere flash decoding.

Revamped all this architecture.

Fix cohere.

Fixing falcon.

Enabling custom block size schedule.

Update router/src/infer.rs

Not sending preallocated output.
@Narsil Narsil merged commit 4327210 into main Jul 1, 2024
9 checks passed
@Narsil Narsil deleted the flashdecoding branch July 1, 2024 21:28
glegendre01 pushed a commit that referenced this pull request Jul 2, 2024
…ttention kernel. (#1940)

* Using flash decoding

Conditional flashdecoding.

Fix max_q.

Working kvcache

Working version with flash decoding.

Make it work for mistral.

Fix after rebase..

Less intrusive.

REvert changes in modeling.

Speedup flashdecoding.

HHachweew
Hack to make other models work.

Fixing non flash decoding llama path.

Router logic knows about page size.

Missing 2 models.

Missing cohere.

Fixing cohere flash decoding.

Revamped all this architecture.

Fix cohere.

Fixing falcon.

Enabling custom block size schedule.

Update router/src/infer.rs

Not sending preallocated output.

* Making it work on non flash decoding.

* Fix Cohere.

* Fix non decoding paths.

* Rebased.

* No need for cache_manager anymore.

* Update?

* "ipex" -> "cpu"

* These do not belong.

* Factoring cu_seqlen_qk for better abstracting over every model.

* Fixing non flash tests/imports.

* Changing return everywhere.

* Update mistral past.

* Fixing Mi{s,x}tral (non functional in Flash Decoding mode though).

* Fixup mistral clamping (had issues with cuda graphs).

* No need to recreate anything actually.
ErikKaum pushed a commit that referenced this pull request Jul 26, 2024
…ttention kernel. (#1940)

* Using flash decoding

Conditional flashdecoding.

Fix max_q.

Working kvcache

Working version with flash decoding.

Make it work for mistral.

Fix after rebase..

Less intrusive.

REvert changes in modeling.

Speedup flashdecoding.

HHachweew
Hack to make other models work.

Fixing non flash decoding llama path.

Router logic knows about page size.

Missing 2 models.

Missing cohere.

Fixing cohere flash decoding.

Revamped all this architecture.

Fix cohere.

Fixing falcon.

Enabling custom block size schedule.

Update router/src/infer.rs

Not sending preallocated output.

* Making it work on non flash decoding.

* Fix Cohere.

* Fix non decoding paths.

* Rebased.

* No need for cache_manager anymore.

* Update?

* "ipex" -> "cpu"

* These do not belong.

* Factoring cu_seqlen_qk for better abstracting over every model.

* Fixing non flash tests/imports.

* Changing return everywhere.

* Update mistral past.

* Fixing Mi{s,x}tral (non functional in Flash Decoding mode though).

* Fixup mistral clamping (had issues with cuda graphs).

* No need to recreate anything actually.
yuanwu2017 pushed a commit to yuanwu2017/tgi-gaudi that referenced this pull request Sep 26, 2024
…1940)

* Using flash decoding

Conditional flashdecoding.

Fix max_q.

Working kvcache

Working version with flash decoding.

Make it work for mistral.

Fix after rebase..

Less intrusive.

REvert changes in modeling.

Speedup flashdecoding.

HHachweew
Hack to make other models work.

Fixing non flash decoding llama path.

Router logic knows about page size.

Missing 2 models.

Missing cohere.

Fixing cohere flash decoding.

Revamped all this architecture.

Fix cohere.

Fixing falcon.

Enabling custom block size schedule.

Update router/src/infer.rs

Not sending preallocated output.

* Making it work on non flash decoding.

* Fix Cohere.

* Fix non decoding paths.

* Rebased.

* No need for cache_manager anymore.

* Update?

* "ipex" -> "cpu"

* These do not belong.

* Factoring cu_seqlen_qk for better abstracting over every model.

* Fixing non flash tests/imports.

* Changing return everywhere.

* Update mistral past.

* Fixing Mi{s,x}tral (non functional in Flash Decoding mode though).

* Fixup mistral clamping (had issues with cuda graphs).

* No need to recreate anything actually.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants