-
Notifications
You must be signed in to change notification settings - Fork 12.3k
top-k sort speedup #5085
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
top-k sort speedup #5085
Conversation
The problem raised here: ggml-org#5073 This patch solves any top-k that's smaller than the entire vocab, so for example a top-k of 10000 runs 29% faster on my I7 CPU. 0.93ms /token goes down to 0.72ms At 20000 tokens the speedup is only 9% At >= vocab the full sort is used. The new code should be equivalent to the normal sort. To really solve "large top-k" I see two ways forward: 1) possibly a lower precision sorting 2) a pre-selector that reduces top-k dynamically down to potential k candidates and then does the partial sort. After all, in almost all runs a top-k of 20000 logits is likely ignoring the lower 19550 due to temperature/p settings.
How does this work? It seems that the only reason for this to be faster would be for a poor implementation of |
Yea, partial_sort is slower than dividing + sort - I can't say for sure if all platforms will show the same effect. Without this PR: With this PR: Update: |
Unfortunately, I see potential degradation in sampling time for Mistral q8_0 with
~40% of the time spent is during sampling. Effectively 17t/s generation speed due to the sampling operations. EDIT: I was compiling my cmake build with "Debug" instead of "Release". That might have something to do with it... |
40% of generation time is spent sampling on Mistral q8_0 for the debug build. Compiling with the "Release" flag instead of debug in Visual Studio, mainline, topk=32000:
This PR, topk=32000:
topk=30000:
This PR, topk=30000:
topk=1:
This PR, topk=1:
|
...... This is not going to work for top-k 30000 This patch should show the performance improvements at top-k 10000 or 20000 The patch is disabled below top-k 3001 or top-k >= vocab_size |
The vocab size is 32,000 ? |
Currently the default sampling order seems to be top_k -> tail_free -> typical -> top_p -> min_p -> temp. More generally, since after softmax the values are bounded to the interval [0, 1] you could do a bucket sort algorithm with buckets based on the exponent of the floating point number. 10 buckets would cover the range [1e-3, 1], 20 buckets would cover the range [1e-6, 1]. Although it may also be a good idea to do sorting and top_k prior to softmax since calculating exponential functions is expensive.
Consider that the number of elements to be sorted here is at most the vocabulary size, i.e. 32000. An RTX 3090 has 82 streaming multiprocessors, which at the bare minimum should have 32 threads running on each of them. This would result in at least 2624 threads and in at most ~12 elements per thread. This is simply not enough to achieve good performance with a GPU over a CPU implementation. The GPU will spend most of its time idling and there will be significant overhead to coordinate the parallelism. I think sampling on the GPU is only worthwhile if it allows you to avoid CPU<->GPU data transfers or if you sample many tokens in parallel. |
I agree on that, that's what I meant by "dyn-k". As in restricting k to a reasonable value beforehand.
The sampling functions rely on the sorted status after top-k. I'm wondering if we can use nth_element in a way to prefilter a "min_p" approximation, something that cuts away the unlikely candidates |
I did a bucket sort implementation that on my system is faster than both master and this PR: #5101 |
Obsoleted by #5109 (merged) |
The problem raised here: #5073
This patch solves any top-k that's smaller than the entire vocab, so for example a top-k of 10000 runs 29% faster on my I7 CPU. 0.93ms /token goes down to 0.72ms
At 20000 tokens the speedup is only 9%
At >= vocab the full sort is used.
The new code should be equivalent to the normal sort.
To really solve "large top-k" I see two ways forward:
After all, in almost all runs a top-k of 20000 logits is likely ignoring the lower 19550 due to temperature/p settings.
So something like "dyn-k" might be useful.
Such large k's likely only play a role when using min-p with high temperature.