Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Confused with four attention mechanism and their performance mentioned by paper #33

Closed
weizhenhuan opened this issue Oct 10, 2023 · 7 comments

Comments

@weizhenhuan
Copy link

Nice idea, and it really works well! Thanks for you nice work. But I have some questions. In paper, it mentions four attention mechanism, dense attention fails because it mismatches with the traing phase's length when the outputs' length is longer than training phase, window attention fails because it evicts the initial tokens' kv cache, but for sliding attention with recomputation and streaming attention, I have some questions.

  • The sliding attention with recomputation just recomputes the kv state from L recent tokens, theoretically, it should have the same PPL with window attention, because they uses the same tokens' kv, it clear to see in picture. It just saves the kv space. However, sliding attention has better PPL in picture. Did I have a misunderstanding with sliding attention?
  • The streaming attention's idea is to use initial tokens' kv and L recent tokens' kv, the reason why it uses initial tokens' kv is clear in paper. Compared with dense attention, dense attention also uses these tokens' kv, so it should have no the softmax shift problem. Dense attention even uses more tokens, so dense attention should have better PPL because it has longer context, the streaming attention should have higher inference speed and longer output length because it uses less tokens. But in picture, the streaming attention have better PPL than dense attention.

Thanks for your nice work again. Hope to get a reply.
image

@weizhenhuan weizhenhuan changed the title Confused with four attention mechanism and their permance mentioned by paper Confused with four attention mechanism and their performance mentioned by paper Oct 10, 2023
@Guangxuan-Xiao
Copy link
Collaborator

Guangxuan-Xiao commented Oct 11, 2023

Thank you for your interest in our paper; I appreciate your insightful questions. Here are my clarifications:

  1. Sliding Window with Re-computation vs. Window Attention:

    • Assume we have a token sequence [a, b, c, d, e, f, g], and the model's window size is 4. For predicting token 'g', the sliding window with re-computation truncates the text sequence to [d, e, f], treating it as a whole sequence before inputting it into the language model, predicting only token 'g'. Here, token 'd' is at position 0, 'e' at 1 (seeing only 'd'), and 'f' at 2 (seeing 'd' and 'e').
    • In contrast, window attention reuses the previously computed KV states. So, while predicting 'g', the reused tokens [d, e, f]'s KV are based on prior computations: 'd' was computed at pos 3 (seeing a, b, c), 'e' at pos 3 (seeing b, c, d), and 'f' at pos 3 (seeing c, d, e). The critical distinction is that in sliding window with re-computation, some key states are treated as initial tokens, whereas in window attention, all previous tokens' KV are computed as if they were middle tokens.
  2. StreamingLLM vs. Dense Attention:
    The superior performance of StreamingLLM over dense attention is attributed to the fact that dense attention struggles to generalize to sequences exceeding its training length. In the figure, we showed the language modeling perplexity on a book containing 65K tokens. The perplexity of dense attention becomes problematic because the Llama-2-13B model we used was pre-trained on a chunk size of 4096, causing its perplexity to deteriorate for sequences surpassing 4K. For a deeper dive, you might find length extrapolation works, such as the ALiBi paper (https://arxiv.org/abs/2108.12409), insightful.

I hope this addresses your confusion.

Thanks,
Guangxuan

@Guangxuan-Xiao
Copy link
Collaborator

Guangxuan-Xiao commented Oct 12, 2023

I'd like to clarify that the issue isn't related to absolute or relative positional encoding. Our current results were obtained with models that utilize relative position encodings, such as Llama and MPT. The core of the matter lies in whether the context Keys have been computed with or without previous tokens. Attention sinks refer to Keys computed without prior tokens. Hence, such keys are present in the sliding window with the recomputation baseline. However, in the window attention baseline, all context keys are computed from numerous preceding tokens, and the model is trained to recognize that these aren't attention sinks.

@weizhenhuan
Copy link
Author

Get it! Thanks for your kind reply.

@BitCalSaul
Copy link

@weizhenhuan Hey, zhenhuan, I am also interested in the figure 1 and spent some time to figure it out. I put my thoughts on #42 , and I'd appreciate it if you could correct me if I've made any mistakes. Thank you!

@yongshenglian
Copy link

Hi Guangxuan,

Could you give some insights on the attention sink for the 1st token?

If we have 3 existing tokens (a,b,c), the probability of the 4th token d is p(d) = p(a)*p(b|a)*p(c|b)*p(d|c). Token d will depend on the first token. However, I do not see why the attention score of a is larger than the rest, as shown in your early 2020 paper.

I did a quick test using gpt2 decoder, the output attention score of the first token is not the highest either.

The question is: will the attention score of the 1st token always be the highest among all?

If not, why removing the 1st token can be a problem?

@Kylin9511
Copy link

Thank you for your interest in our paper; I appreciate your insightful questions. Here are my clarifications:

  1. Sliding Window with Re-computation vs. Window Attention:

    • Assume we have a token sequence [a, b, c, d, e, f, g], and the model's window size is 4. For predicting token 'g', the sliding window with re-computation truncates the text sequence to [d, e, f], treating it as a whole sequence before inputting it into the language model, predicting only token 'g'. Here, token 'd' is at position 0, 'e' at 1 (seeing only 'd'), and 'f' at 2 (seeing 'd' and 'e').
    • In contrast, window attention reuses the previously computed KV states. So, while predicting 'g', the reused tokens [d, e, f]'s KV are based on prior computations: 'd' was computed at pos 3 (seeing a, b, c), 'e' at pos 3 (seeing b, c, d), and 'f' at pos 3 (seeing c, d, e). The critical distinction is that in sliding window with re-computation, some key states are treated as initial tokens, whereas in window attention, all previous tokens' KV are computed as if they were middle tokens.
  2. StreamingLLM vs. Dense Attention:
    The superior performance of StreamingLLM over dense attention is attributed to the fact that dense attention struggles to generalize to sequences exceeding its training length. In the figure, we showed the language modeling perplexity on a book containing 65K tokens. The perplexity of dense attention becomes problematic because the Llama-2-13B model we used was pre-trained on a chunk size of 4096, causing its perplexity to deteriorate for sequences surpassing 4K. For a deeper dive, you might find length extrapolation works, such as the ALiBi paper (https://arxiv.org/abs/2108.12409), insightful.

I hope this addresses your confusion.

Thanks, Guangxuan

This is very interesting. In fact I suppose the main difference between the "initial tokens" and "middle tokens" is the positional embedding. For RoPE style positional embedding, an attenuation coefficient is applied to the word embedding and the infomation almost vanish when context is super long and the positional index is very large.

From this perspective, maybe sota positional embedding manners like RoPE-NTK, Yarn, etc would bridge the gap between the SWA and SWA-recompute to a certain extent.

@SimonSongg
Copy link

Thank you for your interest in our paper; I appreciate your insightful questions. Here are my clarifications:

  1. Sliding Window with Re-computation vs. Window Attention:

    • Assume we have a token sequence [a, b, c, d, e, f, g], and the model's window size is 4. For predicting token 'g', the sliding window with re-computation truncates the text sequence to [d, e, f], treating it as a whole sequence before inputting it into the language model, predicting only token 'g'. Here, token 'd' is at position 0, 'e' at 1 (seeing only 'd'), and 'f' at 2 (seeing 'd' and 'e').
    • In contrast, window attention reuses the previously computed KV states. So, while predicting 'g', the reused tokens [d, e, f]'s KV are based on prior computations: 'd' was computed at pos 3 (seeing a, b, c), 'e' at pos 3 (seeing b, c, d), and 'f' at pos 3 (seeing c, d, e). The critical distinction is that in sliding window with re-computation, some key states are treated as initial tokens, whereas in window attention, all previous tokens' KV are computed as if they were middle tokens.
  2. StreamingLLM vs. Dense Attention:
    The superior performance of StreamingLLM over dense attention is attributed to the fact that dense attention struggles to generalize to sequences exceeding its training length. In the figure, we showed the language modeling perplexity on a book containing 65K tokens. The perplexity of dense attention becomes problematic because the Llama-2-13B model we used was pre-trained on a chunk size of 4096, causing its perplexity to deteriorate for sequences surpassing 4K. For a deeper dive, you might find length extrapolation works, such as the ALiBi paper (https://arxiv.org/abs/2108.12409), insightful.

I hope this addresses your confusion.
Thanks, Guangxuan

This is very interesting. In fact I suppose the main difference between the "initial tokens" and "middle tokens" is the positional embedding. For RoPE style positional embedding, an attenuation coefficient is applied to the word embedding and the infomation almost vanish when context is super long and the positional index is very large.

From this perspective, maybe sota positional embedding manners like RoPE-NTK, Yarn, etc would bridge the gap between the SWA and SWA-recompute to a certain extent.

I also have similar feeling, how the model tell which is the initial token? Positional embedding is highly possible. So I am wondering if they tried to re-assign the first token in the sliding window a positional information to make it a "initial token".

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants