Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUDA: faster q2_K, q3_K MMQ + int8 tensor cores #7921

Merged
merged 6 commits into from
Jun 14, 2024

Conversation

JohannesGaessler
Copy link
Collaborator

@JohannesGaessler JohannesGaessler commented Jun 13, 2024

This PR overhauls the q2_K and q3_K mul_mat_q kernels and adds int8 tensor core support. Now that all MMQ-supported quantization formats have int8 tensor core support it was possible to simplify the code a little. To make q2_K work I had to implement an alternative data format for q8_1 where instead of the per-block sum the per-halfblock sums are stored (as int8 relative to the max value). The precision loss from doing this seems to be negligible.

ggml-cuda.cu now queries and stores the maximum opt-in shared memory per streaming multiprocessor. This is the actual value that should be used as the upper limit for shared memory for kernel launches (needs an additional function call to explicitly raise).

Performance vs. master MMQ
GPU Model Microbatch size Test t/s master t/s cuda-ptx-mma-17 Speedup
RTX 4090 llama 8B Q2_K_M 16 pp2048 1190.56 1702.40 1.43
RTX 4090 llama 8B Q2_K_M 32 pp2048 1588.46 2689.24 1.69
RTX 4090 llama 8B Q2_K_M 64 pp2048 2462.88 4120.64 1.67
RTX 4090 llama 8B Q2_K_M 128 pp2048 2868.68 5635.31 1.96
RTX 4090 llama 8B Q2_K_M 256 pp2048 3553.10 7022.20 1.98
RTX 4090 llama 8B Q2_K_M 512 pp2048 3909.38 7639.35 1.95
RTX 4090 llama 8B Q2_K_M 1024 pp2048 4105.15 7755.40 1.89
RTX 4090 llama 8B Q2_K_M 2048 pp2048 3936.76 7083.78 1.80
RTX 4090 llama 8B Q3_K_S 16 pp2048 997.96 1611.31 1.61
RTX 4090 llama 8B Q3_K_S 32 pp2048 1346.93 2644.87 1.96
RTX 4090 llama 8B Q3_K_S 64 pp2048 2289.07 4190.48 1.83
RTX 4090 llama 8B Q3_K_S 128 pp2048 2836.75 6014.07 2.12
RTX 4090 llama 8B Q3_K_S 256 pp2048 3710.10 7712.51 2.08
RTX 4090 llama 8B Q3_K_S 512 pp2048 4183.62 8659.58 2.07
RTX 4090 llama 8B Q3_K_S 1024 pp2048 4397.62 8644.62 1.97
RTX 4090 llama 8B Q3_K_S 2048 pp2048 4248.61 7935.98 1.87
RTX 3090 llama 8B Q2_K_M 16 pp2048 595.56 977.44 1.64
RTX 3090 llama 8B Q2_K_M 32 pp2048 799.47 1378.14 1.72
RTX 3090 llama 8B Q2_K_M 64 pp2048 937.48 1932.08 2.06
RTX 3090 llama 8B Q2_K_M 128 pp2048 1211.96 2414.37 1.99
RTX 3090 llama 8B Q2_K_M 256 pp2048 1516.43 2956.45 1.95
RTX 3090 llama 8B Q2_K_M 512 pp2048 1582.72 3084.91 1.95
RTX 3090 llama 8B Q2_K_M 1024 pp2048 1646.13 3196.07 1.94
RTX 3090 llama 8B Q2_K_M 2048 pp2048 1628.12 3125.93 1.92
RTX 3090 llama 8B Q3_K_S 16 pp2048 492.46 923.08 1.87
RTX 3090 llama 8B Q3_K_S 32 pp2048 649.50 1315.24 2.03
RTX 3090 llama 8B Q3_K_S 64 pp2048 859.15 1943.93 2.26
RTX 3090 llama 8B Q3_K_S 128 pp2048 1230.25 2646.62 2.15
RTX 3090 llama 8B Q3_K_S 256 pp2048 1543.02 3305.31 2.14
RTX 3090 llama 8B Q3_K_S 512 pp2048 1665.09 3487.42 2.09
RTX 3090 llama 8B Q3_K_S 1024 pp2048 1743.56 3575.74 2.05
RTX 3090 llama 8B Q3_K_S 2048 pp2048 1735.05 3470.47 2.00
RX 6800 llama 8B Q2_K_M 16 pp2048 140.08 153.15 1.09
RX 6800 llama 8B Q2_K_M 32 pp2048 159.30 197.70 1.24
RX 6800 llama 8B Q2_K_M 64 pp2048 191.10 238.01 1.25
RX 6800 llama 8B Q2_K_M 128 pp2048 228.92 295.19 1.29
RX 6800 llama 8B Q2_K_M 256 pp2048 273.17 349.91 1.28
RX 6800 llama 8B Q2_K_M 512 pp2048 288.55 364.26 1.26
RX 6800 llama 8B Q2_K_M 1024 pp2048 277.80 345.47 1.24
RX 6800 llama 8B Q2_K_M 2048 pp2048 256.47 311.28 1.21
RX 6800 llama 8B Q3_K_S 16 pp2048 115.97 134.14 1.16
RX 6800 llama 8B Q3_K_S 32 pp2048 127.77 160.74 1.26
RX 6800 llama 8B Q3_K_S 64 pp2048 164.16 206.02 1.26
RX 6800 llama 8B Q3_K_S 128 pp2048 192.03 249.73 1.30
RX 6800 llama 8B Q3_K_S 256 pp2048 228.57 292.16 1.28
RX 6800 llama 8B Q3_K_S 512 pp2048 242.35 306.31 1.26
RX 6800 llama 8B Q3_K_S 1024 pp2048 235.30 294.04 1.25
RX 6800 llama 8B Q3_K_S 2048 pp2048 220.55 269.35 1.22
P40 llama 8B Q2_K_M 16 pp2048 273.10 314.75 1.15
P40 llama 8B Q2_K_M 32 pp2048 358.76 427.85 1.19
P40 llama 8B Q2_K_M 64 pp2048 464.29 558.64 1.20
P40 llama 8B Q2_K_M 128 pp2048 534.66 677.66 1.27
P40 llama 8B Q2_K_M 256 pp2048 582.77 731.19 1.25
P40 llama 8B Q2_K_M 512 pp2048 613.71 768.46 1.25
P40 llama 8B Q2_K_M 1024 pp2048 613.27 759.92 1.24
P40 llama 8B Q2_K_M 2048 pp2048 590.71 728.91 1.23
P40 llama 8B Q3_K_S 16 pp2048 241.94 249.37 1.03
P40 llama 8B Q3_K_S 32 pp2048 320.23 355.26 1.11
P40 llama 8B Q3_K_S 64 pp2048 452.54 491.01 1.09
P40 llama 8B Q3_K_S 128 pp2048 520.61 572.41 1.10
P40 llama 8B Q3_K_S 256 pp2048 564.99 620.82 1.10
P40 llama 8B Q3_K_S 512 pp2048 592.96 652.16 1.10
P40 llama 8B Q3_K_S 1024 pp2048 586.02 643.80 1.10
P40 llama 8B Q3_K_S 2048 pp2048 565.22 618.79 1.09
Performance vs. master cuBLAS
GPU Model Microbatch size Test t/s master t/s cuda-ptx-mma-17 Speedup
RTX 4090 llama 8B Q2_K_M 16 pp2048 1189.99 1702.40 1.43
RTX 4090 llama 8B Q2_K_M 32 pp2048 1580.39 2689.24 1.70
RTX 4090 llama 8B Q2_K_M 64 pp2048 2454.15 4120.64 1.68
RTX 4090 llama 8B Q2_K_M 128 pp2048 3642.99 5635.31 1.55
RTX 4090 llama 8B Q2_K_M 256 pp2048 5897.75 7022.20 1.19
RTX 4090 llama 8B Q2_K_M 512 pp2048 7797.40 7639.35 0.98
RTX 4090 llama 8B Q2_K_M 1024 pp2048 9039.66 7755.40 0.86
RTX 4090 llama 8B Q2_K_M 2048 pp2048 8915.95 7083.78 0.79
RTX 4090 llama 8B Q3_K_S 16 pp2048 991.74 1611.31 1.62
RTX 4090 llama 8B Q3_K_S 32 pp2048 1344.66 2644.87 1.97
RTX 4090 llama 8B Q3_K_S 64 pp2048 2287.11 4190.48 1.83
RTX 4090 llama 8B Q3_K_S 128 pp2048 3542.60 6014.07 1.70
RTX 4090 llama 8B Q3_K_S 256 pp2048 5785.44 7712.51 1.33
RTX 4090 llama 8B Q3_K_S 512 pp2048 7704.78 8659.58 1.12
RTX 4090 llama 8B Q3_K_S 1024 pp2048 9015.20 8644.62 0.96
RTX 4090 llama 8B Q3_K_S 2048 pp2048 8939.09 7935.98 0.89
RTX 3090 llama 8B Q2_K_M 16 pp2048 585.72 977.44 1.67
RTX 3090 llama 8B Q2_K_M 32 pp2048 778.38 1378.14 1.77
RTX 3090 llama 8B Q2_K_M 64 pp2048 913.46 1932.08 2.12
RTX 3090 llama 8B Q2_K_M 128 pp2048 2274.05 2414.37 1.06
RTX 3090 llama 8B Q2_K_M 256 pp2048 3354.70 2956.45 0.88
RTX 3090 llama 8B Q2_K_M 512 pp2048 3984.17 3084.91 0.77
RTX 3090 llama 8B Q2_K_M 1024 pp2048 4692.10 3196.07 0.68
RTX 3090 llama 8B Q2_K_M 2048 pp2048 4739.62 3125.93 0.66
RTX 3090 llama 8B Q3_K_S 16 pp2048 477.93 923.08 1.93
RTX 3090 llama 8B Q3_K_S 32 pp2048 625.05 1315.24 2.10
RTX 3090 llama 8B Q3_K_S 64 pp2048 829.36 1943.93 2.34
RTX 3090 llama 8B Q3_K_S 128 pp2048 2102.58 2646.62 1.26
RTX 3090 llama 8B Q3_K_S 256 pp2048 3151.56 3305.31 1.05
RTX 3090 llama 8B Q3_K_S 512 pp2048 3847.53 3487.42 0.91
RTX 3090 llama 8B Q3_K_S 1024 pp2048 4594.82 3575.74 0.78
RTX 3090 llama 8B Q3_K_S 2048 pp2048 4694.42 3470.47 0.74

The performance on Ampere and Ada Lovelace seems to still be suboptimal relative to FP16 cuBLAS GEMM.

@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Jun 13, 2024
@JohannesGaessler JohannesGaessler added the Review Complexity : High Generally require indepth knowledge of LLMs or GPUs label Jun 13, 2024
@slaren
Copy link
Collaborator

slaren commented Jun 14, 2024

I am not sure how big of a problem this is, but extending the test-backend-ops tests to include q2_k in base_types causes it to fail frequently.

  MUL_MAT(type_a=q2_K,type_b=f32,m=16,n=16,k=256,bs=[1,1],nr=[1,1]): [MUL_MAT] NMSE = 0.000539679 > 0.000500000 FAIL
  MUL_MAT(type_a=q2_K,type_b=f32,m=16,n=16,k=256,bs=[10,1],nr=[1,1]): [MUL_MAT] NMSE = 0.000567695 > 0.000500000 FAIL
  MUL_MAT(type_a=q2_K,type_b=f32,m=16,n=16,k=256,bs=[10,1],nr=[2,1]): [MUL_MAT] NMSE = 0.000584948 > 0.000500000 FAIL
  MUL_MAT(type_a=q2_K,type_b=f32,m=16,n=16,k=256,bs=[10,10],nr=[1,1]): [MUL_MAT] NMSE = 0.000510545 > 0.000500000 FAIL
  MUL_MAT(type_a=q2_K,type_b=f32,m=16,n=16,k=256,bs=[10,10],nr=[2,1]): [MUL_MAT] NMSE = 0.000504169 > 0.000500000 FAIL

@JohannesGaessler
Copy link
Collaborator Author

JohannesGaessler commented Jun 14, 2024

Sorry, it seems my testing setup for precision was broken due to #7809 ; I had checked the KL divergence using perplexity which was the old binary which due to the name change was not getting removed by make clean. And since I got the exact same value (within the printed precision) I assumed the precision loss to be negligible. These are the correct results:

Model imatrix Code PPL KL Divergence vs. FP16 Mean Δp
LLaMA 3 iq2_M WT 10m master cuBLAS 8.598441 ± 0.055100 0.325965 ± 0.001606 -6.467 ± 0.046 %
LLaMA 3 q2_K_M WT 10m master cuBLAS 8.646568 ± 0.055594 0.332531 ± 0.001572 -6.507 ± 0.047 %
LLaMA 3 q2_K_M WT 10m master MMQ 8.646308 ± 0.055604 0.332680 ± 0.001573 -6.503 ± 0.047 %
LLaMA 3 q2_K_M WT 10m PR MMQ 8.690794 ± 0.055918 0.337893 ± 0.001585 -6.613 ± 0.047 %
LLaMA 3 q2_K_S WT 10m master cuBLAS 9.321797 ± 0.061532 0.403376 ± 0.001787 -7.137 ± 0.049 %
LLaMA 3 q2_K_S WT 10m master MMQ 9.322530 ± 0.061534 0.403561 ± 0.001788 -7.146 ± 0.049 %
LLaMA 3 q2_K_S WT 10m PR MMQ 9.467689 ± 0.062546 0.418697 ± 0.001824 -7.488 ± 0.050 %
LLaMA 3 iq2_s WT 10m master cuBLAS 9.652453 ± 0.063226 0.439268 ± 0.001975 -8.325 ± 0.052 %

For q2_K_M I think this is fine but for q2_K_S I would like a better solution; it would also be possible to quantize the activations as q8_1 with a block size of 16 in the first place but that will require comparatively more changes. Since right now q2_K_S is I think not worth using over iq2_M and iq2_S anyways the way I would like to proceed is to merge this PR as-is, optimize MMQ for those quant formats that are already implemented, figure out how to best refactor and simplify the code, and then add the unsupported quantization formats during which I will revisit this.

@slaren
Copy link
Collaborator

slaren commented Jun 14, 2024

racecheck reports an error with mmq:

========= Error: Race reported between Write access at void load_tiles_q2_K<(int)64, (int)4, (bool)1>(const char *, int *, __half2 *, int *, const int &, const int &, const int &)+0x3f70 in /home/diego/code/llama.cpp/ggml-cuda/mmq.cuh:853
=========     and Write access at void load_tiles_q2_K<(int)64, (int)4, (bool)1>(const char *, int *, __half2 *, int *, const int &, const int &, const int &)+0x3f10 in /home/diego/code/llama.cpp/ggml-cuda/mmq.cuh:863 [64 hazards]
=========     and Write access at void load_tiles_q2_K<(int)64, (int)4, (bool)1>(const char *, int *, __half2 *, int *, const int &, const int &, const int &)+0x3f30 in /home/diego/code/llama.cpp/ggml-cuda/mmq.cuh:863 [64 hazards]
=========     and Write access at void load_tiles_q2_K<(int)64, (int)4, (bool)1>(const char *, int *, __half2 *, int *, const int &, const int &, const int &)+0x3f40 in /home/diego/code/llama.cpp/ggml-cuda/mmq.cuh:863 [64 hazards]
=========     and Write access at void load_tiles_q2_K<(int)64, (int)4, (bool)1>(const char *, int *, __half2 *, int *, const int &, const int &, const int &)+0x3f50 in /home/diego/code/llama.cpp/ggml-cuda/mmq.cuh:863 [64 hazards]
=========     and Write access at void load_tiles_q2_K<(int)64, (int)4, (bool)1>(const char *, int *, __half2 *, int *, const int &, const int &, const int &)+0x3f80 in /home/diego/code/llama.cpp/ggml-cuda/mmq.cuh:863 [64 hazards]
=========     and Write access at void load_tiles_q2_K<(int)64, (int)4, (bool)1>(const char *, int *, __half2 *, int *, const int &, const int &, const int &)+0x3f90 in /home/diego/code/llama.cpp/ggml-cuda/mmq.cuh:863 [64 hazards]
=========     and Write access at void load_tiles_q2_K<(int)64, (int)4, (bool)1>(const char *, int *, __half2 *, int *, const int &, const int &, const int &)+0x3fa0 in /home/diego/code/llama.cpp/ggml-cuda/mmq.cuh:863 [64 hazards]
=========     and Write access at void load_tiles_q2_K<(int)64, (int)4, (bool)1>(const char *, int *, __half2 *, int *, const int &, const int &, const int &)+0x3fb0 in /home/diego/code/llama.cpp/ggml-cuda/mmq.cuh:863 [64 hazards]
=========

I also noticed that the precision issue happens rarely happens with low batch sizes (mmvq).

@JohannesGaessler
Copy link
Collaborator Author

racecheck reports an error with mmq:

In this particular case it doesn't matter because multiple threads write the same value multiple times (that's faster than a conditional statement). But I think I can rewrite the code in such a way that this doesn't happen (at the cost of more register usage which at this point in the kernel is probably irrelevant).

I also noticed that the precision issue happens rarely happens with low batch sizes (mmvq).

That sounds to me like a different issue. What command are you using for testing?

@slaren
Copy link
Collaborator

slaren commented Jun 14, 2024

I'm using test-backend-ops. This test never fails the NMSE check (batch size=1):

test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q2_K, GGML_TYPE_F32, 16, 1, 256, { 1,  1}, {1, 1}));

But this test fails almost always (batch size=32):

test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q2_K, GGML_TYPE_F32, 16, 32, 256, { 1,  1}, {1, 1}));

@ggerganov
Copy link
Owner

In this particular case it doesn't matter because multiple threads write the same value multiple times (that's faster than a conditional statement).

I believe this still classifies as UB. But regardless, it would be better to not have any data races even if they are benign in practice

@JohannesGaessler
Copy link
Collaborator Author

The idea I had turned out to be slower than a conditional statement so I used that instead.

But this test fails almost always (batch size=32):

How does that relate to MMVQ? That kernel is only used for batch sizes <= 8.

Unrelated to that, should we maybe set a seed for the precision tests? (To my understanding the data that is used for testing is currently unseeded RNG output.) I think that would make it easier to differentiate between race conditions and other bugs.

@slaren
Copy link
Collaborator

slaren commented Jun 14, 2024

How does that relate to MMVQ? That kernel is only used for batch sizes <= 8.

Sorry, what I meant was the the issue does not happen when mmvq is used, only with mmq. I am assuming that mmvq also uses the new q8_1 formats and is affected by these changes, but maybe I am wrong.

@JohannesGaessler
Copy link
Collaborator Author

I am assuming that mmvq also uses the new q8_1 formats and is affected that this changes, but maybe I am wrong.

No, MMVQ uses the exact same data layout that it did prior to my changes.

@slaren
Copy link
Collaborator

slaren commented Jun 14, 2024

Unrelated to that, should we maybe set a seed for the precision tests? (To my understanding the data that is used for testing is currently unseeded RNG output.) I think that would make it easier to differentiate between race conditions and other bugs.

I would prefer if the CI tests are run with random seeds. I am concerned that if we always use the same data, then some bugs will never be detected just because it happens to work with the random values that were generated. If you think that having a fixed seed would be useful during development, maybe instead we could add a command line option to set the seed.

Currently, some quant formats are only tested with bs=1. We should add another test for the other_types mul_mat tests with at least bs=32 so that these issues are caught more reliably. But that would also cause the CI to fail every time on q2_K, so we would have to increase the NMSE threshold once again for mul mat, which is probably already too high.

The increase in ppl also seems very high to me. But I do not think that very low BPW formats like q2_K have many practical applications anyway, even with very large models the error is usually too high to be usable. So I am not even convinced that it is worth spending much effort on this.

@JohannesGaessler
Copy link
Collaborator Author

I would prefer if the CI tests are run with random seeds. I am concerned that if we always use the same data, then some bugs will never be detected just because it happens to work with the random values that were generated.

A fair point. My first choice for testing correctness is llama-perplexity anyways so if the current implementation is more convenient for you we should keep it as-is.

We should add another test for the other_types mul_mat tests with at least bs=32 so that these issues are caught more reliably.

We should also test with batch sizes that are not powers of 2 since those are typically the ones that suffer from out-of-bounds issues.

The increase in ppl also seems very high to me.

But that would also cause the CI to fail every time on q2_K, so we would have to increase the NMSE threshold once again for mul mat, which is probably already too high.

If this is causing too many issues I can for now revert the implementation back to the way it was before (on-the-fly calculation of the per-halfblock sums). It would be a simple change since I know exactly how to do it (but the performance would be worse).

@JohannesGaessler
Copy link
Collaborator Author

JohannesGaessler commented Jun 14, 2024

I reverted the changes related to q2_K precision. Precision looks like this:

Model imatrix Code PPL KL Divergence vs. FP16 Mean Δp
LLaMA 3 iq2_M WT 10m master cuBLAS 8.598441 ± 0.055100 0.325965 ± 0.001606 -6.467 ± 0.046 %
LLaMA 3 q2_K_M WT 10m master cuBLAS 8.646568 ± 0.055594 0.332531 ± 0.001572 -6.507 ± 0.047 %
LLaMA 3 q2_K_M WT 10m master MMQ 8.646308 ± 0.055604 0.332680 ± 0.001573 -6.503 ± 0.047 %
LLaMA 3 q2_K_M WT 10m PR MMQ 8.647440 ± 0.055611 0.332767 ± 0.001573 -6.506 ± 0.047 %
LLaMA 3 q2_K_S WT 10m master cuBLAS 9.321797 ± 0.061532 0.403376 ± 0.001787 -7.137 ± 0.049 %
LLaMA 3 q2_K_S WT 10m master MMQ 9.322530 ± 0.061534 0.403561 ± 0.001788 -7.146 ± 0.049 %
LLaMA 3 q2_K_S WT 10m PR MMQ 9.323076 ± 0.061531 0.403553 ± 0.001787 -7.151 ± 0.049 %
LLaMA 3 iq2_s WT 10m master cuBLAS 9.652453 ± 0.063226 0.439268 ± 0.001975 -8.325 ± 0.052 %

The performance will be worse since you now need to do twice as many int8 operations.

@JohannesGaessler
Copy link
Collaborator Author

Performance vs. master MMQ
GPU Model Microbatch size Test t/s master t/s cuda-ptx-mma-18 Speedup
RTX 4090 llama 8B Q2_K_M 16 pp2048 1190.56 1855.20 1.56
RTX 4090 llama 8B Q2_K_M 32 pp2048 1588.46 2721.91 1.71
RTX 4090 llama 8B Q2_K_M 64 pp2048 2462.88 4103.09 1.67
RTX 4090 llama 8B Q2_K_M 128 pp2048 2868.68 5504.39 1.92
RTX 4090 llama 8B Q2_K_M 256 pp2048 3553.10 6880.45 1.94
RTX 4090 llama 8B Q2_K_M 512 pp2048 3909.38 7483.21 1.91
RTX 4090 llama 8B Q2_K_M 1024 pp2048 4105.15 7573.81 1.84
RTX 4090 llama 8B Q2_K_M 2048 pp2048 3936.76 7029.54 1.79
RTX 3090 llama 8B Q2_K_M 16 pp2048 595.56 916.59 1.54
RTX 3090 llama 8B Q2_K_M 32 pp2048 799.47 1240.97 1.55
RTX 3090 llama 8B Q2_K_M 64 pp2048 937.48 1675.81 1.79
RTX 3090 llama 8B Q2_K_M 128 pp2048 1211.96 2089.21 1.72
RTX 3090 llama 8B Q2_K_M 256 pp2048 1516.43 2618.58 1.73
RTX 3090 llama 8B Q2_K_M 512 pp2048 1582.72 2730.08 1.72
RTX 3090 llama 8B Q2_K_M 1024 pp2048 1646.13 2822.00 1.71
RTX 3090 llama 8B Q2_K_M 2048 pp2048 1628.12 2775.45 1.70
RX 6800 llama 8B Q2_K_M 16 pp2048 140.08 126.13 0.90
RX 6800 llama 8B Q2_K_M 32 pp2048 159.30 168.21 1.06
RX 6800 llama 8B Q2_K_M 64 pp2048 191.10 215.66 1.13
RX 6800 llama 8B Q2_K_M 128 pp2048 228.92 265.33 1.16
RX 6800 llama 8B Q2_K_M 256 pp2048 273.17 317.25 1.16
RX 6800 llama 8B Q2_K_M 512 pp2048 288.55 332.35 1.15
RX 6800 llama 8B Q2_K_M 1024 pp2048 277.80 317.75 1.14
RX 6800 llama 8B Q2_K_M 2048 pp2048 256.47 289.17 1.13
P40 llama 8B Q2_K_M 16 pp2048 273.10 326.75 1.20
P40 llama 8B Q2_K_M 32 pp2048 358.76 412.26 1.15
P40 llama 8B Q2_K_M 64 pp2048 464.29 509.30 1.10
P40 llama 8B Q2_K_M 128 pp2048 534.66 621.33 1.16
P40 llama 8B Q2_K_M 256 pp2048 582.77 681.40 1.17
P40 llama 8B Q2_K_M 512 pp2048 613.71 714.58 1.16
P40 llama 8B Q2_K_M 1024 pp2048 613.27 712.67 1.16
P40 llama 8B Q2_K_M 2048 pp2048 590.71 685.07 1.16
Performance vs. master cuBLAS
GPU Model Microbatch size Test t/s master t/s cuda-ptx-mma-18 Speedup
RTX 4090 llama 8B Q2_K_M 16 pp2048 1189.99 1853.57 1.56
RTX 4090 llama 8B Q2_K_M 32 pp2048 1580.39 2719.81 1.72
RTX 4090 llama 8B Q2_K_M 64 pp2048 2454.15 4098.03 1.67
RTX 4090 llama 8B Q2_K_M 128 pp2048 3642.99 5474.59 1.50
RTX 4090 llama 8B Q2_K_M 256 pp2048 5897.75 6880.95 1.17
RTX 4090 llama 8B Q2_K_M 512 pp2048 7797.40 7489.97 0.96
RTX 4090 llama 8B Q2_K_M 1024 pp2048 9039.66 7563.97 0.84
RTX 4090 llama 8B Q2_K_M 2048 pp2048 8915.95 7025.17 0.79
RTX 3090 llama 8B Q2_K_M 16 pp2048 585.72 930.18 1.59
RTX 3090 llama 8B Q2_K_M 32 pp2048 778.38 1277.75 1.64
RTX 3090 llama 8B Q2_K_M 64 pp2048 913.46 1718.08 1.88
RTX 3090 llama 8B Q2_K_M 128 pp2048 2274.05 2125.18 0.93
RTX 3090 llama 8B Q2_K_M 256 pp2048 3354.70 2633.10 0.78
RTX 3090 llama 8B Q2_K_M 512 pp2048 3984.17 2755.28 0.69
RTX 3090 llama 8B Q2_K_M 1024 pp2048 4692.10 2853.54 0.61
RTX 3090 llama 8B Q2_K_M 2048 pp2048 4739.62 2804.06 0.59

As of right now the performance difference when using tensor cores is relatively small (~10%) since the overall utilization is poor. It will likely become more of a bottleneck as more optimizations are added.

@slaren
Copy link
Collaborator

slaren commented Jun 14, 2024

For me q2_K_M it is about 12% slower with bs=512. I think is still a good tradeoff compared to the memory cost of cuBLAS, since the only reason to use q2_K in the first place is in memory limited situations. Especially since current models that tend to have a very large output tensors, which increases memory usage.

GPU Model Microbatch size Test t/s master cuBLAS t/s cuda-ptx-mma-17 MMQ Speedup
RTX 3090 Ti 7B Q2_K_M 16 pp1024 551.50 907.57 1.65
RTX 3090 Ti 7B Q2_K_M 32 pp1024 729.71 1363.34 1.87
RTX 3090 Ti 7B Q2_K_M 64 pp1024 957.33 1986.98 2.08
RTX 3090 Ti 7B Q2_K_M 128 pp1024 2328.11 2711.16 1.16
RTX 3090 Ti 7B Q2_K_M 256 pp1024 3509.39 3357.53 0.96
RTX 3090 Ti 7B Q2_K_M 512 pp1024 4205.41 3703.15 0.88
RTX 3090 Ti 7B Q2_K_M 1024 pp1024 4659.57 3650.46 0.78

@JohannesGaessler JohannesGaessler merged commit 76d66ee into ggerganov:master Jun 14, 2024
65 checks passed
@bartowski1182
Copy link
Contributor

Just for curiousity, what is Q2_K_M? This PR and a couple others from Johannes are the only places I see it mentioned

@JohannesGaessler
Copy link
Collaborator Author

JohannesGaessler commented Jun 14, 2024

It's just how llama.cpp calls "q2_K" as opposed to q2_K_S.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs Review Complexity : High Generally require indepth knowledge of LLMs or GPUs
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants