-
Notifications
You must be signed in to change notification settings - Fork 67
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How to generate the n-grams - which to keep, which to discard? #14
Comments
Hi, thanks for your interest! Sorry for not clearly presenting it in our blog and code. We will refactor the code for better readability. We maintain a lookahead branch (a 2D window) with size (W , (N - 1)). It is in variable An example of the lookahead branch is in the figure below (N = 4 and W = 5) There are two points:
For each step. We will generate one token from each window slot, as yellow ones in the figure below. Then we do two things:
About your question, the n-grams collected is stored in a hash map with its starting token as the hash map key and a set of all following (n-1)-grams as value (variable Please feel free to ask if you have other questions. |
For a given starting token and a pool of |
@ggerganov Hi, thanks for your interest. Sorry for my confusing expression. In my implementation, It is OK to set |
@Viol2000 I remain confused about the outputs of the lookahead branch. You say
What happens to the output tokens that correspond to input tokens of green 1-5 and orange 1-4? If they are not used in the next iteration of the lookahead algorithm, what are they used for? Are they used as speculation drafts alongside the cache or are they just not computed? |
@tdene Thanks for your interest! You raised some excellent points! These tokens' outputs are discarded. However, I do not think it is a waste to compute them. Although these tokens' outputs are discarded, these tokens themselves are used to build yellow tokens with stronger local information. For example, yellow 7 takes orange1-4, green 5 and red 6 in attention and give output. Another way to do this is to put these tokens in kv-cache. So we can remove row orange 1-4 and green 1-5 and save flops. Moreover, it seems llama.cpp's implementation of lookahead decoding uses this method(correct me if I was wrong). However, if you put them in kv-cache, they actually missed the information on newly coming tokens. For example, if you put orange 1 in kv-cache, it should have been built 2 steps before and remained unchanged since then. At that time, blue 0 is not obtained, so the kv-cache embedding of orange 1 have an incorrect token's information rather than blue 0. And I think it will reduce the local information and reduce the acceptance rate of your guess tokens. |
It is not the case in llama.cpp. In the case from the figure we submit a full batch containing:
None of the lookahead tokens are stored in the KV cache so they will "see" the blue token from the current iteration. |
Hi @ggerganov, thank you for providing clarification! I've noticed the following graph, and it appears to not include several rows. My initial assumption was that there might be an innovation in kv-caching past tokens. Nevertheless, I appreciate your substantial effort in implementing Lookahead Decoding in llama.cpp.
|
This diagram represents the input batch containing 21 tokens. Later in the code, based on the sequence ids of the tokens in the batch, we construct the attention mask which is 21 x 21. |
So the output of the red 5 line in the diagram (which pays attention to blue 0, orange 1, orange 2, orange 3, green 4, red 5), which you name "yellow 6", becomes the new red 5 in the next time-step's calculation? And the output of the green 3 line in the diagram (which pays attention to blue 0, orange 1, orange 2, green 3), which you name "red 4", becomes the new green 3 in the next time-step's calculation? And then the output of the orange 2 line in the diagram (which pays attention to blue 0, orange 1, orange 2), which you name "green 3", becomes the next orange 2 in the next time-step's calculation? I was confused because it sounded, from that first quote, that you use the same "green tokens with number 2-5, all red tokens and all yellow tokens" that are shown in the diagrams as inputs. But "these tokens themselves are used to build yellow tokens with stronger local information", I'm now understanding that you don't use green tokens 2-5 from the diagram, you use the outputs of their lines? |
looking into the code, I think:
yes
The red in the diagram is generated by green, not in the current decode iteration (when green and red already exists), but in the previous iteration when red did not exist. |
Hi @Viol2000 -- regarding this comment, if the green and orange rows of output are not used, why not just trim them out in the attention mask? This won't have any effect on the output yellow-3,4,5,6,7 tokens, and also won't affect the collected 4-grams. Are the green and orange outputs somehow used to refine past states ( |
I checked the related code: LookaheadDecoding/lade/models/llama.py Lines 445 to 451 in b756db3
LookaheadDecoding/lade/decoding.py Lines 274 to 275 in b756db3
In higher-level call, the full
|
@learning-chip Sorry for the late reply. I did not see your message. Even if the green and orange rows of output are not used, they are useful for generating yellow 3,4,5,6,7 tokens. In the current case, if the current step is t, the red tokens come from the t-1 step, green tokens come from the t-2 steps, and yellow tokens come from the t-3 steps. If you do not input these tokens, you are actually using the kv-cache of green and orange tokens, and the information of current input- deep blue 0- can not be obtained in the green and orange tokens. (they are in kv-cache and static) So they are actually out of date. In short, if you do not input green and orange rows, these tokens can not 'see' deep blue 0, and only red tokens can 'see' deep blue 0. If you input these tokens, in shallow layers, these tokens can obtain the information of deep blue 0 and thus affect the output of red tokens in deep layers. I have some initial implementations that show that without green and orange rows as inputs, the generated n-grams have a lower quality to be accepted. I believe there are more trade-offs to do -- with lower computational costs and lower acceptance ratios or higher computational costs and higher acceptance ratios. |
My naming is bad here; I will keep refactoring the code in the following weeks. |
Hi @Viol2000 thanks for the explanation. For this part I totally understand. By "trim the mask" I mean shortening the y-axis (output dimension, or q dimension), not shortening the x-axis (input dimension, or kv dimension). The output red tokens can still see the full inputs. |
@learning-chip I think every input token corresponds to one row.(query) And each column corresponds to one input token or a kv-cache entry.(key) Is it correct? So the green and orange ones should be inputted or in kv-cache. If they are in kv-cache, the situation I explained above will happen. If they are input, we need to have these as rows. Is my understanding correct? |
This is only true if you restrict to using the If directly using FlashAttention or PyTorch |
I know what you mean, but I do not think it is an implementation problem here. The problem is if you do not input orange and green ones, they will not be included in the hidden_states. The orange and green ones need to be stored in kv-cahe. |
Now, |
So you mean we still input all orange and green tokens? |
If so, the orange and green tokens need to go through MLP but not participate in attention computation. Is my understanding correct? |
Yes, so K and V still have the complete input information. Just make Q shorter. Rewriting the |
But I doubt if they will keep input information if they do not inputted to attention as queries. As far as I understand, attention's output is correspond to queries. If not put them as queries and just as keys, they will somehow 'skip' the attention and not include the information we want. It is an inspiring idea but seems strange? |
Right, the MLP is done independently on each token (no cross-token information flow), and will stay as it. Jus t to save FLOPs for the attention computation. |
|
I got your point and did some more experiments. Shortening Q for the last decoding layer is perfectly fine. It has no impact on algorithm correctness, but just saves FLOPs for Shortening Q for other decoding layers will affect the inputs to their next layers. At the end this will affect the lookahead guess and thus the speed-up ratio. Testing with
So the speed-up roughly (but not strictly) drops monotonously when more layers have shortened query/output dimension. To try the above test, in # Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
...
) # unchanged
used_tokens = len(past_tokens[fill_level]) + lguess # obtained from jforward_multilevel()
hidden_states[:,1:-used_tokens,:] = 0.0 # as if those output positions are not computed
hidden_states = residual + hidden_states # just pass over inputs for those positions
# Fully Connected
# (unchanged) And only branch to this modified LookaheadDecoding/lade/models/llama.py Lines 234 to 236 in b756db3
This modification does not actually save FLOPs. Just to see how the guess-accuracy will be like if Further, after obtaining the combined attention mask from LookaheadDecoding/lade/models/llama.py Lines 201 to 203 in b756db3
you can also modify the unused mask rows to a whatever value: used_tokens = len(past_tokens[fill_level]) + lguess # obtained from jforward_multilevel()
maskout_value = torch.finfo(attention_mask.dtype).min # any value, doesn't matter
attention_mask[:,:,1:-used_tokens,:] = maskout_value With the previous modification to |
@learning-chip Thanks for your valuable discussion and experiments! I think the mask trimming does show the potential of saving FLOPS. And when we only prune a few layers, the step compression ratio does not drop. But I doubt reducing these flops will turn into speedups as we can not prune too many attention layers, and the attention computation only takes part of flops in the entire model. As the compression ratio does not perfectly reflect the overall latency, there may be a trade-off between ratio and flops. I guess it is worth more exploration. |
This is actually not my question, llama.cpp wants to implement this but encountered some problems.
ggerganov/llama.cpp#4207
The text was updated successfully, but these errors were encountered: