Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUDA: faster non k-quant mul_mat_q kernels #2483

Merged

Conversation

JohannesGaessler
Copy link
Collaborator

@JohannesGaessler JohannesGaessler commented Aug 1, 2023

This PR adds CUDA performance optimization for non k-quant (mul_mat_q) kernels. The changes are mostly two things:

  1. I took over the approach that @ikawrakow used for q4_K and q5_K where he processed multiple integers per CUDA thread. For token generation 2 ints are processed per thread while for prompt processing 4/8 are processed per thread.
  2. q5 formats are now immediately converted to q8 when being loaded to shared memory. This reduces the amount of computation needed to assemble q5 from the lower and upper bits by a factor of 32.

These are the results on my system:

GPU Model Test t/s master t/s PR Speedup
RTX 3090 7b q4_0 pp 1574 1993 1.27
RTX 3090 7b q4_1 pp 1392 1848 1.33
RTX 3090 7b q5_0 pp 1094 1536 1.40
RTX 3090 7b q5_1 pp 1062 1580 1.49
RTX 3090 7b q8_0 pp 1206 1553 1.29
RTX 3090 7b q4_0 tg128 130.26 132.46 1.02
RTX 3090 7b q4_1 tg128 123.48 124.55 1.01
RTX 3090 7b q5_0 tg128 112.62 113.97 1.01
RTX 3090 7b q5_1 tg128 107.79 107.66 1.00
RTX 3090 7b q8_0 tg128 82.48 83.54 1.01
P40 7b q4_0 pp 703 826 1.17
P40 7b q4_1 pp 455 757 1.66
P40 7b q5_0 pp 401 681 1.70
P40 7b q5_1 pp 378 738 1.95
P40 7b q8_0 pp 559 724 1.30
P40 7b q4_0 tg128 50.45 54.09 1.07
P40 7b q4_1 tg128 49.92 51.81 1.04
P40 7b q5_0 tg128 43.36 46.78 1.08
P40 7b q5_1 tg128 45.72 47.10 1.03
P40 7b q8_0 tg128 32.38 33.85 1.05

For reference: the speed of cuBLAS is ~1500 t/s on my RTX 3090 and ~500 t/s on my P40. So for non k-quants the mul_mat_q kernels now seem to be universally faster than cuBLAS.

@Dampfinchen
Copy link

Dampfinchen commented Aug 2, 2023

Another excellent contribution! Processing 1800 tokens with 13B q5_1 on an RTX 2060 laptop results in pp of 11.7ms/t with this PR. With cublas it's 14.8ms/t. A very noticeable and most welcome speedboost!

Copy link
Member

@slaren slaren left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3090 Ti / WSL2 / 7B q4_0 pp:
Master: 1544 t/s
PR: 1776 t/s

@JohannesGaessler JohannesGaessler merged commit 468ea24 into ggml-org:master Aug 2, 2023
@Nexesenex
Copy link
Contributor

Nexesenex commented Aug 2, 2023

There might be a hiccup there.
At least on K Quants, the memory occupations grows much more with a growing context with this PR. As the context grows when I run an Airoboros 33b K_3S model launched via Cublas (63 layers in VRAM), the VRAM occupation growth is 3-5 times faster than with the previous mul_mat_q implementation, which leads me quickly to be out of memory.
Tested on KoboldCPP experimental, compiled without that PR, then with it.

@JohannesGaessler
Copy link
Collaborator Author

Sorry, but I don't see how that could in any way be related to the changes I did in this PR.

@Nexesenex
Copy link
Contributor

It's weird for me too, so I redid the test.

At 3800 tokens in context with your PR (63/63, batch size of 256, max context of 5632), VRAM gets full and I get "CUDA error 2 at B:\kobold2\ggml-cuda.cu:4194: out of memory" on the aforementioned 33b K_3S model.

Without your PR, at 3,800 tokens in context, the VRAM of my 3090 is at 24075/24576 mb occupied.

In both cases, the VRAM occupation at zero token in context is at 23893 mb once the model is loaded.

I don't think that I picked a wrong PR prior to compilation.

@cebtenzzre
Copy link
Collaborator

@Nexesenex Just so everyone is on the same page, you are comparing commits 4f6b60c and 468ea24?

@Nexesenex
Copy link
Contributor

Nexesenex commented Aug 3, 2023

468ea24 is the one behaving anormally.

4f6b60c is the one behaving normally.

Also, I compile under VS2019 with cuda_11.4.4_472.50.

@LostRuins
Copy link
Collaborator

LostRuins commented Aug 3, 2023

@Nexesenex for best results to avoid any complicating factors, I'd recommend benchmarking directly from code from this repo for comparison again - although I am using the CUDA code verbatim when I merge downstream, there might potentially be other components in koboldcpp that could influence speed/memory usage.

I haven't tested this PR myself yet. I will revisit this when I merge it in the next release.

@JohannesGaessler
Copy link
Collaborator Author

On llama.cpp I can't reproduce the issue. On my machine VRAM usage is the exact same.

@Nexesenex
Copy link
Contributor

Then, I might have compiled a little Frankenstein, Johannes.

Maybe it could be about Cuda 11.4, I don't know. I'll compile and test your next experimental build, Lostruins, once it includes this PR, and report here if I still have the issue.

@Nexesenex
Copy link
Contributor

Nexesenex commented Aug 6, 2023

The problem of memory leak is solved for me with the last experimental build of KoboldCPP including the present commit and its later revision (f514d1b), and compiled also with that additional PR: #2529), on the same model with the same settings.
Also, 30% faster prompt processing (Q_K_3S)!
Sorry for the disturbance, and thanks for the great work!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants