Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ggml: implement quantized KV cache for FA #7372

Merged
merged 1 commit into from
May 19, 2024

Conversation

JohannesGaessler
Copy link
Collaborator

This PR implements the use of quantized KV caches for the CPU backend when using FlashAttention. This is done via switch statements in ggml_compute_forward_flash_attn_ext_f16; from what I can tell this does not significantly affect performance. I also added comments and did some small performance optimizations such as converting Q only once to FP16/q8_0/q8_1 and using multiplications for scaling rather than divisions. The amount of wdata has been increased but only on the order of kiB which should be negligible. On my desktop with a Ryzen 5950X power limited to 95W the performance changes as follows:

model backend threads fa test t/s master t/s PR Speedup
llama 7B Q4_0 CPU 16 1 pp2048 33.43 35.29 1.06
llama 7B Q4_0 CPU 16 1 tg512 10.59 10.71 1.01

When benchmarking, be mindful of the order in which you run the tests since a hot CPU will perform worse and this difference can easily be larger than the performance difference from this PR. I did a warmup run prior to the actual measurements.

@slaren
Copy link
Collaborator

slaren commented May 18, 2024

It would be preferable to use the more generic function pointers in type_traits in the same way mul_mat does.

Copy link
Contributor

github-actions bot commented May 19, 2024

📈 llama.cpp server for bench-server-baseline on Standard_NC4as_T4_v3 for phi-2-q4_0: 547 iterations 🚀

Expand details for performance related PR only
  • Concurrent users: 8, duration: 10m
  • HTTP request : avg=8537.9ms p(95)=20441.75ms fails=, finish reason: stop=493 truncated=54
  • Prompt processing (pp): avg=101.07tk/s p(95)=423.38tk/s
  • Token generation (tg): avg=51.07tk/s p(95)=48.17tk/s
  • ggml-org/models/phi-2/ggml-model-q4_0.gguf parallel=8 ctx-size=16384 ngl=33 batch-size=2048 ubatch-size=256 pp=1024 pp+tg=2048 branch=fa-quantize-3 commit=b7da2e86db9836b4b85a30675f39c0571de0ec94

prompt_tokens_seconds

More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 547 iterations"
    y-axis "llamacpp:prompt_tokens_seconds"
    x-axis "llamacpp:prompt_tokens_seconds" 1716145964 --> 1716146596
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 875.21, 875.21, 875.21, 875.21, 875.21, 827.76, 827.76, 827.76, 827.76, 827.76, 839.07, 839.07, 839.07, 839.07, 839.07, 887.27, 887.27, 887.27, 887.27, 887.27, 906.6, 906.6, 906.6, 906.6, 906.6, 892.65, 892.65, 892.65, 892.65, 892.65, 901.19, 901.19, 901.19, 901.19, 901.19, 905.11, 905.11, 905.11, 905.11, 905.11, 914.48, 914.48, 914.48, 914.48, 914.48, 914.94, 914.94, 914.94, 914.94, 914.94, 939.87, 939.87, 939.87, 939.87, 939.87, 917.31, 917.31, 917.31, 917.31, 917.31, 928.28, 928.28, 928.28, 928.28, 928.28, 932.96, 932.96, 932.96, 932.96, 932.96, 932.64, 932.64, 932.64, 932.64, 932.64, 932.47, 932.47, 932.47, 932.47, 932.47, 925.58, 925.58, 925.58, 925.58, 925.58, 938.79, 938.79, 938.79, 938.79, 938.79, 938.52, 938.52, 938.52, 938.52, 938.52, 941.35, 941.35, 941.35, 941.35, 941.35, 940.47, 940.47, 940.47, 940.47, 940.47, 940.23, 940.23, 940.23, 940.23, 940.23, 903.32, 903.32, 903.32, 903.32, 903.32, 903.77, 903.77, 903.77, 903.77, 903.77, 913.4, 913.4, 913.4, 913.4, 913.4, 916.13, 916.13, 916.13, 916.13, 916.13, 915.08, 915.08, 915.08, 915.08, 915.08, 915.15, 915.15, 915.15, 915.15, 915.15, 914.7, 914.7, 914.7, 914.7, 914.7, 915.77, 915.77, 915.77, 915.77, 915.77, 912.89, 912.89, 912.89, 912.89, 912.89, 914.47, 914.47, 914.47, 914.47, 914.47, 912.92, 912.92, 912.92, 912.92, 912.92, 912.52, 912.52, 912.52, 912.52, 912.52, 916.64, 916.64, 916.64, 916.64, 916.64, 912.77, 912.77, 912.77, 912.77, 912.77, 910.04, 910.04, 910.04, 910.04, 910.04, 911.36, 911.36, 911.36, 911.36, 911.36, 913.73, 913.73, 913.73, 913.73, 913.73, 911.06, 911.06, 911.06, 911.06, 911.06, 913.16, 913.16, 913.16, 913.16, 913.16, 911.52, 911.52, 911.52, 911.52, 911.52, 908.89, 908.89, 908.89, 908.89, 908.89, 907.08, 907.08, 907.08, 907.08, 907.08, 900.25, 900.25, 900.25, 900.25, 900.25, 903.67, 903.67, 903.67, 903.67, 903.67, 900.27, 900.27, 900.27, 900.27, 900.27, 899.22, 899.22, 899.22, 899.22, 899.22, 901.12, 901.12, 901.12, 901.12, 901.12, 901.99, 901.99, 901.99, 901.99, 901.99, 902.03, 902.03, 902.03, 902.03, 902.03, 904.79, 904.79, 904.79, 904.79, 904.79, 897.11, 897.11, 897.11, 897.11, 897.11, 903.84, 903.84, 903.84, 903.84, 903.84, 902.57, 902.57, 902.57, 902.57, 902.57, 903.66, 903.66, 903.66, 903.66, 903.66, 902.45, 902.45, 902.45, 902.45, 902.45, 902.98, 902.98, 902.98, 902.98, 902.98, 904.22, 904.22, 904.22, 904.22, 904.22, 905.74, 905.74, 905.74, 905.74, 905.74, 904.84, 904.84]
                    
Loading
predicted_tokens_seconds
More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 547 iterations"
    y-axis "llamacpp:predicted_tokens_seconds"
    x-axis "llamacpp:predicted_tokens_seconds" 1716145964 --> 1716146596
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 37.77, 37.77, 37.77, 37.77, 37.77, 31.81, 31.81, 31.81, 31.81, 31.81, 30.19, 30.19, 30.19, 30.19, 30.19, 32.39, 32.39, 32.39, 32.39, 32.39, 32.6, 32.6, 32.6, 32.6, 32.6, 32.48, 32.48, 32.48, 32.48, 32.48, 33.71, 33.71, 33.71, 33.71, 33.71, 34.12, 34.12, 34.12, 34.12, 34.12, 33.95, 33.95, 33.95, 33.95, 33.95, 34.15, 34.15, 34.15, 34.15, 34.15, 34.32, 34.32, 34.32, 34.32, 34.32, 34.1, 34.1, 34.1, 34.1, 34.1, 33.47, 33.47, 33.47, 33.47, 33.47, 33.46, 33.46, 33.46, 33.46, 33.46, 32.31, 32.31, 32.31, 32.31, 32.31, 30.38, 30.38, 30.38, 30.38, 30.38, 30.71, 30.71, 30.71, 30.71, 30.71, 30.81, 30.81, 30.81, 30.81, 30.81, 30.77, 30.77, 30.77, 30.77, 30.77, 30.79, 30.79, 30.79, 30.79, 30.79, 30.85, 30.85, 30.85, 30.85, 30.85, 30.98, 30.98, 30.98, 30.98, 30.98, 30.68, 30.68, 30.68, 30.68, 30.68, 30.8, 30.8, 30.8, 30.8, 30.8, 31.07, 31.07, 31.07, 31.07, 31.07, 30.87, 30.87, 30.87, 30.87, 30.87, 31.0, 31.0, 31.0, 31.0, 31.0, 31.06, 31.06, 31.06, 31.06, 31.06, 31.23, 31.23, 31.23, 31.23, 31.23, 31.28, 31.28, 31.28, 31.28, 31.28, 31.39, 31.39, 31.39, 31.39, 31.39, 31.61, 31.61, 31.61, 31.61, 31.61, 31.66, 31.66, 31.66, 31.66, 31.66, 31.39, 31.39, 31.39, 31.39, 31.39, 31.24, 31.24, 31.24, 31.24, 31.24, 30.62, 30.62, 30.62, 30.62, 30.62, 30.54, 30.54, 30.54, 30.54, 30.54, 30.71, 30.71, 30.71, 30.71, 30.71, 30.91, 30.91, 30.91, 30.91, 30.91, 31.02, 31.02, 31.02, 31.02, 31.02, 30.91, 30.91, 30.91, 30.91, 30.91, 30.75, 30.75, 30.75, 30.75, 30.75, 30.54, 30.54, 30.54, 30.54, 30.54, 29.81, 29.81, 29.81, 29.81, 29.81, 28.77, 28.77, 28.77, 28.77, 28.77, 28.79, 28.79, 28.79, 28.79, 28.79, 28.82, 28.82, 28.82, 28.82, 28.82, 28.83, 28.83, 28.83, 28.83, 28.83, 28.9, 28.9, 28.9, 28.9, 28.9, 28.93, 28.93, 28.93, 28.93, 28.93, 29.02, 29.02, 29.02, 29.02, 29.02, 29.06, 29.06, 29.06, 29.06, 29.06, 28.98, 28.98, 28.98, 28.98, 28.98, 28.99, 28.99, 28.99, 28.99, 28.99, 29.0, 29.0, 29.0, 29.0, 29.0, 29.11, 29.11, 29.11, 29.11, 29.11, 29.28, 29.28, 29.28, 29.28, 29.28, 29.37, 29.37, 29.37, 29.37, 29.37, 29.42, 29.42, 29.42, 29.42, 29.42, 29.51, 29.51, 29.51, 29.51, 29.51, 29.58, 29.58]
                    
Loading

Details

kv_cache_usage_ratio

More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 547 iterations"
    y-axis "llamacpp:kv_cache_usage_ratio"
    x-axis "llamacpp:kv_cache_usage_ratio" 1716145964 --> 1716146596
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.38, 0.38, 0.38, 0.38, 0.38, 0.29, 0.29, 0.29, 0.29, 0.29, 0.12, 0.12, 0.12, 0.12, 0.12, 0.2, 0.2, 0.2, 0.2, 0.2, 0.25, 0.25, 0.25, 0.25, 0.25, 0.1, 0.1, 0.1, 0.1, 0.1, 0.14, 0.14, 0.14, 0.14, 0.14, 0.17, 0.17, 0.17, 0.17, 0.17, 0.21, 0.21, 0.21, 0.21, 0.21, 0.12, 0.12, 0.12, 0.12, 0.12, 0.17, 0.17, 0.17, 0.17, 0.17, 0.23, 0.23, 0.23, 0.23, 0.23, 0.4, 0.4, 0.4, 0.4, 0.4, 0.36, 0.36, 0.36, 0.36, 0.36, 0.36, 0.36, 0.36, 0.36, 0.36, 0.24, 0.24, 0.24, 0.24, 0.24, 0.1, 0.1, 0.1, 0.1, 0.1, 0.17, 0.17, 0.17, 0.17, 0.17, 0.26, 0.26, 0.26, 0.26, 0.26, 0.17, 0.17, 0.17, 0.17, 0.17, 0.19, 0.19, 0.19, 0.19, 0.19, 0.14, 0.14, 0.14, 0.14, 0.14, 0.27, 0.27, 0.27, 0.27, 0.27, 0.11, 0.11, 0.11, 0.11, 0.11, 0.2, 0.2, 0.2, 0.2, 0.2, 0.13, 0.13, 0.13, 0.13, 0.13, 0.09, 0.09, 0.09, 0.09, 0.09, 0.14, 0.14, 0.14, 0.14, 0.14, 0.11, 0.11, 0.11, 0.11, 0.11, 0.15, 0.15, 0.15, 0.15, 0.15, 0.16, 0.16, 0.16, 0.16, 0.16, 0.17, 0.17, 0.17, 0.17, 0.17, 0.3, 0.3, 0.3, 0.3, 0.3, 0.21, 0.21, 0.21, 0.21, 0.21, 0.39, 0.39, 0.39, 0.39, 0.39, 0.26, 0.26, 0.26, 0.26, 0.26, 0.1, 0.1, 0.1, 0.1, 0.1, 0.14, 0.14, 0.14, 0.14, 0.14, 0.09, 0.09, 0.09, 0.09, 0.09, 0.34, 0.34, 0.34, 0.34, 0.34, 0.52, 0.52, 0.52, 0.52, 0.52, 0.68, 0.68, 0.68, 0.68, 0.68, 0.62, 0.62, 0.62, 0.62, 0.62, 0.46, 0.46, 0.46, 0.46, 0.46, 0.1, 0.1, 0.1, 0.1, 0.1, 0.27, 0.27, 0.27, 0.27, 0.27, 0.21, 0.21, 0.21, 0.21, 0.21, 0.21, 0.21, 0.21, 0.21, 0.21, 0.16, 0.16, 0.16, 0.16, 0.16, 0.15, 0.15, 0.15, 0.15, 0.15, 0.14, 0.14, 0.14, 0.14, 0.14, 0.24, 0.24, 0.24, 0.24, 0.24, 0.23, 0.23, 0.23, 0.23, 0.23, 0.27, 0.27, 0.27, 0.27, 0.27, 0.11, 0.11, 0.11, 0.11, 0.11, 0.13, 0.13, 0.13, 0.13, 0.13, 0.11, 0.11, 0.11, 0.11, 0.11, 0.16, 0.16, 0.16, 0.16, 0.16, 0.16, 0.16, 0.16, 0.16, 0.16, 0.18, 0.18, 0.18, 0.18, 0.18, 0.22, 0.22]
                    
Loading
requests_processing
More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 547 iterations"
    y-axis "llamacpp:requests_processing"
    x-axis "llamacpp:requests_processing" 1716145964 --> 1716146596
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 6.0, 6.0, 6.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 2.0, 2.0, 2.0, 2.0, 2.0, 6.0, 6.0, 6.0, 6.0, 6.0, 4.0, 4.0, 4.0, 4.0, 4.0, 2.0, 2.0, 2.0, 2.0, 2.0, 3.0, 3.0, 3.0, 3.0, 3.0, 4.0, 4.0, 4.0, 4.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 3.0, 3.0, 3.0, 3.0, 3.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 2.0, 2.0, 2.0, 2.0, 2.0, 4.0, 4.0, 4.0, 4.0, 4.0, 3.0, 3.0, 3.0, 3.0, 3.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 6.0, 6.0, 6.0, 6.0, 6.0, 4.0, 4.0, 4.0, 4.0, 4.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 4.0, 4.0, 4.0, 4.0, 4.0, 3.0, 3.0, 3.0, 3.0, 3.0, 6.0, 6.0, 6.0, 6.0, 6.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 4.0, 4.0, 4.0, 4.0, 4.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 1.0, 1.0, 1.0, 1.0, 1.0, 5.0, 5.0, 5.0, 5.0, 5.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 2.0, 2.0, 2.0, 2.0, 2.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 4.0, 4.0, 4.0, 4.0, 4.0, 1.0, 1.0]
                    
Loading

@JohannesGaessler
Copy link
Collaborator Author

I tested using the same FP32 code for both FP16 and quantized V cache but the performance was worse.

ggml.c Outdated
GGML_ASSERT(nbq0 == sizeof(float));
GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t));
GGML_ASSERT(nbv0 == sizeof(ggml_fp16_t));
// input tensors must be contiguous
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs just contiguous rows, not the entire tensors

@JohannesGaessler JohannesGaessler merged commit 5ca49cb into ggerganov:master May 19, 2024
43 of 45 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants