Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement StreamingLLM/Windowed Attention with Attention Sinks #26553

Closed
tomaarsen opened this issue Oct 3, 2023 · 14 comments · Fixed by #26681
Closed

Implement StreamingLLM/Windowed Attention with Attention Sinks #26553

tomaarsen opened this issue Oct 3, 2023 · 14 comments · Fixed by #26681

Comments

@tomaarsen
Copy link
Member

Feature request

Hello!

I would love to see StreamingLLM/ Windowed Attention with Attention Sinks implemented, as proposed in https://arxiv.org/abs/2309.17453.
The primary author (@Guangxuan-Xiao) has also released the code here: https://github.com/mit-han-lab/streaming-llm
And I've adapted that code to a drop-in replacement of transformers to allow people to use it: https://github.com/tomaarsen/attention_sinks
(e.g.

from attention_sinks import AutoModel

model = AutoModel.from_pretrained("meta-llama/Llama-2-7b-hf", device_map="auto")

)


schemes

The paper shows that adapting windowed attention such that the first 4 tokens of the input sequence are always in the window, allows any tested LLM (Llama 2, MPT, Falcon, Pythia) to scale to endless inputs without catastropic perplexity increases. All without doing any form of retraining.
With other words, scaling any pretrained LLM to infinite sequence length is as simple as:

  1. Converting the attention to windowed attention.
  2. Using a special cache for the windowed attention that always keeps the first 4 (by default) tokens in the cache.

Using this elementary approach, the authors were able to keep various LLM models stable when feeding them with (!) 4 million tokens.
image

Motivation

Maximum sequence lengths have been an important topic for a while now, with solutions ranging from RoPE to LongLoRA to YaRN, but each of these have their limits, and some also require retraining/additional training. This windowed attention with attention sinks seems to completely solve this problem, and it would be an extremely valuable addition.

I can vouch for the results in the paper. I've gotten these results for Llama 2 7B using my own implementation:
llama_2_7b_ppl_vram

Your contribution

Yes. I would love to help implement this into core transformers rather than in my drop-in implementation. However, I would like to discuss:

  1. Whether this feature is a good fit for transformers.
  2. Where we store the code for converting each model (e.g. Llama, Pythia, Falcon) to windowed attention. See e.g. this file for an example.
  3. Where we store the code with applying the Attention Sink KV Cache after a forward call. see e.g. this file for an example.

The primary author of the paper has also expressed interest in a transformers implementation here.

  • Tom Aarsen
@LysandreJik
Copy link
Member

Hey @tomaarsen, very cool feature and implementation!

This definitely looks like a good fit for transformers, or at least it should be of very high value for the community to have access to attention sinks very easily.

Keeping a drop-in implementation up to date on the long term is hard to do, so I would recommend we move towards a utility function for now that could eventually be upstreamed into transformers once it has developed a bit more.

So instead of the current

from attention_sinks import AutoModel

model = AutoModel.from_pretrained("meta-llama/Llama-2-7b-hf", device_map="auto")

how about something like

from attention_sinks import convert_model_attention_to_sinks
from transformers import AutoModel

model = AutoModel.from_pretrained("meta-llama/Llama-2-7b-hf", device_map="auto")
model = convert_model_attention_to_sinks(model)

?


Eventually, the API could take two different directions:

  1. Either we develop it similarly to the existing BetterTransformers support -> It depends on the optimum library being installed in the environment, and offers the method model.to_bettertransformers() to convert the model to the right format
  2. Either we add support for it directly in the from_pretrained method like we do for Flash Attention: AutoModel.from_pretrained("meta-llama/Llama-2-7b-hf", use_flash_attention_2=True)

The first path is likely the most scalable; we would work a bit on the model definition to enable "plugins" from third-party library, enabling support for many third-party tools. The second one would offer support in core transformers directly, but for this we would really want validation across many models first.

cc @ArthurZucker @younesbelkada @patrickvonplaten @ydshieh

@tomaarsen
Copy link
Member Author

tomaarsen commented Oct 3, 2023

Hello!

I'm glad that you've open to the idea of adding this to transformers! I think it would be of enormous value.
First of all, I agree that the implementation of

from attention_sinks import AutoModel

model = AutoModel.from_pretrained("meta-llama/Llama-2-7b-hf", device_map="auto")

is not workable long-term, but I think a solution similar to BetterTransformers in optimum is viable. People can approach the third-party application (e.g. optimum in the case of BetterTransformers, or attention_sinks), and propose the conversion method to add Attention Sinks to whatever architecture isn't supported yet. The goal/scope of the third party would then essentially be to act as a dictionary mapping architectures to conversion functions (rather than also providing AutoModel, AutoModelForCausalLM, LlamaModel, etc.).

On transformers we would have a conversion method (e.g. add_attention_sinks), which applies the conversion from the third party, if it exists. This might be preferable from an API perspective to your option 2, as this method can be given args and kwargs, such as the attention_sink_size (e.g. the first 4 tokens) and window_size (e.g. 1020 tokens). Adding more args and kwargs is more scalable in this way, as we can't just willy-nilly add these kwargs to transformers AutoModel.from_pretrained. This is important to consider, as the research on this is extremely new - so we might require more arguments in the future as the research expands.

I'm curious to hear your thoughts on this.

For reference, today I will be adding support for MPT, Falcon, Pythia alongside my existing Llama support to attention_sinks.

  • Tom Aarsen

@patrickvonplaten
Copy link
Contributor

[Brainstorming] I'm wondering whether we could use this issue as a catalyst to improve our cache / past key value design we have in Transformers as it needs to be updated anyways soon (cc @gante as well).

@tomaarsen do you think we could support StreamingLLM to every model just by defining a "StreamingLLM/AttentionSink" cache that can be passed to the forward method (as past_key_values) and that would then take care of correctly creating the past key values at each step.

Here a GitHub gist of what I'm thinking of: https://gist.github.com/patrickvonplaten/7411f84b8a2cca3bc8e63df315d7d618

In short, this would entail some more fundamental changes to Transformers (essentially that every attention layer would call cache.update(...) if past_key_values is an object of type Cache), but I think this is something we want to do anyways to allow for torch.compile to work better. Also we would then give generate a new function argument generate(..., cache=cache) that can be optionally be passed.

Would be curious to hear what you think about this idea! At this stage is definitely still pure brainstorming, but I think this could be a cool long-term solution that would also be quite easy to implement

@tomaarsen
Copy link
Member Author

tomaarsen commented Oct 3, 2023

@patrickvonplaten

I'm afraid that your proposal would not quite be enough to implement the AttentionSink approach in all models. In addition to the cache, the approach requires that the position IDs are shifted in the window. To give a toy example: 4 attention sink tokens, window size of 6, and the text is just a space separated alphabet, then the model sees:

A
A B
A B C
A B C D
A B C D E
A B C D E F
A B C D E F G
A B C D E F G H 
A B C D E F G H I
A B C D E F G H I J
A B C D F G H I J K
A B C D G H I J K L
A B C D H I J K L M
...

With these position IDs:

0
0 1
0 1 2
0 1 2 3
0 1 2 3 4
0 1 2 3 4 5
0 1 2 3 4 5 6
0 1 2 3 4 5 6 7
0 1 2 3 4 5 6 7 8
0 1 2 3 4 5 6 7 8 9
0 1 2 3 4 5 6 7 8 9
0 1 2 3 4 5 6 7 8 9
0 1 2 3 4 5 6 7 8 9
...

