Skip to content

top-k sort speedup #5085

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed

Conversation

cmp-nct
Copy link
Contributor

@cmp-nct cmp-nct commented Jan 22, 2024

The problem raised here: #5073

This patch solves any top-k that's smaller than the entire vocab, so for example a top-k of 10000 runs 29% faster on my I7 CPU. 0.93ms /token goes down to 0.72ms

At 20000 tokens the speedup is only 9%
At >= vocab the full sort is used.

The new code should be equivalent to the normal sort.

To really solve "large top-k" I see two ways forward:

  1. possibly a lower precision sorting
  2. a pre-selector that reduces top-k dynamically down to potential k candidates and then does the partial sort.
    After all, in almost all runs a top-k of 20000 logits is likely ignoring the lower 19550 due to temperature/p settings.
    So something like "dyn-k" might be useful.

Such large k's likely only play a role when using min-p with high temperature.

The problem raised here: ggml-org#5073

This patch solves any top-k that's smaller than the entire vocab, so for example a top-k of 10000 runs 29% faster on my I7 CPU.
0.93ms /token goes down to 0.72ms

At 20000 tokens the speedup is only 9%
At >= vocab the full sort is used.

The new code should be equivalent to the normal sort.

To really solve "large top-k" I see two ways forward:
1) possibly a lower precision sorting
2) a pre-selector that reduces top-k dynamically down to potential k candidates and then does the partial sort.
After all, in almost all runs a top-k of 20000 logits is likely ignoring the lower 19550 due to temperature/p settings.
@slaren
Copy link
Member

slaren commented Jan 22, 2024

How does this work? It seems that the only reason for this to be faster would be for a poor implementation of std::partial_sort. As such, the result would be dependent on the specific version of the standard library being used.

@cmp-nct
Copy link
Contributor Author

cmp-nct commented Jan 22, 2024

How does this work? It seems that the only reason for this to be faster would be for a poor implementation of std::partial_sort. As such, the result would be dependent on the specific version of the standard library being used.

Yea, partial_sort is slower than dividing + sort - I can't say for sure if all platforms will show the same effect.
I ran quite a few tests and the speedup is significant, it would be interesting to see how it performs for @kalomaze given he had 10 times slower performance than on my box.
If the same 30% improvement applies to him, that would take the sampling time down from 15sec to 10sec.

Without this PR:
Small top-k I run at ~ 0.1ms/token
full top-k 1.6ms/token
10k top-k 0.93ms/tok

With this PR:
10k top-k: 0.72ms/tok

Update:
@slaren I ran more tests, my PR was a bit too fast..
Below top-k of 2500 partial_sort is faster, between 2500-3500 my variant is taking over.
So I added a condition that switches at 3000 from partial_sort to nth-sort, it's as before on smaller k now.
This would benefit from a few tests on different CPUs and platforms

@kalomaze
Copy link
Contributor

kalomaze commented Jan 23, 2024

Unfortunately, I see potential degradation in sampling time for Mistral q8_0 with --top_k 0 or --top_k 32000, fully offloaded to an RTX 3060:

llama_print_timings:        load time =    3261.37 ms
llama_print_timings:      sample time =    5838.35 ms /   256 runs   (   22.81 ms per token,    43.85 tokens per second)
llama_print_timings: prompt eval time =      90.03 ms /    12 tokens (    7.50 ms per token,   133.29 tokens per second)
llama_print_timings:        eval time =    8590.56 ms /   255 runs   (   33.69 ms per token,    29.68 tokens per second)
llama_print_timings:       total time =   15087.90 ms /   267 tokens

~40% of the time spent is during sampling. Effectively 17t/s generation speed due to the sampling operations.
I think the ideal implementation for sampling operations includes parallelization on GPU to minimize time spent outside of the GPU.

EDIT: I was compiling my cmake build with "Debug" instead of "Release". That might have something to do with it...

@kalomaze
Copy link
Contributor

kalomaze commented Jan 23, 2024

40% of generation time is spent sampling on Mistral q8_0 for the debug build.
6% of generation time is spent on sampling for the release build.

Compiling with the "Release" flag instead of debug in Visual Studio, mainline, topk=32000:

llama_print_timings:        load time =    3074.75 ms
llama_print_timings:      sample time =    3849.22 ms /  2048 runs   (    1.88 ms per token,   532.06 tokens per second)
llama_print_timings: prompt eval time =      87.80 ms /    12 tokens (    7.32 ms per token,   136.68 tokens per second)
llama_print_timings:        eval time =   66332.97 ms /  2047 runs   (   32.40 ms per token,    30.86 tokens per second)
llama_print_timings:       total time =   70838.14 ms /  2059 tokens
Log end

This PR, topk=32000:

llama_print_timings:        load time =    3110.09 ms
llama_print_timings:      sample time =     256.02 ms /   128 runs   (    2.00 ms per token,   499.96 tokens per second)
llama_print_timings: prompt eval time =      87.40 ms /    12 tokens (    7.28 ms per token,   137.30 tokens per second)
llama_print_timings:        eval time =    4049.94 ms /   127 runs   (   31.89 ms per token,    31.36 tokens per second)
llama_print_timings:       total time =    4427.30 ms /   139 tokens
Log end

topk=30000:

llama_print_timings:        load time =    3113.99 ms
llama_print_timings:      sample time =     214.04 ms /   128 runs   (    1.67 ms per token,   598.01 tokens per second)
llama_print_timings: prompt eval time =      87.64 ms /    12 tokens (    7.30 ms per token,   136.93 tokens per second)
llama_print_timings:        eval time =    4044.56 ms /   127 runs   (   31.85 ms per token,    31.40 tokens per second)
llama_print_timings:       total time =    4383.51 ms /   139 tokens
Log end

This PR, topk=30000:

llama_print_timings:        load time =    3065.78 ms
llama_print_timings:      sample time =     255.24 ms /   128 runs   (    1.99 ms per token,   501.49 tokens per second)
llama_print_timings: prompt eval time =      87.86 ms /    12 tokens (    7.32 ms per token,   136.57 tokens per second)
llama_print_timings:        eval time =    4052.87 ms /   127 runs   (   31.91 ms per token,    31.34 tokens per second)
llama_print_timings:       total time =    4432.78 ms /   139 tokens
Log end

topk=1:

llama_print_timings:        load time =    3153.41 ms
llama_print_timings:      sample time =      13.95 ms /   128 runs   (    0.11 ms per token,  9174.97 tokens per second)
llama_print_timings: prompt eval time =      87.71 ms /    12 tokens (    7.31 ms per token,   136.81 tokens per second)
llama_print_timings:        eval time =    4051.76 ms /   127 runs   (   31.90 ms per token,    31.34 tokens per second)
llama_print_timings:       total time =    4189.45 ms /   139 tokens
Log end

This PR, topk=1:

llama_print_timings:        load time =    3091.83 ms
llama_print_timings:      sample time =      14.19 ms /   128 runs   (    0.11 ms per token,  9020.44 tokens per second)
llama_print_timings: prompt eval time =      88.03 ms /    12 tokens (    7.34 ms per token,   136.31 tokens per second)
llama_print_timings:        eval time =    4037.71 ms /   127 runs   (   31.79 ms per token,    31.45 tokens per second)
llama_print_timings:       total time =    4171.30 ms /   139 tokens
Log end

@cmp-nct
Copy link
Contributor Author

cmp-nct commented Jan 23, 2024

40% of generation time is spent sampling on Mistral q8_0 for the debug build. 6% of generation time is spent on sampling for the release build.

Compiling with the "Release" flag instead of debug in Visual Studio, mainline, topk=32000:

llama_print_timings:        load time =    3074.75 ms

......

This is not going to work for top-k 30000
top-k 30000 will use the normal sorting path (exactly as before) because it's larger than vocab size.

This patch should show the performance improvements at top-k 10000 or 20000
So large top-k, but something like "half" or 1/3 of the vocab.

The patch is disabled below top-k 3001 or top-k >= vocab_size
So the differences you have seen at 30k ot 1 should be related to something else that happened in the code, maybe another optimization attempt from a previous PR?

@kalomaze
Copy link
Contributor

top-k 30000 will use the normal sorting path (exactly as before) because it's larger than vocab size.

The vocab size is 32,000 ?

@JohannesGaessler
Copy link
Collaborator

Such large k's likely only play a role when using min-p with high temperature.

Currently the default sampling order seems to be top_k -> tail_free -> typical -> top_p -> min_p -> temp.
For this particular use case, wouldn't it make sense to do min_p first?
Unless I'm misunderstanding the order shouldn't matter for min_p but it should be relatively fast since it's only a single pass over the token probabilities and it should reduce the number of elements that you will then have to sort.

More generally, since after softmax the values are bounded to the interval [0, 1] you could do a bucket sort algorithm with buckets based on the exponent of the floating point number. 10 buckets would cover the range [1e-3, 1], 20 buckets would cover the range [1e-6, 1]. Although it may also be a good idea to do sorting and top_k prior to softmax since calculating exponential functions is expensive.

I think the ideal implementation for sampling operations includes parallelization on GPU to minimize time spent outside of the GPU.

Consider that the number of elements to be sorted here is at most the vocabulary size, i.e. 32000. An RTX 3090 has 82 streaming multiprocessors, which at the bare minimum should have 32 threads running on each of them. This would result in at least 2624 threads and in at most ~12 elements per thread. This is simply not enough to achieve good performance with a GPU over a CPU implementation. The GPU will spend most of its time idling and there will be significant overhead to coordinate the parallelism. I think sampling on the GPU is only worthwhile if it allows you to avoid CPU<->GPU data transfers or if you sample many tokens in parallel.

@cmp-nct
Copy link
Contributor Author

cmp-nct commented Jan 23, 2024

ntly the default sampling order see

I agree on that, that's what I meant by "dyn-k". As in restricting k to a reasonable value beforehand.
But we can't run min_p without a sort:

 for (; i < candidates->size; ++i) {
        if (candidates->data[i].p < p * scale && i >= min_keep) {
            break; // prob too small
        }
    }

The sampling functions rely on the sorted status after top-k.
min_p also does a full softmax() on the candidates.

I'm wondering if we can use nth_element in a way to prefilter a "min_p" approximation, something that cuts away the unlikely candidates
min_p_prefilter -> top-k ->..->min_p

@JohannesGaessler
Copy link
Collaborator

I did a bucket sort implementation that on my system is faster than both master and this PR: #5101

@cebtenzzre
Copy link
Collaborator

Obsoleted by #5109 (merged)

@cebtenzzre cebtenzzre closed this Jan 26, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants