-
Notifications
You must be signed in to change notification settings - Fork 10.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
speculative : add tree-based sampling example #3624
Conversation
This is specifically to mess with pedantic people like me and see who cracks under the pressure? You got me. It actually took several work in progress in progress pulls but there is a limit to my willpower! edit: On a more serious note, looking at the changes and how much duplicated code there is for batch handling it really feels like a few convenience functions would help make that stuff easier to work with and maintain. There are some special cases, but a lot of the batch handling for prompt input and token generation is going to be about the same. Client/sequence handling also. |
Ah finally - someone cracked! 😄 True, will probably refactor this, but there are still some API changes popping up, so I'm not in a hurry to do it. For example, for the tree-based decoding functionality to be supported, we need to be able to assign multiple sequence ids to each input token and so the |
225de38
to
7e48e21
Compare
0620ee7
to
b5554b9
Compare
ggml-ci
bd8fa82
to
b8acb6c
Compare
ggml-ci
This is ready for review and testing. Need to make sure I didn't break some of the examples as there are changes impacting all of them. Let me know if you spot any regressions. The performance of the tree-based approach is similar to the single-batch speculative decoding on M2 Ultra which is disappointing. Switching from a 7B codellama draft to a 1B tinyllama draft also does not improve the performance - could be due to low model quality still. Metal seems to behave weirdly when it switches between the 2 contexts (target and draft) - I'm thinking there is some overhead there, but I haven't confirmed or measured that yet. Would be interesting to do similar tests on CUDA GPUs. But overall, was expecting more from this tree-based speculative approach. Maybe I'm missing something. |
ggml-ci
I can try with ROCM if that helps. Can you point me at specific models to test with? (Only have an 8GB GPU so I probably only will be able to run the draft model fully on GPU.) |
Hm, not sure if there is a meaningful test to run with 8GB. Maybe try using Q8_0 LLaMA 7B target model and a F16 LLaMA 160M (https://huggingface.co/JackFram/llama-160m) draft model similar to the tests in #3649. |
seems to result in nothing getting accepted. Also, this is a little weird after those changes:
Whoa, 94.73 accepted! That's amaz... oh wait, we just didn't draft. Can't reject my tokens if I refuse to draft any. *taps forehead* This massive brain is always working. (Just to be clear, that was with the temperature change reverted.) |
What is the command that you are using? The temperature fix is needed to avoid using the greedy sampling for the draft, because we rely on |
I can't reproduce the nothing getting accepted issue anymore. Sorry, we can probably just chalk it up to me being a moron.
Command was Anyway, I guess you probably don't have to do anything here, but it does seem weird how with temperature 1 it basically refuses to draft. Also, random question if you feel like answering (and this probably won't make much of a difference is the num draft limit seems like it rarely is hit) but wouldn't sorting the items when trying to split so you look at the highest probability ones first make splitting more effective? Otherwise you could split a sequence with a lower probability and not have slots open for other ones that might be more promising. Something like: // Who loves pairs? Well, I can tell you who doesn't...
std::vector<std::pair<int, float>> hmm;
for (int f = 1; f < 8; f++) hmm.push_back({f, cur_p[f].p });
std::sort(hmm.begin(), hmm.end(), [](const std::pair<int, float> & a, const std::pair<int, float> & b) {
return a.second > b.second;
});
// attempt to split the branch if the probability is high enough
for (int f_ = 1; f_ < 8; ++f_) {
const int f = hmm[f_ - 1].first; |
The drafting strategy needs some work. With
Yes, I recently realized that as well and the workaround is to use |
How feasible would it be to use a priority-queue for this: https://en.cppreference.com/w/cpp/container/priority_queue and then just expand the next most likely node each time? Is there some efficiency the draft model is taking advantage of that this would give up or is the draft generating this tree sequentially anyway? |
It looks like @KerfuffleV2 asked this already:
|
The current approach expands all branches of the draft with 1 token each. Effectively doing batched decoding also for the draft model. But it makes a lot of sense to expand/split only the most probable one, 1 token at a time. It's something to try. |
ref #3137
This PR demonstrates speculative decoding using tree-based sampling of the drafts.
It also improves the
llama_batch
API to support assigning multiple sequence ids to the tokens in the batch.The
common/sampling.h
API has also been reworkedTree-based sampling
In the standard speculative decoding, the draft model samples a single sequence of
n_draft
tokens:A further improvement on top of this strategy is to instead sample
n_parallel
draft sequences together using batched decoding and evaluate all of them in a single pass on the target model:The drawback of this approach is that the different sequences can share the same prefix in which case we will be evaluating the same tokens multiple times. We can take advantage of
llama.cpp
's support for custom attention mask (#3228) and evaluate the same set ofn_parallel
sequences, by evaluating just the unique tokens:There are 2 free parameters which need tuning:
llama.cpp/examples/speculative/speculative.cpp
Lines 275 to 281 in 1c626e2
llama.cpp/examples/speculative/speculative.cpp
Lines 284 to 289 in 1c626e2
Probably these parameters need some fine-tuning or some better strategy is required.
llama_batch
changesThe
seq_id
member is now an array for each token. The number of elements in the array is specified via then_seq_id
member:The
llama_batch_init()
now requires to specify the maximum number of sequences that an input token can belong to. This is needed to pre-allocate theseq_id
arrays with enough size:common/sampling.h
changesThe
llama_sampling_context
now contains information for a single sequence. Thelast_tokens
andcandidates
structures are now part of the context and are calledprev
andcur
. The plan is to mergellama_sampling_context
inllama.h
in the future, probably after the grammar functionality is merged first.Here is the new way to sample tokens:
There are also helper functions to construct
llama_batch
that are recommended for the examples:Usage
These are examples of running the new tree-based speculative decoding: