-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Separate attention backends #3005
Conversation
@zhuohan123 What do you think about this design? Please note that while I used |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general the refactor LGTM. My only small concern is on the learning cost of AttentionFactory
since it does not completely behave like a torch nn.module
. I think this can add difficulty for people adding new models.
alibi_slopes=self.alibi_slopes, | ||
) | ||
else: | ||
# prefix-enabled attention |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Prefix-enabled attention and decoding part is the same as that in non_flash.py. Could we move them into BaseAttention? Just like:
class BaseAttention(nn.Module):
def forward(self, ...):
if input_metadata.is_prompt:
if ...:
self._do_prompt_attention()
else:
# prefix-enabled attention
else:
# Decoding run.
def _do_prompt_attention(self):
# use xformers or flash_attn here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @chenxu2048, thanks for your input. I intentionally avoided this design since some attention implementation may not follow the structure. For example, an attention kernel may process the prompt attention and prefix-enabled attention together. In terms of flexibility, I think the current structure is preferable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your explanation.
@zhuohan123 PTAL. Please note that I intentionally didn't make changes to other models than Llama. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! In general LGTM. Will you change all other model files when you merge the PR?
if input_metadata.is_prompt: | ||
# Prompt run. | ||
if (key_cache is None or value_cache is None | ||
or input_metadata.block_tables.numel() == 0): | ||
# normal attention | ||
query = query.unflatten(0, (batch_size, seq_len)) | ||
key = key.unflatten(0, (batch_size, seq_len)) | ||
value = value.unflatten(0, (batch_size, seq_len)) | ||
output = flash_attn_func( | ||
query, | ||
key, | ||
value, | ||
softmax_scale=self.scale, | ||
causal=True, | ||
window_size=self.sliding_window, | ||
alibi_slopes=self.alibi_slopes, | ||
) | ||
else: | ||
# prefix-enabled attention | ||
output = PagedAttentionImpl.forward_prefix( | ||
query, | ||
key, | ||
value, | ||
key_cache, | ||
value_cache, | ||
input_metadata, | ||
self.num_heads, | ||
self.num_kv_heads, | ||
self.alibi_slopes, | ||
) | ||
else: | ||
# Decoding run. | ||
output = PagedAttentionImpl.forward_decode( | ||
query, | ||
key_cache, | ||
value_cache, | ||
input_metadata, | ||
self.num_kv_heads, | ||
self.scale, | ||
self.alibi_slopes, | ||
) | ||
|
||
# Reshape the output tensor. | ||
return output.view(batch_size, seq_len, hidden_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would still suggest separating this out into private methods (_forward_decode
, _forward_prefill
etc.) so that forward
can just decide which method to dispatch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your inputs! Actually, I intentionally avoided the design you proposed to ensure the flexibility in implementing the attention backends. As you pointed out, an attention backend performs 4 tasks: 1) storing the input KV tensors into the KV cache, 2) compute prefills, 3) compute prefills with prefixes, and 4) compute decodes. Currently, the two attention backends (FlashAttentionBackend and XFormersBackend) have a kernel for each task. However, this may not be necessary true in the future. For example, depending on the kernel implementation, one can compute prefills with and without prefixes (2&3) at the same time. For anther example, an attention kernel in TRT-LLM stores KV cache while computing decodes (1&4). These can be even more complicated if we implement something like Cascade inference. Hence, I believe we shouldn't fix a certain structure for the attention backends.
@Yard1 What do you think about this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree we should not make them part of public API, but they can be done as private APIs for the backends that do have that distinction. Basically we should try to modularize the forward method if possible as it makes it easier to read and test.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it. First, I believe the current implementation is easy to read; XFormersBackend
is essentially the same as the current main
branch and FlashAttentionBackend
is simpler than that. Particularly for FlashAttentionBackend
, I believe the implementation in this PR is very easy to understand.
That being said, I do agree that modularizing the backends will make it easy to test them. However, since this PR has already been delayed quite a bit, let's merge the PR and do modularization in the next PR.
if use_v1: | ||
# Run PagedAttention V1. | ||
ops.paged_attention_v1( | ||
output, | ||
query, | ||
key_cache, | ||
value_cache, | ||
num_kv_heads, | ||
scale, | ||
input_metadata.block_tables, | ||
input_metadata.context_lens, | ||
block_size, | ||
input_metadata.max_context_len, | ||
alibi_slopes, | ||
input_metadata.kv_cache_dtype, | ||
) | ||
else: | ||
# Run PagedAttention V2. | ||
assert _PARTITION_SIZE % block_size == 0 | ||
tmp_output = torch.empty( | ||
size=(num_seqs, num_heads, max_num_partitions, head_size), | ||
dtype=output.dtype, | ||
device=output.device, | ||
) | ||
exp_sums = torch.empty( | ||
size=(num_seqs, num_heads, max_num_partitions), | ||
dtype=torch.float32, | ||
device=output.device, | ||
) | ||
max_logits = torch.empty_like(exp_sums) | ||
ops.paged_attention_v2( | ||
output, | ||
exp_sums, | ||
max_logits, | ||
tmp_output, | ||
query, | ||
key_cache, | ||
value_cache, | ||
num_kv_heads, | ||
scale, | ||
input_metadata.block_tables, | ||
input_metadata.context_lens, | ||
block_size, | ||
input_metadata.max_context_len, | ||
alibi_slopes, | ||
input_metadata.kv_cache_dtype, | ||
) | ||
return output |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto as in previous comment (_forward_decode_v1
, _forward_decode_v2
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto. Let's do it in the next PR.
This reverts commit 2daf23a.
) | ||
else: | ||
# Decoding run. | ||
output = PagedAttentionImpl.forward_decode( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am just curious that why not use flash_attn_with_kvcache
? The kernel is faster than paged_attention_kernel. More benchmark details can be found in #2744
This PR refactors the attention layer. Specifically, it separates the code paths for Ampere or more recent NVIDIA GPUs (which can directly use FlashAttention) and other GPUs, so that the code for the former becomes much simpler. This PR will also bring some performance improvements for ALiBi models, since we now directly call FlashAttention instead of using xformers in the middle.