i.e. the position IDs get shifted (or rather, they don't get shifted) as the window moves.

Or from the paper itself (Section 3.2, page 5):

When determining the relative distance and adding positional information to tokens, StreamingLLM
focuses on positions within the cache rather than those in the original text. This distinction is crucial
for StreamingLLM’s performance. For instance, if the current cache has tokens [0, 1, 2, 3, 6, 7, 8]
and is in the process of decoding the 9th token, the positions assigned are [0, 1, 2, 3, 4, 5, 6, 7], rather
than the positions in the original text, which would be [0, 1, 2, 3, 6, 7, 8, 9].


In practice, this is somewhat simple. For Mistral it requires changing this rotary position embedding application here:

query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

Into one that only updates the query_states, e.g.

query_states = apply_rotary_pos_emb_single(query_states, cos, sin, position_ids)

with

def apply_rotary_pos_emb_single(x, cos, sin, position_ids):
    # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
    cos = cos.squeeze(1).squeeze(0)  # [seq_len, dim]
    sin = sin.squeeze(1).squeeze(0)  # [seq_len, dim]
    cos = cos[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
    sin = sin[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
    x_embed = (x * cos) + (rotate_half(x) * sin)
    return x_embed

Then, we update the key and value states using the cache, followed by an update to the cache. Only after that's done, do we update the key_states with "faked" position IDs:

key_position_ids = torch.arange(kv_seq_len, device=position_ids.device).unsqueeze(0)
key_states = apply_rotary_pos_emb_single(key_states, cos, sin, key_position_ids)

I took these snippets from my attention_sinks here and here. I'd recommend checking out these sources as these snippets might be confusing without their context.


The tl:dr is essentially that we need 2 changes to implement Attention Sink correctly:

  1. Update the model architecture to shift the position IDs.
  2. Update the Attention Sink KV Cache using the past_key_values from every ...Model.forward call.

Your proposal would be a very elegant solution for the second part of the implementation, but not yet the former. I do the former in my pos_shift.py files for Mistral, Falcon, GPT-NeoX and Llama.

Sidenote: I added support for Mistral, GPT-NeoX, Falcon and MPT to attention_sinks 🎉
If the model perplexities are anything to go by, then it works great for everything that I've tried:

Perplexity & VRAM plots
Llama 2 7B Falcon 7B
llama_2_7b_ppl_vram_plotted falcon_7b_ppl_vram_plotted
MPT 7B Pythia 6.9B
mpt_7b_ppl_vram_plotted pythia_6 8b_ppl_vram_plotted
Mistral 7B
mistral_7b_ppl_vram_plotted
  • Tom Aarsen

@patrickvonplaten
Copy link
Contributor

patrickvonplaten commented Oct 4, 2023

Great point about the position_ids, I indeed didn't think about this enough.

Also super nice to see that the other LLMs also work great with StreamingLLM's approach! Very encouraging!

Taking a step back here, I guess there are different levels of support we could offer for StreamingLLM:


  • 1.) No real native support in Transformers
    This corresponds a bit to what we have now and what is proposed here . The advantage is that we don't need to do any changes to Transformers and that your package can nicely be leveraged. However, keeping it up to date is challenging.

  • 2.) Native support in Transformers but only for a "model.forward-level". In the end of the day generate is just a method that calls forward multiple times and many libraries that depend on Transformers implement their own generate method.
    So in a first step it would be great to support StreamingLLM for every model's forward method so that one only has to change the generate method.

We can achieve this by following the design as described here

for the cache, i.e.:

  1. Update the Attention Sink KV Cache using the past_key_values from every ...Model.forward call.

Now for:

  1. Update the model architecture to shift the position IDs.

it's indeed trickier!

One approach to handle this here could be to add an optional key_position_ids function argument here:

position_ids: Optional[torch.LongTensor] = None,

This would then propagate all the way to apply_rotary_pos_emb:

def apply_rotary_pos_emb(q, k, cos, sin, position_ids):

that would default to position_ids if not specified. This way the user could at every forward call for generate pass the correct, but different position_ids for query and key respectively. For the user this could then look as follows:

cache = SinkCache(window_length=256, num_sink_tokens=3)

query_pos_ids = ...
key_pos_ids = ....
model(input_ids, position_ids=query_pos_ids, key_position_ids=key_pos_ids, cache=cache)

  • 3.) The last step would then be to support streamingLLM natively in Transformers with generate. If we have finished 2.) This could also be done relatively easily e.g. we could just allow the user to pass:
cache = SinkCache(window_length=256, num_sink_tokens=3)

to generate:

model.generate("....", cache=cache)

and regarding the position_ids it would also require only a small change in

position_ids = attention_mask.long().cumsum(-1) - 1
to correct the position ids for generation.

Questions:

Changing all the position_ids logic just for StreamingLLM might be a tough sell given that the method is still relatively new, but if it can be nicely done I think it'd be ok to extend the forward and prepare_inputs_for_generation method.

What do you think here @tomaarsen ?

In that regard some questions:

  • a) It would be great to also give the user a simple way to support Windowed Attention and Sliding Windowed Attention with recomputation in Transformes - do you know how the position_ids would need to be treated here? In a similar way as it's done for StreamingLLM?
  • b) As far as I can see streaming LLM is implemented only for models that use RoPE position scaling. Do you think it should also work with Alibi (Falcon's default is Alibi position ids) or GPT2 (vanilla position ids)?
    It would be cool to gauge if the changes to the position ids needed here could be solved in a similar design as described above.

@tomaarsen
Copy link
Member Author

tomaarsen commented Oct 4, 2023

Apologies for the delay, it a busy day at work today.
I'll go over each of your options:

1.) No real native support in Transformers

Although I'm definitely open to maintaining a third party package, it is not feasible for transformers as it stands right now. For each architecture I have to:

  1. ✅ Wrap the forward of the ...Model with a cache update, which I can implement fairly elegantly.
  2. ❌ Completely replace the entire forward method of all ...Attention classes to update the position IDs.

This requires me to completely pin each attention_sinks version to a very specific transformers version, which is not really viable, as much as I think it would be fun to maintain a plugin of transformers.


2.) Native support in Transformers but only for a "model.forward-level".

This is my personal preference intuitively.

Beyond that, generation doesn't work out of the box even with the forward methods correctly updated. I've encountered two noteworthy problems:

  1. The attention_mask in _update_model_kwargs_for_generation grows with a "1" for every token generated. Once the Sink KV Cache starts removing samples then this causes a shape mismatch. Easy fix here: https://github.com/tomaarsen/attention_sinks/blob/f46e63101fa74c6095e986c33284217c34a9fd88/attention_sinks/generation/utils.py#L38-L41

  2. The model.generate method does not return the past_key_values, preventing any form of multi-step generation (which is the primary use case of the Attention Sink approach: being able to keep prompting your model over and over and over without it losing fluency). If we update the cache like discussed prior, then this problem could be resolved by the user passing a Cache instance to model.generate which holds the updated past_key_values. This cache instance can then be reused for future model.generate calls.


I think that the key_position_ids idea should work. An alternative is that rotating and caching is implemented in a method, so that only this method can be overridden by a third party (i.e. "attention_sinks") to provide this functionality.

Edit: Another alternative is a parameter on the cache class for cache_before_rotate which determines whether to cache before (like in Attention Sink) or after (normal) rotating.


As for your questions, I also invite @Guangxuan-Xiao to answer here, but I'll do my best to answer:

  • a) I've implemented the window attention in the exact same way as attention_sinks, but just with using 0 sink tokens. That said, I'm not confident that this is the correct approach, as there's a chance that the position IDs should not be shifted for window attention. Perhaps @Guangxuan-Xiao can comment on this.
    Also, for Sliding Window Attention, is that the form where each layer of the model can see a slightly different window?

  • b) Yes, it works for ALiBi. In fact, the MPT implementation is simplest of all - I don't need to override any forward method at all, I just need to call the cache update after every MPTModel.forward. Some proof:

StreamingLLM’ design is versatile and can be seamlessly incorporated into any autoregressive language model that employs relative positional encoding, such as RoPE (Su et al., 2021) and ALiBi (Press et al., 2022).

And some more info on the implementation:

For encoding like RoPE, we cache the Keys of tokens prior to introducing the rotary transformation. Then, we apply position transformation to the keys in the rolling cache at each decoding phase. On the other hand, integrating with ALiBi is more direct. Here, the contiguous linear bias is applied instead of a ’jumping’ bias to the attention scores. This method of assigning positional embedding within the cache is crucial to StreamingLLM’s functionality, ensuring that the model operates efficiently even beyond its pre-training attention window size.

For GPT2, I'd have to have a quick look at the implementation. I see it's a bit different than the modern LLMs, e.g. with GPT2LMHeadModel instead of GPT2ForCausalLM. However, I think we need rotary embeddings.


Quasi-related: I've been pointed to a similar paper that does something very similar: https://arxiv.org/abs/2308.16137

It involves only a Λ-shaped attention mask (to avoid excessive attended tokens) and a distance limit (to avoid unseen distances) while requiring no parameter updates or learning. We find it applicable to a variety of LLMs using relative-position encoding methods. LM-Infinite is computationally efficient with O(n) time and space, and demonstrates consistent text generation fluency and quality to as long as 128k tokens on ArXiv and OpenWebText2 datasets, with 2.72x decoding speedup. We will make the codes publicly available following publication.

This "Λ-shaped attention mask" is kind of like always attending to the first tokens (i.e. the sink tokens) and "a distance limit" sounds like a window size.

  • Tom Aarsen

@Glaciohound
Copy link

Glaciohound commented Oct 6, 2023

Thanks for mentioning our work (https://arxiv.org/abs/2308.16137) "LM-Infinite: Simple On-the-Fly Length Generalization for Large Language Models" a month ago! I also noticed the striking similarities between the two methods: (1) we both use a $\Lambda$-shaped attention mask, which is equivalent to "sink tokens" + nearest tokens, and (2) we both re-arrange the distance, which we referred to as a "distance limit" while they refer to as "When determining the relative distance and adding positional information to tokens, StreamingLLM focuses on positions within the cache rather than those in the original text" in Section 3.2.

We are happy to share an implementation here: https://github.com/Glaciohound/LM-Infinite, which you might be interested in having a check.

Somewhat surprisingly, in the StreamingLLM's implementation, even when doing context encoding (such as calculating perplexity of a sequence), they feed tokens one by one (as can be observed here and here). In the contrary, our implementation offers a "sequence" mode encoding functionality just as normal language models, which avoids looping through the sequence and provide a great computational efficiency. This is thanks to our specialized attention kernel implementation.

I am also very interested in helping to integrate these papers in HuggingFace Transformers. If you need any further information or help from technical side, please do not hesitate to let me know.

@patrickvonplaten
Copy link
Contributor

Also @gante

@tomaarsen
Copy link
Member Author

@patrickvonplaten I've created a draft PR in #26681 using the Cache that you envisioned. The implementation for the Attention Sink Cache should be fairly simple then.

Also, I ran more experiments over the weekend:

  • The method does not trivially work for vanilla attention, as the next token still needs to have the "next" position, which means that the entire KV cache would have to be recomputed with the position IDs shifted to the left to make "space" for the next token.
  • I have very strong suspicions, though I haven't tested it yet, that Windowed Attention with Attention Sinks/StreamingLLM should work with Flash Attention 2 for Llama 🎉

I have a Hugging Face blogpost with more experiments on Attention Sinks coming out soon.

  • Tom Aarsen

@Glaciohound
Copy link

Glaciohound commented Oct 9, 2023

@tomaarsen @patrickvonplaten

This is awesome! In this way, the PR provides a general cache module reusable for other models as well, which is of great help to the whole community and future developers for other models.

What is left to be done is compatibility with backward and sequence forwarding/classification support for long sequences, which I am more than happy to help on! Current implementation here is optimized for generation. To also let users forwarding and backwarding long sequences (such as encoding long contexts or for classification on long document, an inevitable need when users do large-scale pre-training or deployment) without token-by-token forwards, our code snippet used in LM-Switch can serve as a starting point for encoding (happy to merge our codes!). After that, Sinks/StreamingLLM can continue using the cached features (theoretically compatible) for generation.

  • Chi Han

@tomaarsen
Copy link
Member Author

I'd love to continue working on this.

@huggingface huggingface deleted a comment from github-actions bot Nov 3, 2023
@ydshieh
Copy link
Collaborator

ydshieh commented Nov 3, 2023

Of course @tomaarsen .

You can delete the bot comment (I guess you know it 😄 ) - and welcome to the team!

@tomaarsen
Copy link
Member Author

Thank you! ❤️

@huggingface huggingface deleted a comment from github-actions bot Nov 28, 2023
@woominsong
Copy link

@tomaarsen @Glaciohound
Hi! Thanks for all the efforts you have put into making this work.
I was wondering if there have been any updates regarding this issue, particularly about forwarding long sequences.

Thanks in advance!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants