-
Notifications
You must be signed in to change notification settings - Fork 27k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement StreamingLLM/Windowed Attention with Attention Sinks #26553
Comments
Hey @tomaarsen, very cool feature and implementation! This definitely looks like a good fit for Keeping a drop-in implementation up to date on the long term is hard to do, so I would recommend we move towards a utility function for now that could eventually be upstreamed into So instead of the current from attention_sinks import AutoModel
model = AutoModel.from_pretrained("meta-llama/Llama-2-7b-hf", device_map="auto") how about something like from attention_sinks import convert_model_attention_to_sinks
from transformers import AutoModel
model = AutoModel.from_pretrained("meta-llama/Llama-2-7b-hf", device_map="auto")
model = convert_model_attention_to_sinks(model) ? Eventually, the API could take two different directions:
The first path is likely the most scalable; we would work a bit on the model definition to enable "plugins" from third-party library, enabling support for many third-party tools. The second one would offer support in core transformers directly, but for this we would really want validation across many models first. |
Hello! I'm glad that you've open to the idea of adding this to from attention_sinks import AutoModel
model = AutoModel.from_pretrained("meta-llama/Llama-2-7b-hf", device_map="auto") is not workable long-term, but I think a solution similar to BetterTransformers in On I'm curious to hear your thoughts on this. For reference, today I will be adding support for MPT, Falcon, Pythia alongside my existing Llama support to
|
[Brainstorming] I'm wondering whether we could use this issue as a catalyst to improve our cache / past key value design we have in Transformers as it needs to be updated anyways soon (cc @gante as well). @tomaarsen do you think we could support StreamingLLM to every model just by defining a "StreamingLLM/AttentionSink" cache that can be passed to the forward method (as Here a GitHub gist of what I'm thinking of: https://gist.github.com/patrickvonplaten/7411f84b8a2cca3bc8e63df315d7d618 In short, this would entail some more fundamental changes to Transformers (essentially that every attention layer would call Would be curious to hear what you think about this idea! At this stage is definitely still pure brainstorming, but I think this could be a cool long-term solution that would also be quite easy to implement |
I'm afraid that your proposal would not quite be enough to implement the AttentionSink approach in all models. In addition to the cache, the approach requires that the position IDs are shifted in the window. To give a toy example: 4 attention sink tokens, window size of 6, and the text is just a space separated alphabet, then the model sees:
With these position IDs:
i.e. the position IDs get shifted (or rather, they don't get shifted) as the window moves. Or from the paper itself (Section 3.2, page 5):
In practice, this is somewhat simple. For Mistral it requires changing this rotary position embedding application here:
Into one that only updates the query_states, e.g. query_states = apply_rotary_pos_emb_single(query_states, cos, sin, position_ids) with def apply_rotary_pos_emb_single(x, cos, sin, position_ids):
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
x_embed = (x * cos) + (rotate_half(x) * sin)
return x_embed Then, we update the key and value states using the cache, followed by an update to the cache. Only after that's done, do we update the key_position_ids = torch.arange(kv_seq_len, device=position_ids.device).unsqueeze(0)
key_states = apply_rotary_pos_emb_single(key_states, cos, sin, key_position_ids) I took these snippets from my The tl:dr is essentially that we need 2 changes to implement Attention Sink correctly:
Your proposal would be a very elegant solution for the second part of the implementation, but not yet the former. I do the former in my Sidenote: I added support for Mistral, GPT-NeoX, Falcon and MPT to
|
Great point about the position_ids, I indeed didn't think about this enough. Also super nice to see that the other LLMs also work great with StreamingLLM's approach! Very encouraging! Taking a step back here, I guess there are different levels of support we could offer for StreamingLLM:
We can achieve this by following the design as described here for the cache, i.e.:
Now for:
it's indeed trickier! One approach to handle this here could be to add an optional
This would then propagate all the way to
that would default to cache = SinkCache(window_length=256, num_sink_tokens=3)
query_pos_ids = ...
key_pos_ids = ....
model(input_ids, position_ids=query_pos_ids, key_position_ids=key_pos_ids, cache=cache)
cache = SinkCache(window_length=256, num_sink_tokens=3) to generate: model.generate("....", cache=cache) and regarding the position_ids it would also require only a small change in
Questions: Changing all the position_ids logic just for StreamingLLM might be a tough sell given that the method is still relatively new, but if it can be nicely done I think it'd be ok to extend the forward and What do you think here @tomaarsen ? In that regard some questions:
|
Apologies for the delay, it a busy day at work today.
Although I'm definitely open to maintaining a third party package, it is not feasible for
This requires me to completely pin each
This is my personal preference intuitively. Beyond that, generation doesn't work out of the box even with the forward methods correctly updated. I've encountered two noteworthy problems:
I think that the key_position_ids idea should work. An alternative is that rotating and caching is implemented in a method, so that only this method can be overridden by a third party (i.e. "attention_sinks") to provide this functionality. Edit: Another alternative is a parameter on the As for your questions, I also invite @Guangxuan-Xiao to answer here, but I'll do my best to answer:
And some more info on the implementation:
For GPT2, I'd have to have a quick look at the implementation. I see it's a bit different than the modern LLMs, e.g. with Quasi-related: I've been pointed to a similar paper that does something very similar: https://arxiv.org/abs/2308.16137
This "Λ-shaped attention mask" is kind of like always attending to the first tokens (i.e. the sink tokens) and "a distance limit" sounds like a window size.
|
Thanks for mentioning our work (https://arxiv.org/abs/2308.16137) "LM-Infinite: Simple On-the-Fly Length Generalization for Large Language Models" a month ago! I also noticed the striking similarities between the two methods: (1) we both use a We are happy to share an implementation here: https://github.com/Glaciohound/LM-Infinite, which you might be interested in having a check. Somewhat surprisingly, in the StreamingLLM's implementation, even when doing context encoding (such as calculating perplexity of a sequence), they feed tokens one by one (as can be observed here and here). In the contrary, our implementation offers a "sequence" mode encoding functionality just as normal language models, which avoids looping through the sequence and provide a great computational efficiency. This is thanks to our specialized attention kernel implementation. I am also very interested in helping to integrate these papers in HuggingFace Transformers. If you need any further information or help from technical side, please do not hesitate to let me know. |
Also @gante |
@patrickvonplaten I've created a draft PR in #26681 using the Also, I ran more experiments over the weekend:
I have a Hugging Face blogpost with more experiments on Attention Sinks coming out soon.
|
This is awesome! In this way, the PR provides a general cache module reusable for other models as well, which is of great help to the whole community and future developers for other models. What is left to be done is compatibility with
|
I'd love to continue working on this. |
Of course @tomaarsen . You can delete the bot comment (I guess you know it 😄 ) - and welcome to the team! |
Thank you! ❤️ |
@tomaarsen @Glaciohound Thanks in advance! |
Feature request
Hello!
I would love to see StreamingLLM/ Windowed Attention with Attention Sinks implemented, as proposed in https://arxiv.org/abs/2309.17453.
The primary author (@Guangxuan-Xiao) has also released the code here: https://github.com/mit-han-lab/streaming-llm
And I've adapted that code to a drop-in replacement of
transformers
to allow people to use it: https://github.com/tomaarsen/attention_sinks(e.g.
)
The paper shows that adapting windowed attention such that the first 4 tokens of the input sequence are always in the window, allows any tested LLM (Llama 2, MPT, Falcon, Pythia) to scale to endless inputs without catastropic perplexity increases. All without doing any form of retraining.
With other words, scaling any pretrained LLM to infinite sequence length is as simple as:
Using this elementary approach, the authors were able to keep various LLM models stable when feeding them with (!) 4 million tokens.
Motivation
Maximum sequence lengths have been an important topic for a while now, with solutions ranging from RoPE to LongLoRA to YaRN, but each of these have their limits, and some also require retraining/additional training. This windowed attention with attention sinks seems to completely solve this problem, and it would be an extremely valuable addition.
I can vouch for the results in the paper. I've gotten these results for Llama 2 7B using my own implementation:
Your contribution
Yes. I would love to help implement this into core transformers rather than in my drop-in implementation. However, I would like to discuss:
transformers
.The primary author of the paper has also expressed interest in a
transformers
implementation here.The text was updated successfully, but these errors were encountered: