Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

speculative : add tree-based sampling example #3624

Merged
merged 18 commits into from
Oct 18, 2023
Merged

Conversation

ggerganov
Copy link
Owner

@ggerganov ggerganov commented Oct 14, 2023

ref #3137

This PR demonstrates speculative decoding using tree-based sampling of the drafts.
It also improves the llama_batch API to support assigning multiple sequence ids to the tokens in the batch.
The common/sampling.h API has also been reworked

Tree-based sampling

In the standard speculative decoding, the draft model samples a single sequence of n_draft tokens:

token: A0 B0 C0 D0 E0 F0
pos:   0  1  2  3  4  5
seq:   0  0  0  0  0  0

A further improvement on top of this strategy is to instead sample n_parallel draft sequences together using batched decoding and evaluate all of them in a single pass on the target model:

token: A0 A1 A2 B0 B1 B2 C0 C1 C2 D0 D1 D2 E0 E1 E2 F0 F1 F2
pos:   0  0  0  1  1  1  2  2  2  3  3  3  4  4  4  5  5  5  
seq:   0  1  2  0  1  2  0  1  2  0  1  2  0  1  2  0  1  2  

The drawback of this approach is that the different sequences can share the same prefix in which case we will be evaluating the same tokens multiple times. We can take advantage of llama.cpp's support for custom attention mask (#3228) and evaluate the same set of n_parallel sequences, by evaluating just the unique tokens:

# assume:
#   A0 == A1 == A2
#   B0 == B1 == B2
#   C0 == C2
#   D0 == D2
#   E0 == E2

token: A0 B0 C0 C1 D0 D1 E0 E1 F0 F1 F2
pos:   0  1  2  2  3  3  4  4  5  5  5  
seq:   0  0  0  1  0  1  0  1  0  1  2  
       1  1  2     2     2
       2  2

There are 2 free parameters which need tuning:

  • probability threshold for continuing to sample a draft sequence:

// TODO: make this configurable
if (cur_p[0].p < 0.4) {
LOG("stopping drafting for seq %3d, probability too low: %.3f < 2*%.3f\n", s, cur_p[0].p, cur_p[1].p);
drafts[s].drafting = false;
continue;
}

  • probability threshold for splitting the current draft sequence into a new sequence (i.e. tree branch)

// attempt to split the branch if the probability is high enough
for (int f = 1; f < 8; ++f) {
// TODO: make this configurable
if (n_seq_cur < n_seq_dft && cur_p[f].p > 0.3) {
LOG("splitting seq %3d into %3d\n", s, n_seq_cur);

Probably these parameters need some fine-tuning or some better strategy is required.

llama_batch changes

The seq_id member is now an array for each token. The number of elements in the array is specified via the n_seq_id member:

     typedef struct llama_batch {
         int32_t n_tokens;
 
-        llama_token  * token;
-        float        * embd;
-        llama_pos    * pos;
-        llama_seq_id * seq_id;
-        int8_t       * logits;
+        llama_token  *  token;
+        float        *  embd;
+        llama_pos    *  pos;
+        int32_t      *  n_seq_id;
+        llama_seq_id ** seq_id;
+        int8_t       *  logits;

The llama_batch_init() now requires to specify the maximum number of sequences that an input token can belong to. This is needed to pre-allocate the seq_id arrays with enough size:

     LLAMA_API struct llama_batch llama_batch_init(
             int32_t n_tokens,
-            int32_t embd);
+            int32_t embd,
+            int32_t n_seq_max);

common/sampling.h changes

The llama_sampling_context now contains information for a single sequence. The last_tokens and candidates structures are now part of the context and are called prev and cur. The plan is to merge llama_sampling_context in llama.h in the future, probably after the grammar functionality is merged first.

Here is the new way to sample tokens:

-            llama_token id = llama_sampling_sample(ctx_llm, NULL, ctx_sampling, last_tokens, candidates, i_batch);
+            llama_token id = llama_sampling_sample(ctx_sampling, ctx_llm, NULL, i_batch);
 
-            // remember which tokens were sampled - used for repetition penalties during sampling
-            last_tokens.erase(last_tokens.begin());
-            last_tokens.push_back(id);
+            llama_sampling_accept(ctx_sampling, ctx_llm, id);

There are also helper functions to construct llama_batch that are recommended for the examples:

         // prepare the next batch
-        batch.n_tokens = 0;
+        llama_batch_clear(batch);
 
         // add token to the batch
-        batch.token [batch.n_tokens] = new_token_id;
-        batch.pos   [batch.n_tokens] = n_cur;
-        batch.seq_id[batch.n_tokens] = i;
-        batch.logits[batch.n_tokens] = true;
-        batch.n_tokens += 1;
+        llama_batch_add(batch, new_token_id, n_cur, { i }, true);

Usage

These are examples of running the new tree-based speculative decoding:

# code completion, n_draft = 32, n_parallel = 8
./bin/speculative -m ../models/codellama-34b/ggml-model-f16.gguf -md ../models/tinyllama-1b/ggml-model-q4_0.gguf -p "# Dijkstra's shortest path algorithm in Python (4 spaces indentation) + complexity analysis:" -e -ngl 1 -t 4 -n 512 -c 4096 -s 20 --draft 32 -np 8 --temp 0.0

# grammar sampling, n_draft = 16, n_parallel = 8
./bin/speculative -m ../models/llama-70b-v2/ggml-model-q8_0.gguf -md ../models/tinyllama-1b/ggml-model-f16.gguf --grammar-file ../grammars/json_arr.gbnf -f assistant.txt -e -ngl 1 -t 4 -n 512 -c 2048 -b 2048 --draft 16 -np 8 --temp 0.0

@ggerganov ggerganov changed the title speculative : add tree-based sampling support speculative : add tree-based sampling example Oct 14, 2023
@KerfuffleV2
Copy link
Collaborator

KerfuffleV2 commented Oct 15, 2023

WIP IN PROGRESS

This is specifically to mess with pedantic people like me and see who cracks under the pressure? You got me. It actually took several work in progress in progress pulls but there is a limit to my willpower!

edit: On a more serious note, looking at the changes and how much duplicated code there is for batch handling it really feels like a few convenience functions would help make that stuff easier to work with and maintain. There are some special cases, but a lot of the batch handling for prompt input and token generation is going to be about the same. Client/sequence handling also.

@ggerganov
Copy link
Owner Author

You got me.

Ah finally - someone cracked! 😄

True, will probably refactor this, but there are still some API changes popping up, so I'm not in a hurry to do it. For example, for the tree-based decoding functionality to be supported, we need to be able to assign multiple sequence ids to each input token and so the llama_batch API just got updated to support this in this PR.

@ggerganov ggerganov marked this pull request as ready for review October 16, 2023 10:22
@ggerganov ggerganov added refactoring Refactoring need feedback Testing and feedback with results are needed labels Oct 16, 2023
@ggerganov
Copy link
Owner Author

ggerganov commented Oct 16, 2023

This is ready for review and testing. Need to make sure I didn't break some of the examples as there are changes impacting all of them. Let me know if you spot any regressions.

The performance of the tree-based approach is similar to the single-batch speculative decoding on M2 Ultra which is disappointing. Switching from a 7B codellama draft to a 1B tinyllama draft also does not improve the performance - could be due to low model quality still. Metal seems to behave weirdly when it switches between the 2 contexts (target and draft) - I'm thinking there is some overhead there, but I haven't confirmed or measured that yet. Would be interesting to do similar tests on CUDA GPUs. But overall, was expecting more from this tree-based speculative approach. Maybe I'm missing something.

@KerfuffleV2
Copy link
Collaborator

Would be interesting to do similar tests on CUDA GPUs.

I can try with ROCM if that helps. Can you point me at specific models to test with? (Only have an 8GB GPU so I probably only will be able to run the draft model fully on GPU.)

@ggerganov
Copy link
Owner Author

Hm, not sure if there is a meaningful test to run with 8GB. Maybe try using Q8_0 LLaMA 7B target model and a F16 LLaMA 160M (https://huggingface.co/JackFram/llama-160m) draft model similar to the tests in #3649.

@ggerganov ggerganov merged commit 0e89203 into master Oct 18, 2023
33 checks passed
@KerfuffleV2
Copy link
Collaborator

KerfuffleV2 commented Oct 18, 2023

4e82b2e

params.sampling_params.temp = std::max(0.01f, params.sampling_params.temp);

seems to result in nothing getting accepted.

Also, this is a little weird after those changes:

n_draft   = 12
n_predict = 101
n_drafted = 19
n_accept  = 18
accept    = 94.737%

Whoa, 94.73 accepted! That's amaz... oh wait, we just didn't draft.

Can't reject my tokens if I refuse to draft any. *taps forehead* This massive brain is always working.

(Just to be clear, that was with the temperature change reverted.)

@ggerganov
Copy link
Owner Author

What is the command that you are using?

The temperature fix is needed to avoid using the greedy sampling for the draft, because we rely on ctx_sampling->cur to be populated and sorted and when using greedy it is not. But maybe there is some side-effect that I missed.

@KerfuffleV2
Copy link
Collaborator

I can't reproduce the nothing getting accepted issue anymore. Sorry, we can probably just chalk it up to me being a moron.

* No skipping - old 1.0 temp
n_draft   = 6
n_predict = 31
n_drafted = 15
n_accept  = 15
accept    = 100.000%

* Skipping layers - master temp

n_draft   = 6
n_predict = 31
n_drafted = 48
n_accept  = 22
accept    = 45.833%

* Skipping - old 1.0 temp

n_draft   = 6
n_predict = 31
n_drafted = 8
n_accept  = 8
accept    = 100.000%

* Skipping - old 1.0 temp, p_accept 0.4, p_split 0.3

n_draft   = 6
n_predict = 31
n_drafted = 19
n_accept  = 18
accept    = 94.737%

* Skipping - old master temp, p_accept 0.4, p_split 0.3

No change - same as "Skipping - master temp"

Command was speculative with --seed 123 -c 1024 -n 100 --temp 0 --draft 6 -np 8 -ngld 7 -np 4 -n 30 (-np, -n, may have varied from when I originally tried it the other day).

Anyway, I guess you probably don't have to do anything here, but it does seem weird how with temperature 1 it basically refuses to draft.

Also, random question if you feel like answering (and this probably won't make much of a difference is the num draft limit seems like it rarely is hit) but wouldn't sorting the items when trying to split so you look at the highest probability ones first make splitting more effective? Otherwise you could split a sequence with a lower probability and not have slots open for other ones that might be more promising.

Something like:

// Who loves pairs? Well, I can tell you who doesn't...
std::vector<std::pair<int, float>> hmm;
for (int f = 1; f < 8; f++) hmm.push_back({f, cur_p[f].p });
std::sort(hmm.begin(), hmm.end(), [](const std::pair<int, float> & a, const std::pair<int, float> & b) {
    return a.second > b.second;
});
// attempt to split the branch if the probability is high enough
for (int f_ = 1; f_ < 8; ++f_) {
    const int f = hmm[f_ -  1].first;

@ggerganov
Copy link
Owner Author

ggerganov commented Oct 19, 2023

but it does seem weird how with temperature 1 it basically refuses to draft.

The drafting strategy needs some work. With --temp 1.0 it does not draft because there is almost never a token with probability > p_accept (see the log file). I usually add --top_k 3 --top_p 0.95 and this helps, or I play with the p_accept value. But in general, I don't think speculating with high temperature will ever be effective - there are just to many random paths the text can continue.

but wouldn't sorting the items when trying to split so you look at the highest probability ones first make splitting more effective?

Yes, I recently realized that as well and the workaround is to use --temp 0.01 - i.e. something close to zero but not zero so that we don't trigger the greedy sampling. This way, we get sorted candidates, although looking at the logs I've noticed some weird probabilities sometimes (i.e. adding up to more than 1.0, not always sorted, etc.) so there could be bugs in the sampling functions.

@jukofyork
Copy link
Contributor

There are 2 free parameters which need tuning:

  • Probability threshold for continuing to sample a draft sequence
  • Probability threshold for splitting the current draft sequence into a new sequence (i.e. tree branch)

How feasible would it be to use a priority-queue for this:

https://en.cppreference.com/w/cpp/container/priority_queue

and then just expand the next most likely node each time? Is there some efficiency the draft model is taking advantage of that this would give up or is the draft generating this tree sequentially anyway?

@jukofyork
Copy link
Contributor

It looks like @KerfuffleV2 asked this already:

but wouldn't sorting the items when trying to split so you look at the highest probability ones first make splitting more effective? Otherwise you could split a sequence with a lower probability and not have slots open for other ones that might be more promising.

@ggerganov
Copy link
Owner Author

How feasible would it be to use a priority-queue

The current approach expands all branches of the draft with 1 token each. Effectively doing batched decoding also for the draft model. But it makes a lot of sense to expand/split only the most probable one, 1 token at a time. It's something to try.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
need feedback Testing and feedback with results are needed refactoring Refactoring
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants