Skip to content

Implement StreamingLLM/Windowed Attention with Attention Sinks #26553

Closed
@tomaarsen

Description

@tomaarsen

Feature request

Hello!

I would love to see StreamingLLM/ Windowed Attention with Attention Sinks implemented, as proposed in https://arxiv.org/abs/2309.17453.
The primary author (@Guangxuan-Xiao) has also released the code here: https://github.com/mit-han-lab/streaming-llm
And I've adapted that code to a drop-in replacement of transformers to allow people to use it: https://github.com/tomaarsen/attention_sinks
(e.g.

from attention_sinks import AutoModel

model = AutoModel.from_pretrained("meta-llama/Llama-2-7b-hf", device_map="auto")

)


schemes

The paper shows that adapting windowed attention such that the first 4 tokens of the input sequence are always in the window, allows any tested LLM (Llama 2, MPT, Falcon, Pythia) to scale to endless inputs without catastropic perplexity increases. All without doing any form of retraining.
With other words, scaling any pretrained LLM to infinite sequence length is as simple as:

  1. Converting the attention to windowed attention.
  2. Using a special cache for the windowed attention that always keeps the first 4 (by default) tokens in the cache.

Using this elementary approach, the authors were able to keep various LLM models stable when feeding them with (!) 4 million tokens.
image

Motivation

Maximum sequence lengths have been an important topic for a while now, with solutions ranging from RoPE to LongLoRA to YaRN, but each of these have their limits, and some also require retraining/additional training. This windowed attention with attention sinks seems to completely solve this problem, and it would be an extremely valuable addition.

I can vouch for the results in the paper. I've gotten these results for Llama 2 7B using my own implementation:
llama_2_7b_ppl_vram

Your contribution

Yes. I would love to help implement this into core transformers rather than in my drop-in implementation. However, I would like to discuss:

  1. Whether this feature is a good fit for transformers.
  2. Where we store the code for converting each model (e.g. Llama, Pythia, Falcon) to windowed attention. See e.g. this file for an example.
  3. Where we store the code with applying the Attention Sink KV Cache after a forward call. see e.g. this file for an example.

The primary author of the paper has also expressed interest in a transformers implementation here.

  • Tom Aarsen

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions