Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

server: update cache_prompt documentation [no ci] #7745

Merged

Conversation

JohannesGaessler
Copy link
Collaborator

Fixes #7594 .

Generally speaking it is not possible to guarantee bit-for-bit identical results when the batch size is varied. As a consequence cache_prompt that caches not only the actual prompt (batch size >> 1) but also the generated tokens (batch size 1) cannot guarantee bit-for-bit identical results. This PR simply updates the documentation to inform users of this.

@steampunque
Copy link

I think there is a very subtle thing going on here which needs to be clarified. With no prompt cache the prompt goes through in chunks of batch size with the final chunk potentially less than the batch size where the first token can simultaneously be generated with the final prompt process. When resubmitting the identical prompt the same results will always be achieved. When prompt cache is turned on, the first generation will be identical to that without prompt cache, but a following run with the identical prompt will start generating immediately with a batch size of 1 (it will not be 1 + residual prompt process for the first generation). So it seems like it should be possible to not start generation on the final partial batch with prompt process so token generation is always done with batch size 1 whether prompt cache is on or off to get the results identical with or without prompt cache.

@steampunque
Copy link

steampunque commented Jun 5, 2024

Here is an example illustrating the point I made in my last comment. Using mistral instruct v0.3 with the simple prompt "hello how are you". The prompt and output don't matter what matters is the LOGPROB value which is the sum of the log of the completion token sequence probabilities. If its identical the results were bit exact.

Start with no prompt cache and get the reference LOGPROB:

PROBS=1 CACHE=0 lm hello how are you?
Hello! I'm an AI, so I don't have feelings, but I'm here and ready to help you. How can I assist you today?
LOGPROB=-2.4457244873046875

Now turn on cache. Result is expected to be identical since the prompt has not been cached yet.

PROBS=1 CACHE=1 lm hello how are you?
Hello! I'm an AI, so I don't have feelings, but I'm here and ready to help you. How can I assist you today?
LOGPROB=-2.4457244873046875

As expected, result was identical indicating bit-exact result. Now submit the same prompt again, this time it is expectred to be different since all token gens are not done with batch size 1:

PROBS=1 CACHE=1 lm hello how are you?
Hello! I'm an AI, so I don't have feelings, but I'm here and ready to help you. How can I assist you today?
LOGPROB=-2.4836981296539307

As expected, result changed. Now do another run with cache on to see if its consistent, it should be.

PROBS=1 CACHE=1 lm hello how are you?
Hello! I'm an AI, so I don't have feelings, but I'm here and ready to help you. How can I assist you today?
LOGPROB=-2.4836981296539307

As expected, result was consistent. Now shut off the prompt cache, result should immediately return
to original:

PROBS=1 CACHE=0 lm hello how are you?
Hello! I'm an AI, so I don't have feelings, but I'm here and ready to help you. How can I assist you today?
LOGPROB=-2.4457244873046875

As expected, result returned to original no-cache result.

To fix this problem I tested a patch where I force the token generation to be batch size 1 whether prompt cache is on or off. The unified diff is as follows:

diff -rbEu llama.cpp.old/examples/server/server.cpp llama.cpp/examples/server/server.cpp
--- llama.cpp.old/examples/server/server.cpp    2024-06-03 17:40:16.678938858 -0400
+++ llama.cpp/examples/server/server.cpp        2024-06-04 19:52:34.893251578 -0400
@@ -268,7 +268,9 @@
     std::string *stopping_word;

     // sampling
-    llama_token sampled;
+    llama_token sampled,start_token;
+    size_t start_pos;
+    bool have_start_token=false;
     struct llama_sampling_params sparams;
     llama_sampling_context ** ctx_sampling = nullptr;
     int ctx_sampling_beams;
@@ -1851,6 +1853,7 @@

         slot.command = SLOT_COMMAND_LOAD_PROMPT;
         slot.prompt_tokens.clear();
+	slot.have_start_token=false;

         LOG_INFO("slot is processing task", {
             {"id_slot", slot.id},
@@ -3430,11 +3433,24 @@

                     // entire prompt has been processed - start decoding new tokens
                     if (slot.n_past == slot.n_prompt_tokens) {
-                        slot.state   = SLOT_STATE_PROCESSING;
-                        slot.command = SLOT_COMMAND_NONE;
+                       if (slot.have_start_token) {
+                          llama_batch_add(batch, slot.start_token, slot.start_pos, { slot.id + 1 }, true);
+                          slot.have_start_token=false;
+                          }

                         GGML_ASSERT(batch.n_tokens > 0);

+                       if (batch.n_tokens > 1) {
+                          slot.start_token = batch.token[batch.n_tokens-1];
+                          slot.start_pos = batch.pos[batch.n_tokens-1];
+                          slot.have_start_token=true;
+                          batch.n_tokens--;
+                          }
+                       else {
+                          slot.state   = SLOT_STATE_PROCESSING;
+                          slot.command = SLOT_COMMAND_NONE;
+                          }
+
                         // extract the logits only for the last token
                         batch.logits[batch.n_tokens - 1] = true;

Running with this patch on the same prompt and cache on/off sequencing produces the following:

PROBS=1 CACHE=0 lm hello how are you?
Hello! I'm an AI, so I don't have feelings, but I'm here and ready to help you. How can I assist you today?
LOGPROB=-2.5299837589263916

PROBS=1 CACHE=1 lm hello how are you?
Hello! I'm an AI, so I don't have feelings, but I'm here and ready to help you. How can I assist you today?
LOGPROB=-2.5299837589263916

PROBS=1 CACHE=1 lm hello how are you?
Hello! I'm an AI, so I don't have feelings, but I'm here and ready to help you. How can I assist you today?
LOGPROB=-2.5299837589263916

PROBS=1 CACHE=1 lm hello how are you?
Hello! I'm an AI, so I don't have feelings, but I'm here and ready to help you. How can I assist you today?
LOGPROB=-2.5299837589263916

PROBS=1 CACHE=0 lm hello how are you?
Hello! I'm an AI, so I don't have feelings, but I'm here and ready to help you. How can I assist you today?
LOGPROB=-2.5299837589263916

Now the result is identical independent of CACHE on/off, but still different from the first two cases since processing sequence for kv fill is not identical to either CACHE on or off processing in the unpatched server. However, all token gens are guaranteed to be done with batch size 1 with the patch as long as nslots is configured to be 1 so processing is the same whether prompt cache is on or off with the patch.

@mofosyne mofosyne added Review Complexity : Low Trivial changes to code that most beginner devs (or those who want a break) can tackle. e.g. UI fix documentation Improvements or additions to documentation labels Jun 5, 2024
@JohannesGaessler
Copy link
Collaborator Author

JohannesGaessler commented Jun 5, 2024

My perspective is this: cache_prompt is an option for better performance so by default the implementation details should be done in such a way that the performance is as high as possible. You can make the results bit-for-bit identical for single-turn conversations with the changes you described but then you would only be able to cache the prompt in chunks of the physical batch size. The physical batch size is by default 512 so any prompts shorter than that cannot be cached at all and for prompts longer than that you can estimate that you will still need to reprocess ~256 tokens on average (but with an awkward batch size and a filled context so these will take disproportionally longer).

And this would still not work for multi-turn generations where there are also generated tokens mixed in. You would not be able to cache those either if you want to guarantee bit-for-bit identical results so in those scenarios you would need to reprocess the entire generation from previous turns as well.

So when I said that the issue is "fundamentally unfixable" I was specifically referring to the performance-optimal implementation. What I should have said is: "It's fundamentally unfixable unless you were to reimplement it in a way that is suboptimal for performance."

@steampunque
Copy link

You can make the results bit-for-bit identical for single-turn conversations with the changes you described but then you would only be able to cache the prompt in chunks of the physical batch size.

That is not how my patch works. It processes the whole prompt in chunks of physical batch size except for the last partial chunk where it leaves one token left so it doesn't simultaneously generate on the last partial chunk.

The physical batch size is by default 512 so any prompts shorter than that cannot be cached at all and for prompts longer than that you can estimate that you will still need to reprocess ~256 tokens on average (but with an awkward batch size and a filled context so these will take disproportionally longer).

No, no tokens are reprocessed. The processing goes B B B .. P-1 1 1 1 1 1 where B is the batch size and P is the last full partial batch size, P<=B.

And this would still not work for multi-turn generations

This is a valid point because Processing 1 1 1 1 1 ..... 1 during token gen is going to give a different result from processing them in B B B... P-1 during next turn prompt processing, but it should also be able to be fixed if desired.

So when I said that the issue is "fundamentally unfixable" I was specifically referring to the performance-optimal implementation. What I should have said is: "It's fundamentally unfixable unless you were to reimplement it in a way that is suboptimal for performance."

Single prompt is both fixable and performance-optimal, while multi turn cannot be fixed without suboptimal performance, though it could in theory also be fixed by using 1 1 1 ... 1 batch size on the prompt processing/caching of the tokens following the initial prompt. I will give a quick look into supporting that option also since it may be desired to at least have the option of deterministic results for both single and multiturn prompting.

@JohannesGaessler
Copy link
Collaborator Author

That is not how my patch works. It processes the whole prompt in chunks of physical batch size except for the last partial chunk where it leaves one token left so it doesn't simultaneously generate on the last partial chunk.

Okay, it seems I misunderstood. Yes, something like that should work and the performance impact of generating one token more or less should be negligible.

But consider that for single-turn generations the user can also add or remove arbitrary tokens at the end of the prompt. So I think it will still be necessary to re-process if two prompts are different but share the same prefix.

@steampunque
Copy link

That is not how my patch works. It processes the whole prompt in chunks of physical batch size except for the last partial chunk where it leaves one token left so it doesn't simultaneously generate on the last partial chunk.

Okay, it seems I misunderstood. Yes, something like that should work and the performance impact of generating one token more or less should be negligible.

But consider that for single-turn generations the user can also add or remove arbitrary tokens at the end of the prompt. So I think it will still be necessary to re-process if two prompts are different but share the same prefix.

Last partial batch in prompt would need reprocessing, agree.

@JohannesGaessler
Copy link
Collaborator Author

I'm merging this PR as-is because while it seems that you could in principle fix some of the nondeterminism issues with negligible performance penalties there is (as of yet) no implementation for this. So the documentation change does accurately reflect the behavior on master until and unless there is a change to the code.

@JohannesGaessler JohannesGaessler merged commit 7027b27 into ggerganov:master Jun 7, 2024
6 checks passed
@steampunque
Copy link

I updated the patch to handle multi turn correctly so the cache function can be made fully deterministic if desired. The approach I used was to disable caching token gens completely (about 10% performance hit assuming prompt processing is 10x the speed of token gen) and also I quantized the dynamic cache length to be an even multiple of the batch size, requiring the last partial batch in the prompt always to be re-processed. For long multi turn prompts where cache gives most benefit this still gives the majority of the performance boost. This results in the identical deterministic result with cache prompt turned on or off and will always give equal or better performance compared to not using cache so there is no reason not to use cache now. I recommend block size 128 based on benchmarks I ran as it gives a good tradeoff of speedup vs block size (want smallest block size as possible while still being fast).

The patch can be found here: https://github.com/steampunque/llama.cpp-patches/blob/main/server_prompt_cache.diff and can be applied with patch -p0 -l <prompt_cache.diff.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation examples Review Complexity : Low Trivial changes to code that most beginner devs (or those who want a break) can tackle. e.g. UI fix server
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Bug: [Server] Prompt caching causes subsequent identical requests to return different token probabilities
4 participants