Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUDA: quantized KV support for FA vec #7527

Merged

Conversation

JohannesGaessler
Copy link
Collaborator

This PR adds CUDA support for quantizing the KV cache. Only the kernels optimized for small batch sizes are implemented so prompt processing speed is still bad. Also as of right now I have not added any mechanisms for reducing the compilation time which currently sits at ~7 minutes on my desktop machine and is definitely too long. I think the way to go about this is to just not compile all possible combinations for KV cache types but rather only those that actually make sense to use based on the results in #7412 . In my opinion compiling the reduced version should be the default but I don't have a strong opinion on this. I would then add a compile flag to compile all combinations if one wants to. I consider a solution for the long compile time the only thing still missing from this PR.

I think the following type combinations make sense:

  • K=f16, V=f16, 16.0 BPV
  • K=q8_0, V=q8_0, 8.50 BPV
  • K=q8_0, V=q5_0, 7.00 BPV
  • K=q8_0, V=q4_0, 6.50 BPV
  • K=q5_1, V=q5_1, 6.00 BPV
  • K=q5_1, V=q5_0, 5.75 BPV
  • K=q5_1, V=q4_1, 5.50 BPV
  • K=q5_1, V=q4_0, 5.25 BPV
  • K=q5_0, V=q0_0, 5.00 BPV
  • K=q4_0, V=q4_0, 4.75 BPV

That would reduce the number of possible combinations from 36 to 10. Compilation may still take too long, so maybe adding yet another compile flag for fast compilation would make sense.

Note: currently there is a bug on master that causes incorrect results when quantizing the K cache, see #7492 . A workaround is to disable CUDA graphs via the environment variable GGML_CUDA_DISABLE_GRAPHS=1 but obviously that is going to result in worse performance.

Performance stays mostly the same with a quantized KV cache; quantizing K makes the performance slightly better, quantizing V makes the performance slightly worse.

I was unfortunately not able to recycle much of the existing code. The vector FlashAttention kernels are run with the same number of threads as the head size so the existing code for dequantizing values that produces 2 values is not a good fit for V. Similarly for K the existing q4/q5 dot products are also a poor fit because those dot products work on chunks of 8*WARP_SIZE == 256 values in parallel while the head size of e.g. LLaMA 3 is only 128.

@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels May 24, 2024
Copy link
Contributor

github-actions bot commented May 24, 2024

📈 llama.cpp server for bench-server-baseline on Standard_NC4as_T4_v3 for phi-2-q4_0: 568 iterations 🚀

Expand details for performance related PR only
  • Concurrent users: 8, duration: 10m
  • HTTP request : avg=8206.57ms p(95)=18904.6ms fails=, finish reason: stop=518 truncated=50
  • Prompt processing (pp): avg=94.04tk/s p(95)=414.2tk/s
  • Token generation (tg): avg=70.85tk/s p(95)=51.78tk/s
  • ggml-org/models/phi-2/ggml-model-q4_0.gguf parallel=8 ctx-size=16384 ngl=33 batch-size=2048 ubatch-size=256 pp=1024 pp+tg=2048 branch=cuda-fattn-vec-quant-3 commit=cc7aef6829b1d3f7889abc05290bd557d394d0fd

prompt_tokens_seconds

More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 568 iterations"
    y-axis "llamacpp:prompt_tokens_seconds"
    x-axis "llamacpp:prompt_tokens_seconds" 1717015075 --> 1717015705
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 648.56, 648.56, 648.56, 648.56, 648.56, 733.53, 733.53, 733.53, 733.53, 733.53, 748.25, 748.25, 748.25, 748.25, 748.25, 800.55, 800.55, 800.55, 800.55, 800.55, 816.89, 816.89, 816.89, 816.89, 816.89, 814.12, 814.12, 814.12, 814.12, 814.12, 836.08, 836.08, 836.08, 836.08, 836.08, 836.59, 836.59, 836.59, 836.59, 836.59, 852.51, 852.51, 852.51, 852.51, 852.51, 851.22, 851.22, 851.22, 851.22, 851.22, 904.11, 904.11, 904.11, 904.11, 904.11, 902.86, 902.86, 902.86, 902.86, 902.86, 832.55, 832.55, 832.55, 832.55, 832.55, 817.35, 817.35, 817.35, 817.35, 817.35, 825.73, 825.73, 825.73, 825.73, 825.73, 829.19, 829.19, 829.19, 829.19, 829.19, 826.36, 826.36, 826.36, 826.36, 826.36, 816.58, 816.58, 816.58, 816.58, 816.58, 818.14, 818.14, 818.14, 818.14, 818.14, 824.74, 824.74, 824.74, 824.74, 824.74, 826.26, 826.26, 826.26, 826.26, 826.26, 845.42, 845.42, 845.42, 845.42, 845.42, 844.45, 844.45, 844.45, 844.45, 844.45, 846.75, 846.75, 846.75, 846.75, 846.75, 845.45, 845.45, 845.45, 845.45, 845.45, 845.92, 845.92, 845.92, 845.92, 845.92, 843.97, 843.97, 843.97, 843.97, 843.97, 843.61, 843.61, 843.61, 843.61, 843.61, 843.82, 843.82, 843.82, 843.82, 843.82, 844.3, 844.3, 844.3, 844.3, 844.3, 841.34, 841.34, 841.34, 841.34, 841.34, 842.75, 842.75, 842.75, 842.75, 842.75, 844.1, 844.1, 844.1, 844.1, 844.1, 837.53, 837.53, 837.53, 837.53, 837.53, 827.04, 827.04, 827.04, 827.04, 827.04, 814.55, 814.55, 814.55, 814.55, 814.55, 812.35, 812.35, 812.35, 812.35, 812.35, 812.24, 812.24, 812.24, 812.24, 812.24, 815.71, 815.71, 815.71, 815.71, 815.71, 818.58, 818.58, 818.58, 818.58, 818.58, 819.43, 819.43, 819.43, 819.43, 819.43, 827.0, 827.0, 827.0, 827.0, 827.0, 829.93, 829.93, 829.93, 829.93, 829.93, 829.2, 829.2, 829.2, 829.2, 829.2, 828.09, 828.09, 828.09, 828.09, 828.09, 835.03, 835.03, 835.03, 835.03, 835.03, 834.84, 834.84, 834.84, 834.84, 834.84, 841.31, 841.31, 841.31, 841.31, 841.31, 840.7, 840.7, 840.7, 840.7, 840.7, 846.13, 846.13, 846.13, 846.13, 846.13, 846.65, 846.65, 846.65, 846.65, 846.65, 852.3, 852.3, 852.3, 852.3, 852.3, 853.6, 853.6, 853.6, 853.6, 853.6, 852.99, 852.99, 852.99, 852.99, 852.99, 852.0, 852.0, 852.0, 852.0, 852.0, 852.93, 852.93, 852.93, 852.93, 852.93, 854.27, 854.27, 854.27, 854.27, 854.27, 857.71, 857.71, 857.71, 857.71, 857.71, 857.91, 857.91, 857.91, 857.91, 857.91, 856.45, 856.45, 856.45, 856.45, 856.45, 856.45]
                    
Loading
predicted_tokens_seconds
More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 568 iterations"
    y-axis "llamacpp:predicted_tokens_seconds"
    x-axis "llamacpp:predicted_tokens_seconds" 1717015075 --> 1717015705
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 45.19, 45.19, 45.19, 45.19, 45.19, 29.76, 29.76, 29.76, 29.76, 29.76, 29.14, 29.14, 29.14, 29.14, 29.14, 31.12, 31.12, 31.12, 31.12, 31.12, 31.61, 31.61, 31.61, 31.61, 31.61, 33.84, 33.84, 33.84, 33.84, 33.84, 34.7, 34.7, 34.7, 34.7, 34.7, 34.67, 34.67, 34.67, 34.67, 34.67, 34.59, 34.59, 34.59, 34.59, 34.59, 34.58, 34.58, 34.58, 34.58, 34.58, 34.21, 34.21, 34.21, 34.21, 34.21, 33.61, 33.61, 33.61, 33.61, 33.61, 33.35, 33.35, 33.35, 33.35, 33.35, 32.15, 32.15, 32.15, 32.15, 32.15, 30.87, 30.87, 30.87, 30.87, 30.87, 30.9, 30.9, 30.9, 30.9, 30.9, 31.11, 31.11, 31.11, 31.11, 31.11, 31.1, 31.1, 31.1, 31.1, 31.1, 31.29, 31.29, 31.29, 31.29, 31.29, 31.31, 31.31, 31.31, 31.31, 31.31, 31.32, 31.32, 31.32, 31.32, 31.32, 31.54, 31.54, 31.54, 31.54, 31.54, 31.29, 31.29, 31.29, 31.29, 31.29, 31.33, 31.33, 31.33, 31.33, 31.33, 31.6, 31.6, 31.6, 31.6, 31.6, 31.54, 31.54, 31.54, 31.54, 31.54, 31.29, 31.29, 31.29, 31.29, 31.29, 31.23, 31.23, 31.23, 31.23, 31.23, 31.46, 31.46, 31.46, 31.46, 31.46, 31.48, 31.48, 31.48, 31.48, 31.48, 31.56, 31.56, 31.56, 31.56, 31.56, 31.67, 31.67, 31.67, 31.67, 31.67, 31.79, 31.79, 31.79, 31.79, 31.79, 31.69, 31.69, 31.69, 31.69, 31.69, 31.61, 31.61, 31.61, 31.61, 31.61, 31.58, 31.58, 31.58, 31.58, 31.58, 30.78, 30.78, 30.78, 30.78, 30.78, 30.55, 30.55, 30.55, 30.55, 30.55, 30.69, 30.69, 30.69, 30.69, 30.69, 30.88, 30.88, 30.88, 30.88, 30.88, 30.95, 30.95, 30.95, 30.95, 30.95, 30.84, 30.84, 30.84, 30.84, 30.84, 30.74, 30.74, 30.74, 30.74, 30.74, 29.45, 29.45, 29.45, 29.45, 29.45, 29.53, 29.53, 29.53, 29.53, 29.53, 29.49, 29.49, 29.49, 29.49, 29.49, 29.57, 29.57, 29.57, 29.57, 29.57, 29.68, 29.68, 29.68, 29.68, 29.68, 29.78, 29.78, 29.78, 29.78, 29.78, 29.83, 29.83, 29.83, 29.83, 29.83, 29.78, 29.78, 29.78, 29.78, 29.78, 29.82, 29.82, 29.82, 29.82, 29.82, 29.81, 29.81, 29.81, 29.81, 29.81, 29.84, 29.84, 29.84, 29.84, 29.84, 30.06, 30.06, 30.06, 30.06, 30.06, 30.12, 30.12, 30.12, 30.12, 30.12, 30.23, 30.23, 30.23, 30.23, 30.23, 30.28, 30.28, 30.28, 30.28, 30.28, 30.3, 30.3, 30.3, 30.3, 30.3, 30.42, 30.42, 30.42, 30.42, 30.42, 30.52]
                    
Loading

Details

kv_cache_usage_ratio

More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 568 iterations"
    y-axis "llamacpp:kv_cache_usage_ratio"
    x-axis "llamacpp:kv_cache_usage_ratio" 1717015075 --> 1717015705
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.35, 0.35, 0.35, 0.35, 0.35, 0.22, 0.22, 0.22, 0.22, 0.22, 0.28, 0.28, 0.28, 0.28, 0.28, 0.17, 0.17, 0.17, 0.17, 0.17, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.18, 0.18, 0.18, 0.18, 0.18, 0.17, 0.17, 0.17, 0.17, 0.17, 0.18, 0.18, 0.18, 0.18, 0.18, 0.21, 0.21, 0.21, 0.21, 0.21, 0.28, 0.28, 0.28, 0.28, 0.28, 0.23, 0.23, 0.23, 0.23, 0.23, 0.38, 0.38, 0.38, 0.38, 0.38, 0.36, 0.36, 0.36, 0.36, 0.36, 0.2, 0.2, 0.2, 0.2, 0.2, 0.15, 0.15, 0.15, 0.15, 0.15, 0.11, 0.11, 0.11, 0.11, 0.11, 0.12, 0.12, 0.12, 0.12, 0.12, 0.15, 0.15, 0.15, 0.15, 0.15, 0.2, 0.2, 0.2, 0.2, 0.2, 0.13, 0.13, 0.13, 0.13, 0.13, 0.31, 0.31, 0.31, 0.31, 0.31, 0.14, 0.14, 0.14, 0.14, 0.14, 0.11, 0.11, 0.11, 0.11, 0.11, 0.16, 0.16, 0.16, 0.16, 0.16, 0.34, 0.34, 0.34, 0.34, 0.34, 0.21, 0.21, 0.21, 0.21, 0.21, 0.12, 0.12, 0.12, 0.12, 0.12, 0.14, 0.14, 0.14, 0.14, 0.14, 0.21, 0.21, 0.21, 0.21, 0.21, 0.13, 0.13, 0.13, 0.13, 0.13, 0.2, 0.2, 0.2, 0.2, 0.2, 0.25, 0.25, 0.25, 0.25, 0.25, 0.31, 0.31, 0.31, 0.31, 0.31, 0.44, 0.44, 0.44, 0.44, 0.44, 0.58, 0.58, 0.58, 0.58, 0.58, 0.2, 0.2, 0.2, 0.2, 0.2, 0.11, 0.11, 0.11, 0.11, 0.11, 0.11, 0.11, 0.11, 0.11, 0.11, 0.11, 0.11, 0.11, 0.11, 0.11, 0.33, 0.33, 0.33, 0.33, 0.33, 0.49, 0.49, 0.49, 0.49, 0.49, 0.38, 0.38, 0.38, 0.38, 0.38, 0.24, 0.24, 0.24, 0.24, 0.24, 0.16, 0.16, 0.16, 0.16, 0.16, 0.2, 0.2, 0.2, 0.2, 0.2, 0.15, 0.15, 0.15, 0.15, 0.15, 0.18, 0.18, 0.18, 0.18, 0.18, 0.11, 0.11, 0.11, 0.11, 0.11, 0.22, 0.22, 0.22, 0.22, 0.22, 0.13, 0.13, 0.13, 0.13, 0.13, 0.22, 0.22, 0.22, 0.22, 0.22, 0.13, 0.13, 0.13, 0.13, 0.13, 0.19, 0.19, 0.19, 0.19, 0.19, 0.12, 0.12, 0.12, 0.12, 0.12, 0.15, 0.15, 0.15, 0.15, 0.15, 0.17, 0.17, 0.17, 0.17, 0.17, 0.14, 0.14, 0.14, 0.14, 0.14, 0.15, 0.15, 0.15, 0.15, 0.15, 0.08, 0.08, 0.08, 0.08, 0.08, 0.13]
                    
Loading
requests_processing
More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 568 iterations"
    y-axis "llamacpp:requests_processing"
    x-axis "llamacpp:requests_processing" 1717015075 --> 1717015705
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 6.0, 6.0, 6.0, 6.0, 6.0, 5.0, 5.0, 5.0, 5.0, 5.0, 7.0, 7.0, 7.0, 7.0, 7.0, 2.0, 2.0, 2.0, 2.0, 2.0, 6.0, 6.0, 6.0, 6.0, 6.0, 2.0, 2.0, 2.0, 2.0, 2.0, 3.0, 3.0, 3.0, 3.0, 3.0, 5.0, 5.0, 5.0, 5.0, 5.0, 4.0, 4.0, 4.0, 4.0, 4.0, 6.0, 6.0, 6.0, 6.0, 6.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 8.0, 8.0, 8.0, 8.0, 8.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 3.0, 3.0, 3.0, 3.0, 3.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 6.0, 6.0, 6.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 4.0, 4.0, 4.0, 4.0, 4.0, 6.0, 6.0, 6.0, 6.0, 6.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 7.0, 7.0, 7.0, 7.0, 7.0, 6.0, 6.0, 6.0, 6.0, 6.0, 5.0, 5.0, 5.0, 5.0, 5.0, 3.0, 3.0, 3.0, 3.0, 3.0, 4.0, 4.0, 4.0, 4.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 1.0, 1.0, 1.0, 1.0, 1.0, 4.0, 4.0, 4.0, 4.0, 4.0, 2.0, 2.0, 2.0, 2.0, 2.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 3.0, 3.0, 3.0, 3.0, 3.0, 1.0, 1.0, 1.0, 1.0, 1.0, 4.0, 4.0, 4.0, 4.0, 4.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 3.0, 3.0, 3.0, 3.0, 3.0, 6.0, 6.0, 6.0, 6.0, 6.0, 1.0, 1.0, 1.0, 1.0, 1.0, 7.0, 7.0, 7.0, 7.0, 7.0, 2.0, 2.0, 2.0, 2.0, 2.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 1.0]
                    
Loading

@Nexesenex
Copy link
Contributor

Considering the results you obtained, I'd say that KV q8_0/q5_1 7.25BPW is imo a must due to the big gap between 7.00BPW and 8.50BPW, and most importantly due to the very sizeable benefit q8_0/q5_1 brings compared to q8_0/q5_0 in terms of lesser quality loss versus q8_0/q8_0, this for a difference of only 0.25BPW.

q8_0/q5_1 is the compromise I'd likely use daily, and I think it should be part of the "short list"!

@mofosyne mofosyne added the Review Complexity : High Generally require indepth knowledge of LLMs or GPUs label May 25, 2024
@JohannesGaessler
Copy link
Collaborator Author

I have a prototype for restricting the compilation only to those types listed in the OP. On my desktop the compile time with LLAMA_NO_CCACHE still increases from ~1 minute to ~2 minutes.

@ggerganov @slaren can you share some insights regarding what maximum compile time you would still find acceptable and whether you think a full or a reduced build should be the default? Another approach would be to split the compilation across more files which would help on e.g. user desktop machines but maybe not in CI?

@ggerganov
Copy link
Owner

My opinion is that we should probably keep just 3 options:

  • K=f16, V=f16, 16.0 BPV
  • K=q8_0, V=q8_0, 8.50 BPV
  • whichever one is fastest from the sub-6 BPV combinations

The compile time even on master is already quite slow for my taste, so definitely we should try to improve this somehow

@github-actions github-actions bot added the build Compilation issues label May 25, 2024
@JohannesGaessler JohannesGaessler marked this pull request as ready for review May 25, 2024 19:16
@JohannesGaessler
Copy link
Collaborator Author

The compile time even on master is already quite slow for my taste, so definitely we should try to improve this somehow

Which scenario are you talking about? Multi-threaded or single-threaded?

@JohannesGaessler
Copy link
Collaborator Author

The type combination I chose for "the fastest" is q4_0/q4_0; there are still performance issues with the FP32 -> q4_0 conversion code though so end-to-end performance is still suboptimal.

@ggerganov
Copy link
Owner

The default multi-threaded build (i.e. using the subset of 3 KV quants) takes the same amount of time as on master which is good:

time LLAMA_NO_CCACHE=1 LLAMA_CUDA=1 make -j main

real	0m35,972s
user	3m57,118s
sys	0m28,550s

This is on a AMD Ryzen 9 5950X 16-Core Processor.

The build using all quants is indeed quite slow, but being option I think is acceptable:

time LLAMA_NO_CCACHE=1 LLAMA_CUDA=1 LLAMA_CUDA_FA_ALL_QUANTS=1 make -j main

real	6m18,667s
user	11m51,335s
sys	0m32,225s

@JohannesGaessler
Copy link
Collaborator Author

If you're concerned about compilation time for developers in particular it would definitely be possible to speed that up. A lot of the code is not strictly needed in terms of correctness but only exists for better performance, e.g. in mmvq.cu there are kernel versions for all batch sizes from 1-8 even though you could in theory use batch size 1 for all of those cases (but with virtually no performance gain). For the FlashAttention kernels it would in principle also be sufficient to compile only versions for batch size 1 and batch sizes >> 1 at the cost of worse performance for intermediate batch sizes.

There is a relatively recent NVCC option called --split-compile that is supposed to parallelize CUDA compilation but on my machine the difference is only ~10%; manually splitting the files definitely works much better.

Using -O1 instead of -O3 reduces compilation time by ~3% on my machine (this only affects host code).

I noticed that --device-debug greatly speeds up the compilation since it disables device code optimization; I made a PR for adding it to debug builds: #7542 .

@ggerganov
Copy link
Owner

If you're concerned about compilation time for developers in particular it would definitely be possible to speed that up.

Yes, I think any reasonable reorganization of the build that reduces the compile time improves the developer experience. I build the CUDA backend to run some tests from time-to-time - my main development environment is MacOS, so it does not affect me a lot. But I can see it becoming a problem if the build time keeps increasing, so we have to be mindful and look for opportunities to improve. The multi-file reorganization of the CUDA code is already a big win and it's paying out a lot

@slaren
Copy link
Collaborator

slaren commented May 26, 2024

These changes would only be useful when working on the affected files, eg. when working on the mmq kernels. But we could not use them when testing the performance of the CUDA backend, and it also wouldn't help when working on other parts of the code since these files do not need to be recompiled.

@slaren
Copy link
Collaborator

slaren commented May 26, 2024

I think the build time with only the 3 options is already very good. Lately I am far more concerned about the build time of llama.cpp (the file), a full CUDA build (this PR) takes 26 seconds for me, building llama.cpp only takes 13 seconds. However the solution for that is easy, there are lots of code in there that could be split into multiple files.

@JohannesGaessler
Copy link
Collaborator Author

I was also thinking about the compilation time in terms of the CI though; I feel like it would be useful if the latency between opening a PR and feedback about a failed run would be lower. And for that parallelization probably won't do much since (I assume) the CI is working with very limited resources.

@ggerganov
Copy link
Owner

@slaren Yes, we should do that. And also for the ggml sources eventually

@JohannesGaessler ggml-ci is quite responsive and covers a lot of the CI, but it only monitors branches in this repository

@JohannesGaessler
Copy link
Collaborator Author

These changes would only be useful when working on the affected files, eg. when working on the mmq kernels.

Strictly speaking they would also be useful when working on e.g. common.cuh but I agree that when using ccache the compilation time is in most cases unproblematic. The scenario which affects me the most is when I compile code on one of my secondary machines where there are more cache misses because I rarely use them for development (but by definition that also means that it's an infrequent problem for me).

@slaren
Copy link
Collaborator

slaren commented May 26, 2024

We should have tests for this in test-backend-ops. Ie. add type_k and type_v parameters to test_flash_attn_ext, and add a quick test for all type combinations, and full tests only for f16 and maybe one or two quant type combinations.

@slaren
Copy link
Collaborator

slaren commented May 26, 2024

Performance stays mostly the same with a quantized KV cache; quantizing K makes the performance slightly better, quantizing V makes the performance slightly worse.

This is not really what I am seeing, with q8_0 I see a 15-20% drop in performance. Is this the expected performance?

With GGML_CUDA_DISABLE_GRAPHS=1:

model size params backend ngl type_k type_v fa test t/s
llama 7B Q4_0 3.56 GiB 6.74 B CUDA 99 f16 f16 1 pp512 4946.49 ± 7.31
llama 7B Q4_0 3.56 GiB 6.74 B CUDA 99 f16 f16 1 tg128 138.23 ± 0.12
llama 7B Q4_0 3.56 GiB 6.74 B CUDA 99 f16 q8_0 1 pp512 3757.84 ± 5.52
llama 7B Q4_0 3.56 GiB 6.74 B CUDA 99 f16 q8_0 1 tg128 115.21 ± 0.42
llama 7B Q4_0 3.56 GiB 6.74 B CUDA 99 q8_0 f16 1 pp512 3814.50 ± 13.03
llama 7B Q4_0 3.56 GiB 6.74 B CUDA 99 q8_0 f16 1 tg128 119.53 ± 0.22
llama 7B Q4_0 3.56 GiB 6.74 B CUDA 99 q8_0 q8_0 1 pp512 3602.09 ± 15.58
llama 7B Q4_0 3.56 GiB 6.74 B CUDA 99 q8_0 q8_0 1 tg128 112.82 ± 0.19

build: a5436a0 (2981)

@JohannesGaessler
Copy link
Collaborator Author

When I refactored the code I disabled the compilation of most kernel variants, including the one for batch size 1 since I was only checking for correctness. But I accidentally committed this change so the tg performance was bad. This is the performance you should be seeing:

model backend ngl type_k type_v fa test t/s
llama 7B Q4_0 CUDA 99 f16 f16 1 pp512 4386.01 ± 20.19
llama 7B Q4_0 CUDA 99 f16 f16 1 tg128 145.10 ± 0.36
llama 7B Q4_0 CUDA 99 f16 f16 1 tg1024 139.17 ± 0.35
llama 7B Q4_0 CUDA 99 f16 f16 1 tg4096 121.29 ± 0.71
llama 7B Q4_0 CUDA 99 q8_0 q8_0 1 pp512 3177.33 ± 11.17
llama 7B Q4_0 CUDA 99 q8_0 q8_0 1 tg128 141.48 ± 0.67
llama 7B Q4_0 CUDA 99 q8_0 q8_0 1 tg1024 135.32 ± 0.65
llama 7B Q4_0 CUDA 99 q8_0 q8_0 1 tg4096 120.35 ± 0.05
llama 7B Q4_0 CUDA 99 q4_0 q4_0 1 pp512 3128.80 ± 18.66
llama 7B Q4_0 CUDA 99 q4_0 q4_0 1 tg128 141.26 ± 0.25
llama 7B Q4_0 CUDA 99 q4_0 q4_0 1 tg1024 135.96 ± 0.47
llama 7B Q4_0 CUDA 99 q4_0 q4_0 1 tg4096 119.91 ± 0.23

The pp performance is expected to be worse because there are no efficient kernels for batch sizes >> 1.
(When you're testing, remember to set GGML_CUDA_DISABLE_GRAPHS=1 since quantizing the KV cache is currently broken without it.)

@slaren
Copy link
Collaborator

slaren commented May 26, 2024

Yes, that's close to what I see now.

It would be fairly easy to parallelize the compilation of the different template instantiations across multiple sources files by using extern templates. Essentially you would have to remove the static from the templates, declare each instantiation as extern, and then instantiate each instance manually in different source files. I can make an example of how to do this if you are interested, but I am not sure what templates are the most expensive that would benefit from this.

@slaren
Copy link
Collaborator

slaren commented May 26, 2024

Let's add this to disable CUDA graphs with quantized KV for now:

diff --git a/ggml-cuda.cu b/ggml-cuda.cu
index b82167cb..cd8b2e31 100644
--- a/ggml-cuda.cu
+++ b/ggml-cuda.cu
@@ -2585,6 +2585,14 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
 #endif
             }

+            if (node->op == GGML_OP_FLASH_ATTN_EXT && (ggml_is_quantized(node->src[1]->type) || ggml_is_quantized(node->src[2]->type))) {
+                // disable CUDA graphs for quantized FLASH_ATTN_EXT for now
+                use_cuda_graph = false;
+#ifndef NDEBUG
+                GGML_CUDA_LOG_WARN("%s: disabling CUDA graphs due to quantized FLASH_ATTN_EXT\n", __func__);
+#endif
+            }
+
             if (node->op == GGML_OP_CPY) {
                 // store the copy op parameter which changes with each token.
                 cuda_ctx->cuda_graph->updated_kernel_arg.push_back((char **) &(node->src[1]->data));

@github-actions github-actions bot added the testing Everything test related label May 26, 2024
@JohannesGaessler
Copy link
Collaborator Author

I can't manage to wrangle the compiler into doing what I want, I think I'm doing something subtly wrong. @slaren can you post the discussed example? The way I would split the template variants across files is via K type, V type, and head size since you can treat each tuple of those three values in the same way.

@slaren
Copy link
Collaborator

slaren commented May 27, 2024

Here it is: 1ca802a

Probably the first macro is unnecessary (DECL_FATTN_VEC_F16_INST), but I wasn't sure.

@JohannesGaessler
Copy link
Collaborator Author

Thank you very much for the example, I pushed a prototype for what I have in mind. Specifically:

  • Change the CUDA kernel template parameters to 2x ggml_type for K and V.
  • Define __device__ constexpr functions that convert ggml_type to the corresponding type-specific __device__ functions.
  • Create a host template with head size, K type, and V type that internally launches the correct CUDA kernel depending on Q shape.
  • Create a directory ggml-cuda/template-instances for all of the mostly non-functional .cu files, write a simple Python script to autogenerate them. I plan to re-use this directory for future kernels.

Thoughts?

@slaren
Copy link
Collaborator

slaren commented May 28, 2024

Looks good to me. I compared the compilation time with the previous commit:

Before:

ggml-cuda/mmvq.cu 14.54 s
ggml-cuda/fattn-vec-f32.cu 23.81 s
ggml-cuda/fattn.cu 23.98 s
ggml-cuda/fattn-vec-f16.cu 25.49 s

After:

ggml-cuda/fattn-vec-f16-variant-q8_0-q8_0-hs128.cu 10.93 s
ggml-cuda/fattn-vec-f16-variant-q4_0-q4_0-hs128.cu 11.30 s
ggml-cuda/fattn-vec-f16-variant-f16-f16-hs128.cu 13.04 s
ggml-cuda/mmvq.cu 15.36 s
ggml-cuda/fattn-vec-f16.cu 15.56 s
ggml-cuda/fattn.cu 24.36 s

So it seems quite effective, although compilation with a single core will be significantly slower.

@JohannesGaessler
Copy link
Collaborator Author

How are you getting the compilation times?

@slaren
Copy link
Collaborator

slaren commented May 28, 2024

I modified the Makefile as such:

diff --git a/Makefile b/Makefile
index 278a06d2..db64b76b 100644
--- a/Makefile
+++ b/Makefile
@@ -504,7 +504,7 @@ define NVCC_COMPILE
 endef # NVCC_COMPILE
 else
 define NVCC_COMPILE
-       $(NVCC) $(NVCCFLAGS) $(CPPFLAGS) -Xcompiler "$(CUDA_CXXFLAGS)" -c $< -o $@
+       /bin/time -f "$< %e s" $(NVCC) $(NVCCFLAGS) $(CPPFLAGS) -Xcompiler "$(CUDA_CXXFLAGS)" -c $< -o $@
 endef # NVCC_COMPILE
 endif # JETSON_EOL_MODULE_DETECT

@JohannesGaessler
Copy link
Collaborator Author

JohannesGaessler commented May 29, 2024

I pushed an implementation for split compilation where each individual FlashAttention compilation job takes roughly as long as mmvq.cu. I moved the device code in fattn.cu to fattn-wmma-f16.cuh and also split the compilation into multiple jobs using extern templates. On my main development machine with a Ryzen 5950X powerlimited to 95W the compilation time for commands like

make clean && time /usr/bin/make -j 32 main LLAMA_CUDA=1 LLAMA_NO_CCACHE=1

is as follows:

-j b3012 [s] Concentrated templates [s] extern templates [s]
1 177.66 192.91 255.90
2 86.74 108.96 131.29
4 48.01 55.08 73.75
8 36.54 38.71 43.34
16 33.49 35.06 30.33
32 33.93 35.53 29.47

b3012 is the master commit that this PR is based on. I think timing individual compilation jobs in a setting with >> 1 jobs overestimates how long the compilation would take with -j 1 because the scaling is nonlinear and the time per individual job is higher (I think due to resource limitations that are shared between jobs). Still, for low numbers of parallel jobs split compilation seems to be detrimental. For me personally however this is not an issue since the machines I use for development all have at least 16 cores.

I didn't test it in the same detail but the compilation time for LLAMA_CUDA_FA_ALL_QUANTS with -j 32 decreases from ~7 minutes to 107s.

@github-actions github-actions bot added the python python script changes label May 29, 2024
@slaren
Copy link
Collaborator

slaren commented May 29, 2024

Shouldn't the generated files be added to the repository? The python script is also missing the executable permission.

@JohannesGaessler
Copy link
Collaborator Author

I was thinking it would be better to add the autogenerated files towards the end of the review process in case there are requests for changes.

@JohannesGaessler
Copy link
Collaborator Author

The CI says: Error: The operation was canceled.
Does that just mean the compilation took too long?

@slaren
Copy link
Collaborator

slaren commented May 30, 2024

I don't think it's a timeout, and it was cancelled before the compilation actually started so I think it is unlikely to be caused by these changes. Let's see if it happens again..

@slaren
Copy link
Collaborator

slaren commented May 31, 2024

We should add a check in llama_new_context_with_model to avoid using quantized V without flash attention, currently it crashes in random locations (or doesn't, which is probably worse).

@JohannesGaessler JohannesGaessler merged commit 9b59641 into ggerganov:master Jun 1, 2024
63 of 70 checks passed
@slaren
Copy link
Collaborator

slaren commented Jun 1, 2024

test-backend-ops is failing on debug builds after this change:

test-backend-ops: ggml.c:3078: ggml_row_size: Assertion `ne % ggml_blck_size(type) == 0' failed.

@RachidAR
Copy link

RachidAR commented Jun 1, 2024

After 9b59641, flash attention doesn't work with phi-3 medium.

./main -ngl 99 -t 14 -c 4096 -m ./Phi-3-medium-4k-instruct-q4.gguf -fa -p "Once upon on time "

Log start
main: build = 3062 (9b596417)
main: built with cc (GCC) 14.1.1 20240522 for x86_64-pc-linux-gnu
main: seed  = 1717261994
llama_model_loader: loaded meta data with 26 key-value pairs and 243 tensors from ./Phi-3-medium-4k-instruct-q4.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = phi3
llama_model_loader: - kv   1:                               general.name str              = Phi3
llama_model_loader: - kv   2:                        phi3.context_length u32              = 4096
llama_model_loader: - kv   3:  phi3.rope.scaling.original_context_length u32              = 4096
llama_model_loader: - kv   4:                      phi3.embedding_length u32              = 5120
llama_model_loader: - kv   5:                   phi3.feed_forward_length u32              = 17920
llama_model_loader: - kv   6:                           phi3.block_count u32              = 40
llama_model_loader: - kv   7:                  phi3.attention.head_count u32              = 40
llama_model_loader: - kv   8:               phi3.attention.head_count_kv u32              = 10
llama_model_loader: - kv   9:      phi3.attention.layer_norm_rms_epsilon f32              = 0.000010
llama_model_loader: - kv  10:                  phi3.rope.dimension_count u32              = 128
llama_model_loader: - kv  11:                        phi3.rope.freq_base f32              = 10000.000000
llama_model_loader: - kv  12:                          general.file_type u32              = 15
llama_model_loader: - kv  13:                       tokenizer.ggml.model str              = llama
llama_model_loader: - kv  14:                         tokenizer.ggml.pre str              = default
llama_model_loader: - kv  15:                      tokenizer.ggml.tokens arr[str,32064]   = ["<unk>", "<s>", "</s>", "<0x00>", "<...
llama_model_loader: - kv  16:                      tokenizer.ggml.scores arr[f32,32064]   = [-1000.000000, -1000.000000, -1000.00...
llama_model_loader: - kv  17:                  tokenizer.ggml.token_type arr[i32,32064]   = [3, 3, 4, 6, 6, 6, 6, 6, 6, 6, 6, 6, ...
llama_model_loader: - kv  18:                tokenizer.ggml.bos_token_id u32              = 1
llama_model_loader: - kv  19:                tokenizer.ggml.eos_token_id u32              = 32000
llama_model_loader: - kv  20:            tokenizer.ggml.unknown_token_id u32              = 0
llama_model_loader: - kv  21:            tokenizer.ggml.padding_token_id u32              = 32000
llama_model_loader: - kv  22:               tokenizer.ggml.add_bos_token bool             = false
llama_model_loader: - kv  23:               tokenizer.ggml.add_eos_token bool             = false
llama_model_loader: - kv  24:                    tokenizer.chat_template str              = {% for message in messages %}{% if (m...
llama_model_loader: - kv  25:               general.quantization_version u32              = 2
llama_model_loader: - type  f32:   81 tensors
llama_model_loader: - type q4_K:  101 tensors
llama_model_loader: - type q5_K:   40 tensors
llama_model_loader: - type q6_K:   21 tensors
llm_load_vocab: special tokens cache size = 323
llm_load_vocab: token to piece cache size = 0.3372 MB
llm_load_print_meta: format           = GGUF V3 (latest)
llm_load_print_meta: arch             = phi3
llm_load_print_meta: vocab type       = SPM
llm_load_print_meta: n_vocab          = 32064
llm_load_print_meta: n_merges         = 0
llm_load_print_meta: n_ctx_train      = 4096
llm_load_print_meta: n_embd           = 5120
llm_load_print_meta: n_head           = 40
llm_load_print_meta: n_head_kv        = 10
llm_load_print_meta: n_layer          = 40
llm_load_print_meta: n_rot            = 128
llm_load_print_meta: n_embd_head_k    = 128
llm_load_print_meta: n_embd_head_v    = 128
llm_load_print_meta: n_gqa            = 4
llm_load_print_meta: n_embd_k_gqa     = 1280
llm_load_print_meta: n_embd_v_gqa     = 1280
llm_load_print_meta: f_norm_eps       = 0.0e+00
llm_load_print_meta: f_norm_rms_eps   = 1.0e-05
llm_load_print_meta: f_clamp_kqv      = 0.0e+00
llm_load_print_meta: f_max_alibi_bias = 0.0e+00
llm_load_print_meta: f_logit_scale    = 0.0e+00
llm_load_print_meta: n_ff             = 17920
llm_load_print_meta: n_expert         = 0
llm_load_print_meta: n_expert_used    = 0
llm_load_print_meta: causal attn      = 1
llm_load_print_meta: pooling type     = 0
llm_load_print_meta: rope type        = 2
llm_load_print_meta: rope scaling     = linear
llm_load_print_meta: freq_base_train  = 10000.0
llm_load_print_meta: freq_scale_train = 1
llm_load_print_meta: n_yarn_orig_ctx  = 4096
llm_load_print_meta: rope_finetuned   = unknown
llm_load_print_meta: ssm_d_conv       = 0
llm_load_print_meta: ssm_d_inner      = 0
llm_load_print_meta: ssm_d_state      = 0
llm_load_print_meta: ssm_dt_rank      = 0
llm_load_print_meta: model type       = 14B
llm_load_print_meta: model ftype      = Q4_K - Medium
llm_load_print_meta: model params     = 13.96 B
llm_load_print_meta: model size       = 7.98 GiB (4.91 BPW) 
llm_load_print_meta: general.name     = Phi3
llm_load_print_meta: BOS token        = 1 '<s>'
llm_load_print_meta: EOS token        = 32000 '<|endoftext|>'
llm_load_print_meta: UNK token        = 0 '<unk>'
llm_load_print_meta: PAD token        = 32000 '<|endoftext|>'
llm_load_print_meta: LF token         = 13 '<0x0A>'
llm_load_print_meta: EOT token        = 32007 '<|end|>'
ggml_cuda_init: GGML_CUDA_FORCE_MMQ:   no
ggml_cuda_init: CUDA_USE_TENSOR_CORES: yes
ggml_cuda_init: found 1 CUDA devices:
  Device 0: NVIDIA GeForce RTX 3060, compute capability 8.6, VMM: yes
llm_load_tensors: ggml ctx size =    0.28 MiB
llm_load_tensors: offloading 40 repeating layers to GPU
llm_load_tensors: offloading non-repeating layers to GPU
llm_load_tensors: offloaded 41/41 layers to GPU
llm_load_tensors:        CPU buffer size =    88.07 MiB
llm_load_tensors:      CUDA0 buffer size =  8081.18 MiB
...............................................................................................
llama_new_context_with_model: n_ctx      = 4096
llama_new_context_with_model: n_batch    = 2048
llama_new_context_with_model: n_ubatch   = 512
llama_new_context_with_model: flash_attn = 1
llama_new_context_with_model: freq_base  = 10000.0
llama_new_context_with_model: freq_scale = 1
llama_kv_cache_init:      CUDA0 KV buffer size =   800.00 MiB
llama_new_context_with_model: KV self size  =  800.00 MiB, K (f16):  400.00 MiB, V (f16):  400.00 MiB
llama_new_context_with_model:  CUDA_Host  output buffer size =     0.12 MiB
llama_new_context_with_model:      CUDA0 compute buffer size =   182.00 MiB
llama_new_context_with_model:  CUDA_Host compute buffer size =    18.01 MiB
llama_new_context_with_model: graph nodes  = 1447
llama_new_context_with_model: graph splits = 2

system_info: n_threads = 14 / 28 | AVX = 1 | AVX_VNNI = 0 | AVX2 = 1 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | AVX512_BF16 = 0 | FMA = 1 | NEON = 0 | SVE = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 1 | SSSE3 = 1 | VSX = 0 | MATMUL_INT8 = 0 | LLAMAFILE = 1 | 
sampling: 
	repeat_last_n = 64, repeat_penalty = 1.000, frequency_penalty = 0.000, presence_penalty = 0.000
	top_k = 40, tfs_z = 1.000, top_p = 0.950, min_p = 0.050, typical_p = 1.000, temp = 0.800
	mirostat = 0, mirostat_lr = 0.100, mirostat_ent = 5.000
sampling order: 
CFG -> Penalties -> top_k -> tfs_z -> typical_p -> top_p -> min_p -> temperature 
generate: n_ctx = 4096, n_batch = 2048, n_predict = -1, n_keep = 0


 Once upon on time 
GGML_ASSERT: ggml-cuda/template-instances/../fattn-vec-f32.cuh:285: precision == GGML_PREC_DEFAULT
ptrace: Operation not permitted.
No stack.
The program is not being run.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
build Compilation issues ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs python python script changes Review Complexity : High Generally require indepth knowledge of LLMs or GPUs testing Everything test related
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants