Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mistral & Sliding Window Attention - GGUF ROPE accomodation. #3867

Closed
SabinStargem opened this issue Oct 31, 2023 · 15 comments
Closed

Mistral & Sliding Window Attention - GGUF ROPE accomodation. #3867

SabinStargem opened this issue Oct 31, 2023 · 15 comments
Labels
enhancement New feature or request stale

Comments

@SabinStargem
Copy link

I have been trying to use Mistral with extended context. The SWA supposedly allows for up to 32k context, but in practice I get garbage. However, someone one Reddit mentioned that using a ROPE of 45,000 makes 24k coherent. So I compared 24k with default ROPE in KoboldCPP, then with the custom ROPE. The latter works, the former was gibberish.

My guess is that the current GGUF wasn't built with SWA in mind.

@SabinStargem SabinStargem added the enhancement New feature or request label Oct 31, 2023
@shibe2
Copy link
Contributor

shibe2 commented Oct 31, 2023

Does PyTorch model work with long prompts? How does its output compare to outputs of llama.cpp and KoboldCPP? I tried to test it, but ran out of system RAM.

@staviq
Copy link
Contributor

staviq commented Oct 31, 2023

I've been wondering the same thing recently

Mistrals get n_ctx_train of 32k during conversion, but they are effectively 4k context models ( 4k window over 32k total context )

So them being processed as trained 32k context makes no sense, because from the point of view of the current implementation that model should be processed as having 4k trained context.

Any thoughts?

@SabinStargem
Copy link
Author

Dunno anything about pytorch, being exclusively an KoboldCPP user. Anyhow, the person who suggested the usable ROPE for 24k thinks it might be the model configuration at fault, but that will be for smarter minds to figure out. Here is their speculation.

mll_59
Thanks for your reaction. In this case I think it's not a bug in llama.cpp but in the parameters of the Mistral models. The original Mistral models have been trained on 8K context size, see Product | Mistral AI | Open source models .

But when I load a Mistral model, or a finetune of a Mistral model, koboldcpp always reports a trained context size of 32768, like this:

llm_load_print_meta: n_ctx_train = 32768

So llama.cpp (or koboldcpp) just assume that up to 32768 context size, no NTK scaling is needed and they leave the rope freq base at 10000, which I think is correct. I don't know why the model has this n_ctx_train parameter at 32768 instead of 8192, maybe a mistake?

@KerfuffleV2
Copy link
Collaborator

Maybe look at #2268 - YaRN/NTK type-scaling. It may work better than the flat scaling.

@Dampfinchen
Copy link

I don't think llama.cpp supports Sliding Window Attention yet... :(

@ggerganov
Copy link
Owner

It actually might already support it - see the discussion #3581
The main problem is I am not 100% certain of how SWA works (see my comments there) and so we haven't added it as an option to the examples

@studiotatsu
Copy link

studiotatsu commented Nov 1, 2023

Using Mistral models, I set the --rope-freq-base, and I no longer get garbage output, after 8k context
From my Tests:
--rope-freq-base 10000 (default) 8K context
--rope-freq-base 20000 (works, lower perplexity) 16k context
--rope-freq-base 40000 (works, lower perplexity) 32k context

I may be completely using it wrong, but I'm still learning all of the ins and outs of this repo.

@foldl
Copy link
Contributor

foldl commented Jan 22, 2024

FYI: Here is an implementation of SWA using full cache. Tests show that the output looks OK after 9k+ tokens are generated.

@ggerganov
Copy link
Owner

@foldl Can you explain your approach? Let's say you have a KV cache of 8192 tokens and it is now full, and you want to process the 8193th token at position 8192 - what do you do?

@foldl
Copy link
Contributor

foldl commented Jan 22, 2024

@ggerganov

Some terminologies: let's name each item k_vector and v_vector in KV cache.
Length of sliding window is w_len.

It is simple from the perspective of implementation. The required modifications
besides extending cache length to 32k:

  • Before: all k_vectors and v_vectors are used.

  • After: only the last (up to) w_len k_vectors and v_vectors in cache are used.

Key point: Tokens are processed one by one. (Batch processing is also possible,
which will be discussed later.)

To explain it, here is a secret story that I'd like to share.


A: What's the meaning of the n-th k_vector/v_vector in the cache of each layer?

B: Easy. It's an understanding of n-th input token.

A: No. It's the output from the previous layer which has attended to all previous
tokens, therefore, it carries some information of all previous input tokens.

B: Okay. Then what?

A: It is waste of memory and computation. For each layer, when a new input vector
arrives, and the corresponding q_vector is got, we can only use the last w_len k_vector to
make up the K W matrix, and multiply it to q_vector.

B: Please elaborate.

A: Now, let's think layer by layer.

  • Layer 0: each k_vector/v_vector in the cache corresponds to each input tokens.

    The n-th output of this layer carries information about (n - w_len, n]-th vectors
    in the cache, or rather (n - w_len, n]-th input tokens.

  • Layer 1: As shown above, the n-th k_vector/v_vector in the cache corresponds
    to (n - w_len, n]-th input tokens.

    The n-th output of this layer carries information about (n - w_len, n]-th vectors
    in the cache, or rather (n - 2 * w_len, n]-th (n - 2 * w_len - 1, n]-th input tokens.

    Note that the total attention window is now 2 * w_len 2 * w_len - 1.

and so on.

B: Oh, I got it. Because only w_len vectors are needed to be saved in cache, we
can use a ring buffer.

A: STOP. Ring buffer is a poor well-known terminology. We need something new. Rolling Cache,
Rolling in the cache.

B: Nice.


Obviously, if cache stores only w_len k/v vectors, tokens must be processed one by one.
If a larger cache is used, batch processing is also possible, just keeping in mind that
the n-th input attends to the last (up to) w_len k/v vectors.

@ggerganov
Copy link
Owner

Thanks for the explanation. I guess this corresponds to calling:

llama_kv_cache_seq_rm(ctx, 0, n_past % w_len, n_past % w_len);

before each decode in order to free a KV slot.

But I still can't say I have an intuition of why this works.
This part is not obvious:

The n-th output of this layer carries information about (n - w_len, n]-th vectors
in the cache, or rather (n - 2 * w_len, n]-th input tokens.

@foldl
Copy link
Contributor

foldl commented Jan 22, 2024

For Layer 1: As shown above, the n-th k_vector/v_vector in the cache corresponds
to (n - w_len, n]-th input tokens.

The n-th output of this layer carries information about (n - w_len, n]-th vectors
in the cache, while:

  • the n - w_len + 1-th k/v vector in the cache coming from
    the n - w_len + 1-th output from previous layer carries information about
    (n - w_len + 1 - w_len, n - w_len + 1]-th input tokens as shown above;

  • the n - w_len + 2-th k/v vector in the cache coming from
    the n - w_len + 2-th output from previous layer carries information about
    (n - w_len + 2 - w_len, n - w_len + 2]-th input tokens;

  • ...

  • the n-th k/v vector in the cache coming from
    the n-th output from previous layer carries information about
    (n - w_len, n]-th input tokens;

Therefore, the tokens covered by n-th output of this layer carries information about
(n - 2 * w_len + 1, n]-th input tokens, a.k.a the total attention window length
is 2 * w_len - 1. (~ 2 * w_len)

@foldl
Copy link
Contributor

foldl commented Jan 23, 2024

@ggerganov

llama_kv_cache_seq_rm(ctx, 0, n_past % w_len, n_past % w_len);

I am not quite sure about the internal logic of llama_kv_cache. Intuitively, I think it should be like "deleting entries elder than w_len":

llama_kv_cache_seq_rm(ctx, 0, -1, n_past - w_len);

And then, when calling llama_eval, n_tokens must be 1, which is essential.

@ggerganov
Copy link
Owner

llama_kv_cache_seq_rm(ctx, 0, -1, n_past - w_len);

Correct! I made a mistake.

@github-actions github-actions bot added the stale label Mar 19, 2024
Copy link
Contributor

github-actions bot commented Apr 2, 2024

This issue was closed because it has been inactive for 14 days since being marked as stale.

@github-actions github-actions bot closed this as completed Apr 2, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request stale
Projects
None yet
Development

No branches or pull requests

8 participants