Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding LlamaInfinite model which implements LM-Infinite on Llama #26645

Closed
wants to merge 10 commits into from

Conversation

Glaciohound
Copy link

What does this PR do?

In this PR, we implement LM-Infinite: Simple On-the-Fly Length Generalization for Large Language Models on Llama model proposed in August 2023, which removes length limits of large language models, and enables them to generate to infinite lengths with intact performance similar to training time, without any parameter updates. Results show that LM-Infinite can encode as long as 128k tokens on a single A100 GPU, and allows generating to infinite tokens, thanks to its $O(n)$ time and space complexity for encoding and $O(1)$ complexity for decoding. Interestingly, later StreamingLLM recently also observed alike results on a similar technique.

This implementation is related to and in response to an issue discussing about integrating LM-Infinite into Huggingface Transformers.

This LlamaInfinite model allows for seamless adaptation from usage of original Llama models, simply by substituting LlamaForCausalModel.from_pretrained() with LlamaInfiniteForCausalLM.from_pretrained(). All other usages remain the same. This implementation is compatible with all previous Llama model checkpoints without any modifications, so new model checkpoints are needed.

Fixes # (issue)

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@tomaarsen
Copy link
Member

Hello!

Thank you for taking the time and care to implement this. I'm doing some benchmarking as I'm writing this now :)
I do have to say that I would prefer a solution along the lines of what is being discussed in #26553, e.g. creating some Cache class and/or compartmentalising the key/query rotation etc. into a specific method so a third-party can cleanly overwrite it. This preference originates for a few reasons:

  • package-wide implementation opportunities: It seems feasible to implement this form of improved long-term generation for all LLM architectures. It seems to indeed be quite valuable, so we should aim for that, rather than just for llama.
  • avoiding code duplication & high maintenance costs: If we indeed support this for all LLM architectures, then we will need to create a FalconInfinite, MistralInfinite, MPTInfinite, etc. etc. etc. This is not maintainable.
  • LlamaInfiniteModel does not correspond to a new architecture, but to llama. With other words, you don't load llama-infinite models with the LlamaInfiniteModel, you load llama models. This is unusual for transformers, I believe.

Preliminary benchmarking results

As I was writing this, my perplexity benchmarking tool from attention_sinks made some progress. For reference, I updated it to use LlamaInfiniteForCausalLM from this PR, and then fed it the first few thousand tokens of a 65k token book to measure the perplexity over time. For the unaware, a lower perplexity is better, as it is directly tied to the loss across all of the measured tokens.

I've ran this experiment for:

  1. pure transformers
  2. attention_sinks
  3. window attention (in particular, I used attention_sinks but I set the number of sinks to 0, meaning that it only does the window attention like normal)
  4. LlamaInfinite from this PR.

Figure_1

Let's go over the details:

  1. pure transformers: For Llama, this implementation fails for two reasons: The VRAM is linear to the input length, making the model infeasible for endless prompting (e.g. an assistant LLM wouldn't work well). The perplexity also shoots up as the model goes beyond 4096 tokens, indicating that the model stopped giving reasonable predictions there.
  2. attention_sinks: This implementation uses a window size of 1024, of which 4 are sink tokens. The result is a perplexity that stays low and a constant memory usage - ideal for assistant LLM applications, for example.
  3. window attention: This is the naive approach for getting constant (i.e. feasible) memory usage, and also uses a window size of 1024. It is clear that this approach fails as soon as the first few tokens are discarded due to the window size.
  4. LlamaInfinite: The results here are very interesting. The approach shows to be able to mirror the perplexity performance of regular transformers, and keep it going beyond 4096 tokens. However, it only seems capable of doing so because of the linear space complexity that mirrors that of transformers. From my point of view, this shows that Llama-infinite can likely indeed theoretically keep up fluency indefinitely, but it is just extremely impractical - nobody could actually scale a chat-style LLM to respond to thousands of prompts sequentially using this approach. And that is something that people can actually do with attention_sinks.

To further support my thoughts here, I've also plotted the latencies of using transformers, LlamaInfinite and attention_sinks as a function of the input length. (Note: this is a log plot)
Figure_2

Click to see a non-log plot

Figure_3

As you can see, both LlamaInfinite and transformers are equally impractical for long sequences. Sidenote: even before the memory issues cause latency problems, the LlamaInfinite implementation is a good bit slower than pure transformers or attention_sinks, e.g. 11 samples/s vs 15 samples/s on my device (RTX 3090).

To summarize my results: I'm not very confident in the benefit that LlamaInfinite has over pure transformers. It would work wonders if you happen to have a machine with infinite VRAM, but in the real world the memory issues likely become a problem before the perplexity gains become interesting - especially when attention_sinks is a very practical alternative.

  • Tom Aarsen

@Glaciohound
Copy link
Author

Glaciohound commented Oct 7, 2023

Hi Tom Aarsen!

Thank you so much for taking the time for a detailed evaluation!

I see the point in implementing a plug-in separately for long-term maintenence. I am happy to help in that direction as well (e.g., your efforts in attention_sinks), especially to combine the advantages of both implementations (this and LM-Infinite).

To be more specific:

  • When encoding (e.g., reading a document as context or calculating perplexity), if I understand correctly, attention_sinks currently inputs tokens one-by-one even if the whole sequence is already there, trading off time complexity for space efficiency. This is not a natural philosophy for most other Transformers implementations. LM-Infinite, however, additionally supports a sequence mode for encoding the sequence one at a time, which surely occupies large space per each operation, but making encoding more time efficient. (About machine, we used A100 so could encode up to 128k tokens once) It can of course support token-by-token feeding as well, if we evaluate in that way. In summary, if we combine these two as two options and smartly decide when to do which, we can potentially achieve a better balance between time and space for users with various resources and needs.

  • When decoding or generating, in my understanding, two approaches should theoretically perform identically. We can work together to debug and optimize Figure 2, as I am also trying to interpret why the curve is not smooth but rather discrete. One thing that particularly puzzles me is that, Llama-2 has a natural context window of 4k. Therefore, even Transformers and attention_sinks should behave the same before 4k (if attention_sink does not manually reduce the window further down to 2k) because they both need to attend on all tokens shorter than 4k.

Again, whatever the outcome and final decisions, I see this a great chance for combining and benefiting from both implementations of LM-Infinite and attention_sinks. I am definitely looking forward to working together and make this approach finally integrated into Transformers. If you have any opinions and comments, please feel free to let us know!

Chi Han

@tomaarsen
Copy link
Member

tomaarsen commented Oct 7, 2023

As a quick comment regarding the decoding section: my experiments using window attention and attention_sinks use a window size of 1024, which explains the difference compared between window attention & attention_sinks to just transformers.
Also, the VRAM usage curve is likely discrete because it measures the allocated memory, which is a bit of an imprecise measurement of the real usage.

@Glaciohound Glaciohound marked this pull request as draft October 8, 2023 08:10
Copy link

github-actions bot commented Nov 6, 2023

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot closed this Nov 14, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants