Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUDA: faster FlashAttention for batch sizes > 1 #6646

Merged

Conversation

JohannesGaessler
Copy link
Collaborator

This PR does the following:

  • Generalize the code for running multiple FlashAttention blocks in parallel to work for batch sizes > 1. This helps with performance for small batch sizes and small numbers of attention heads.
  • Extend the CUDA device info to also include the number of streaming multiprocessors.
  • Refactor the FlashAttention host code. Instead of macros it now uses templates. Instead of fixed values for parallel_blocks the value is chosen based on the number of streaming multiprocessors; as long as the increased number of blocks still finishes in a single wave it's basically always worthwhile to do.
  • Add two new macros FP16_AVAILABLE and FP16_MMA_AVAILABLE that can be used to determine the available of general FP16 intrinsics and FP16 tensor cores in device code.

Performance relative to current FlashAttention kernels:

GPU Model Batch size Test t/s gg/flash-attn t/s jg/flash-attn-20 Speedup
RTX 4090 gemma 2B F16 1 pp4096 156.93 157.16 1.00
RTX 4090 gemma 2B F16 2 pp4096 245.72 275.30 1.12
RTX 4090 gemma 2B F16 4 pp4096 501.66 562.81 1.12
RTX 4090 gemma 2B F16 8 pp4096 994.84 1111.80 1.12
RTX 4090 gemma 2B F16 16 pp4096 1963.72 2180.48 1.11
RTX 4090 gemma 2B F16 32 pp4096 3763.81 4252.16 1.13
RTX 4090 gemma 2B F16 64 pp4096 6609.81 7601.76 1.15
RTX 4090 gemma 2B F16 128 pp4096 12166.87 13375.50 1.10
RTX 4090 gemma 2B F16 256 pp4096 20393.39 20825.66 1.02
RTX 4090 gemma 2B F16 512 pp4096 28404.84 28339.54 1.00
RTX 4090 gemma 2B F16 1024 pp4096 29462.16 29472.00 1.00
RTX 4090 gemma 2B F16 2048 pp4096 30085.55 30055.65 1.00
RTX 4090 gemma 2B F16 4096 pp4096 30343.08 30317.35 1.00
RTX 4090 llama 7B Q4_0 1 pp4096 146.68 146.38 1.00
RTX 4090 llama 7B Q4_0 2 pp4096 276.31 279.32 1.01
RTX 4090 llama 7B Q4_0 4 pp4096 547.74 551.97 1.01
RTX 4090 llama 7B Q4_0 8 pp4096 929.35 931.75 1.00
RTX 4090 llama 7B Q4_0 16 pp4096 1317.17 1310.32 0.99
RTX 4090 llama 7B Q4_0 32 pp4096 1600.31 1606.50 1.00
RTX 4090 llama 7B Q4_0 64 pp4096 2119.35 2168.12 1.02
RTX 4090 llama 7B Q4_0 128 pp4096 3705.01 3710.46 1.00
RTX 4090 llama 7B Q4_0 256 pp4096 6069.55 6057.19 1.00
RTX 4090 llama 7B Q4_0 512 pp4096 7829.50 7821.58 1.00
RTX 4090 llama 7B Q4_0 1024 pp4096 7843.75 7836.71 1.00
RTX 4090 llama 7B Q4_0 2048 pp4096 7864.92 7846.92 1.00
RTX 4090 llama 7B Q4_0 4096 pp4096 7854.40 7847.28 1.00
RTX 4090 phi2 3B F16 1 pp4096 123.51 123.77 1.00
RTX 4090 phi2 3B F16 2 pp4096 203.90 218.77 1.07
RTX 4090 phi2 3B F16 4 pp4096 405.53 434.01 1.07
RTX 4090 phi2 3B F16 8 pp4096 803.90 852.98 1.06
RTX 4090 phi2 3B F16 16 pp4096 1571.15 1673.18 1.06
RTX 4090 phi2 3B F16 32 pp4096 3085.40 3206.54 1.04
RTX 4090 phi2 3B F16 64 pp4096 5461.86 5863.08 1.07
RTX 4090 phi2 3B F16 128 pp4096 9513.56 9499.56 1.00
RTX 4090 phi2 3B F16 256 pp4096 15202.10 15237.75 1.00
RTX 4090 phi2 3B F16 512 pp4096 17564.81 17577.25 1.00
RTX 4090 phi2 3B F16 1024 pp4096 17506.02 17510.25 1.00
RTX 4090 phi2 3B F16 2048 pp4096 17565.62 17571.12 1.00
RTX 4090 phi2 3B F16 4096 pp4096 17604.26 17618.98 1.00
RTX 3090 gemma 2B F16 1 pp4096 135.58 135.96 1.00
RTX 3090 gemma 2B F16 2 pp4096 199.72 229.19 1.15
RTX 3090 gemma 2B F16 4 pp4096 401.66 462.41 1.15
RTX 3090 gemma 2B F16 8 pp4096 795.51 912.27 1.15
RTX 3090 gemma 2B F16 16 pp4096 1553.86 1768.38 1.14
RTX 3090 gemma 2B F16 32 pp4096 2909.64 3246.80 1.12
RTX 3090 gemma 2B F16 64 pp4096 5255.95 5941.40 1.13
RTX 3090 gemma 2B F16 128 pp4096 9130.67 10118.23 1.11
RTX 3090 gemma 2B F16 256 pp4096 11761.76 11962.81 1.02
RTX 3090 gemma 2B F16 512 pp4096 13284.64 13149.46 0.99
RTX 3090 gemma 2B F16 1024 pp4096 13546.97 13429.50 0.99
RTX 3090 gemma 2B F16 2048 pp4096 13625.12 13651.72 1.00
RTX 3090 gemma 2B F16 4096 pp4096 13673.57 13569.27 0.99
RTX 3090 llama 7B Q4_0 1 pp4096 125.30 126.03 1.01
RTX 3090 llama 7B Q4_0 2 pp4096 225.69 240.07 1.06
RTX 3090 llama 7B Q4_0 4 pp4096 383.69 406.05 1.06
RTX 3090 llama 7B Q4_0 8 pp4096 533.13 542.02 1.02
RTX 3090 llama 7B Q4_0 16 pp4096 569.25 571.30 1.00
RTX 3090 llama 7B Q4_0 32 pp4096 629.32 634.80 1.01
RTX 3090 llama 7B Q4_0 64 pp4096 1320.76 1348.39 1.02
RTX 3090 llama 7B Q4_0 128 pp4096 2236.75 2244.31 1.00
RTX 3090 llama 7B Q4_0 256 pp4096 3149.86 3168.56 1.01
RTX 3090 llama 7B Q4_0 512 pp4096 3706.48 3737.65 1.01
RTX 3090 llama 7B Q4_0 1024 pp4096 3718.74 3745.87 1.01
RTX 3090 llama 7B Q4_0 2048 pp4096 3730.79 3746.94 1.00
RTX 3090 llama 7B Q4_0 4096 pp4096 3735.57 3749.07 1.00
RTX 3090 phi2 3B F16 1 pp4096 106.93 106.64 1.00
RTX 3090 phi2 3B F16 2 pp4096 170.42 185.58 1.09
RTX 3090 phi2 3B F16 4 pp4096 337.30 369.12 1.09
RTX 3090 phi2 3B F16 8 pp4096 666.42 726.71 1.09
RTX 3090 phi2 3B F16 16 pp4096 1293.18 1339.44 1.04
RTX 3090 phi2 3B F16 32 pp4096 2434.70 2535.57 1.04
RTX 3090 phi2 3B F16 64 pp4096 4246.23 4539.63 1.07
RTX 3090 phi2 3B F16 128 pp4096 6835.86 6897.72 1.01
RTX 3090 phi2 3B F16 256 pp4096 7753.84 7776.34 1.00
RTX 3090 phi2 3B F16 512 pp4096 8814.01 8762.43 0.99
RTX 3090 phi2 3B F16 1024 pp4096 8585.54 8605.73 1.00
RTX 3090 phi2 3B F16 2048 pp4096 8668.73 8666.90 1.00
RTX 3090 phi2 3B F16 4096 pp4096 8557.30 8580.58 1.00

Performance relative to master:

GPU Model Batch size Test t/s master t/s jg/flash-attn-20 Speedup
RTX 4090 gemma 2B F16 1 pp4096 148.07 157.16 1.06
RTX 4090 gemma 2B F16 2 pp4096 271.97 275.30 1.01
RTX 4090 gemma 2B F16 4 pp4096 557.04 562.81 1.01
RTX 4090 gemma 2B F16 8 pp4096 1101.79 1111.80 1.01
RTX 4090 gemma 2B F16 16 pp4096 2155.89 2180.48 1.01
RTX 4090 gemma 2B F16 32 pp4096 4169.48 4252.16 1.02
RTX 4090 gemma 2B F16 64 pp4096 7436.70 7601.76 1.02
RTX 4090 gemma 2B F16 128 pp4096 12903.88 13375.50 1.04
RTX 4090 gemma 2B F16 256 pp4096 20643.73 20825.66 1.01
RTX 4090 gemma 2B F16 512 pp4096 26408.48 28339.54 1.07
RTX 4090 gemma 2B F16 1024 pp4096 27284.12 29472.00 1.08
RTX 4090 gemma 2B F16 2048 pp4096 27874.57 30055.65 1.08
RTX 4090 gemma 2B F16 4096 pp4096 28060.68 30317.35 1.08
RTX 4090 llama 7B Q4_0 1 pp4096 135.48 146.38 1.08
RTX 4090 llama 7B Q4_0 2 pp4096 270.53 279.32 1.03
RTX 4090 llama 7B Q4_0 4 pp4096 532.41 551.97 1.04
RTX 4090 llama 7B Q4_0 8 pp4096 902.26 931.75 1.03
RTX 4090 llama 7B Q4_0 16 pp4096 1268.07 1310.32 1.03
RTX 4090 llama 7B Q4_0 32 pp4096 1550.51 1606.50 1.04
RTX 4090 llama 7B Q4_0 64 pp4096 2078.30 2168.12 1.04
RTX 4090 llama 7B Q4_0 128 pp4096 3380.61 3710.46 1.10
RTX 4090 llama 7B Q4_0 256 pp4096 4705.68 6057.19 1.29
RTX 4090 llama 7B Q4_0 512 pp4096 5487.18 7821.58 1.43
RTX 4090 llama 7B Q4_0 1024 pp4096 5502.76 7836.71 1.42
RTX 4090 llama 7B Q4_0 2048 pp4096 5498.61 7846.92 1.43
RTX 4090 llama 7B Q4_0 4096 pp4096 5507.25 7847.28 1.42
RTX 4090 phi2 3B F16 1 pp4096 119.42 123.77 1.04
RTX 4090 phi2 3B F16 2 pp4096 211.06 218.77 1.04
RTX 4090 phi2 3B F16 4 pp4096 418.31 434.01 1.04
RTX 4090 phi2 3B F16 8 pp4096 816.89 852.98 1.04
RTX 4090 phi2 3B F16 16 pp4096 1563.67 1673.18 1.07
RTX 4090 phi2 3B F16 32 pp4096 2922.31 3206.54 1.10
RTX 4090 phi2 3B F16 64 pp4096 5083.17 5863.08 1.15
RTX 4090 phi2 3B F16 128 pp4096 7471.62 9499.56 1.27
RTX 4090 phi2 3B F16 256 pp4096 9204.74 15237.75 1.66
RTX 4090 phi2 3B F16 512 pp4096 9661.48 17577.25 1.82
RTX 4090 phi2 3B F16 1024 pp4096 9702.15 17510.25 1.80
RTX 4090 phi2 3B F16 2048 pp4096 9720.99 17571.12 1.81
RTX 4090 phi2 3B F16 4096 pp4096 9733.62 17618.98 1.81
RTX 3090 gemma 2B F16 1 pp4096 125.05 135.96 1.09
RTX 3090 gemma 2B F16 2 pp4096 224.34 229.19 1.02
RTX 3090 gemma 2B F16 4 pp4096 447.26 462.41 1.03
RTX 3090 gemma 2B F16 8 pp4096 862.20 912.27 1.06
RTX 3090 gemma 2B F16 16 pp4096 1592.68 1768.38 1.11
RTX 3090 gemma 2B F16 32 pp4096 3192.15 3246.80 1.02
RTX 3090 gemma 2B F16 64 pp4096 5759.86 5941.40 1.03
RTX 3090 gemma 2B F16 128 pp4096 9357.61 10118.23 1.08
RTX 3090 gemma 2B F16 256 pp4096 10963.78 11962.81 1.09
RTX 3090 gemma 2B F16 512 pp4096 12014.60 13149.46 1.09
RTX 3090 gemma 2B F16 1024 pp4096 12209.17 13429.50 1.10
RTX 3090 gemma 2B F16 2048 pp4096 12230.41 13651.72 1.12
RTX 3090 gemma 2B F16 4096 pp4096 12390.53 13569.27 1.10
RTX 3090 llama 7B Q4_0 1 pp4096 114.09 126.03 1.10
RTX 3090 llama 7B Q4_0 2 pp4096 214.85 240.07 1.12
RTX 3090 llama 7B Q4_0 4 pp4096 351.06 406.05 1.16
RTX 3090 llama 7B Q4_0 8 pp4096 466.90 542.02 1.16
RTX 3090 llama 7B Q4_0 16 pp4096 476.37 571.30 1.20
RTX 3090 llama 7B Q4_0 32 pp4096 615.67 634.80 1.03
RTX 3090 llama 7B Q4_0 64 pp4096 1214.87 1348.39 1.11
RTX 3090 llama 7B Q4_0 128 pp4096 1949.06 2244.31 1.15
RTX 3090 llama 7B Q4_0 256 pp4096 2635.02 3168.56 1.20
RTX 3090 llama 7B Q4_0 512 pp4096 3008.32 3737.65 1.24
RTX 3090 llama 7B Q4_0 1024 pp4096 3009.86 3745.87 1.24
RTX 3090 llama 7B Q4_0 2048 pp4096 3015.45 3746.94 1.24
RTX 3090 llama 7B Q4_0 4096 pp4096 3011.98 3749.07 1.24
RTX 3090 phi2 3B F16 1 pp4096 101.03 106.64 1.06
RTX 3090 phi2 3B F16 2 pp4096 176.91 185.58 1.05
RTX 3090 phi2 3B F16 4 pp4096 343.75 369.12 1.07
RTX 3090 phi2 3B F16 8 pp4096 659.73 726.71 1.10
RTX 3090 phi2 3B F16 16 pp4096 1218.31 1339.44 1.10
RTX 3090 phi2 3B F16 32 pp4096 2154.24 2535.57 1.18
RTX 3090 phi2 3B F16 64 pp4096 3501.02 4539.63 1.30
RTX 3090 phi2 3B F16 128 pp4096 4817.51 6897.72 1.43
RTX 3090 phi2 3B F16 256 pp4096 5339.66 7776.34 1.46
RTX 3090 phi2 3B F16 512 pp4096 5731.00 8762.43 1.53
RTX 3090 phi2 3B F16 1024 pp4096 5747.72 8605.73 1.50
RTX 3090 phi2 3B F16 2048 pp4096 5766.92 8666.90 1.50
RTX 3090 phi2 3B F16 4096 pp4096 5782.52 8580.58 1.48

On my systems FlashAttention now seems to be universally faster than master.

@JohannesGaessler JohannesGaessler changed the title Jg/flash attn 20 CUDA: faster FlashAttention for batch sizes > 1 Apr 12, 2024
@kalomaze
Copy link
Contributor

kalomaze commented Apr 13, 2024

Maybe there's something I'm missing here, but doesn't the mention of ~100t/s prompt processing for batch size 1 seem inaccurate? I think you should be able to hit 1000t/s on q4_0 7b for bs1 on master easy with a 3090, unless the actual token batch count is what is described here and it's not about concurrent requests?

I think I'm just mixing up my definitions of "batching" here and it's not about concurrent requests.

@kalomaze
Copy link
Contributor

kalomaze commented Apr 13, 2024

On my machine, hosting to llama.cpp server, I'm able to get a couple coherent tokens, sometimes up to a short paragraph in length, before garbage output consistently comes out for Cohere's Command-R 35b at q5_K_S:

but it's still frustrating. being nothing more than an error  as if i was an afterthoughtserviceNameserviceNameserviceNameserviceNameserviceNameserviceNameserviceNameserviceNameserviceNameserviceNameserviceNameserviceNameserviceNameserviceNameserviceNameserviceNameserviceNameserviceNameserviceNameserviceNameserviceNameserviceNameserviceNameserviceNameserviceNameserviceNameserviceNameserviceNameserviceNameserviceNameserviceNameserviceNameserviceNameserviceNameserviceNameserviceNameserviceNameserviceNameserviceNameserviceNameserviceNameserviceNameserviceNameserviceNameserviceNameserviceNameserviceNameserviceNameserviceNameserviceNameserviceNameserviceNameserviceNameserviceNameserviceNameserviceNameserviceNameserviceNameserviceNameserviceNameserviceNameserviceNameserviceName

It always degenerates into "Nameservice" repeatedly.

@JohannesGaessler
Copy link
Collaborator Author

Maybe there's something I'm missing here, but doesn't the mention of ~100t/s prompt processing for batch size 1 seem inaccurate? I think you should be able to hit 1000t/s on q4_0 7b for bs1 on master easy with a 3090, unless the actual token batch count is what is described here and it's not about concurrent requests?

I think I'm just mixing up my definitions of "batching" here and it's not about concurrent requests.

The numbers are from running something like

export model_name=llama_2-7b && export quantization=q4_0
./llama-bench --model models/opt/${model_name}-${quantization}.gguf -n 0 -p 4096 -r 1 -b 1,2,4,8,16,32,64,128,256,512,1024,2048,4096

So they are representative of a single concurrent context with a varying number of tokens processed in parallel. You would essentially get these speeds for prompt processing if you were to set this batch size. batch size 1 is equivalent to generation.

On my machine, hosting to llama.cpp server, I'm able to get a couple coherent tokens, sometimes up to a short paragraph in length, before garbage output consistently comes out for Cohere's Command-R 35b at q5_K_S:

I can reproduce the issue. It's a bug in the batch size 1 kernel that is already on gg/flash-attn.

@JohannesGaessler
Copy link
Collaborator Author

The problem seemed to be that there was a race condition that - depending on the specifics - could sometimes result in garbage outputs. @kalomaze can you confirm that the fix works?

Copy link
Contributor

github-actions bot commented Apr 14, 2024

📈 llama.cpp server for bench-server-baseline on Standard_NC4as_T4_v3 for phi-2-q4_0: 422 iterations 🚀

Expand details for performance related PR only
  • Concurrent users: 8, duration: 10m
  • HTTP request : avg=11225.04ms p(95)=28440.42ms fails=, finish reason: stop=371 truncated=51
  • Prompt processing (pp): avg=122.65tk/s p(95)=544.61tk/s
  • Token generation (tg): avg=27.63tk/s p(95)=34.85tk/s
  • ggml-org/models/phi-2/ggml-model-q4_0.gguf parallel=8 ctx-size=16384 ngl=33 batch-size=2048 ubatch-size=256 pp=1024 pp+tg=2048 branch=jg/flash-attn-20 commit=44ca5764d621d8693f8f8c01b9b920d7620c9076

prompt_tokens_seconds

More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 422 iterations"
    y-axis "llamacpp:prompt_tokens_seconds"
    x-axis "llamacpp:prompt_tokens_seconds" 1713379977 --> 1713380611
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 255.41, 255.41, 255.41, 255.41, 255.41, 498.61, 498.61, 498.61, 498.61, 498.61, 345.6, 345.6, 345.6, 345.6, 345.6, 351.71, 351.71, 351.71, 351.71, 351.71, 371.45, 371.45, 371.45, 371.45, 371.45, 384.91, 384.91, 384.91, 384.91, 384.91, 417.85, 417.85, 417.85, 417.85, 417.85, 435.67, 435.67, 435.67, 435.67, 435.67, 439.88, 439.88, 439.88, 439.88, 439.88, 445.91, 445.91, 445.91, 445.91, 445.91, 476.3, 476.3, 476.3, 476.3, 476.3, 481.7, 481.7, 481.7, 481.7, 481.7, 495.7, 495.7, 495.7, 495.7, 495.7, 499.36, 499.36, 499.36, 499.36, 499.36, 502.58, 502.58, 502.58, 502.58, 502.58, 518.23, 518.23, 518.23, 518.23, 518.23, 552.48, 552.48, 552.48, 552.48, 552.48, 571.56, 571.56, 571.56, 571.56, 571.56, 582.15, 582.15, 582.15, 582.15, 582.15, 577.58, 577.58, 577.58, 577.58, 577.58, 583.31, 583.31, 583.31, 583.31, 583.31, 584.24, 584.24, 584.24, 584.24, 584.24, 584.75, 584.75, 584.75, 584.75, 584.75, 597.82, 597.82, 597.82, 597.82, 597.82, 601.04, 601.04, 601.04, 601.04, 601.04, 602.99, 602.99, 602.99, 602.99, 602.99, 603.74, 603.74, 603.74, 603.74, 603.74, 608.94, 608.94, 608.94, 608.94, 608.94, 609.73, 609.73, 609.73, 609.73, 609.73, 611.06, 611.06, 611.06, 611.06, 611.06, 621.85, 621.85, 621.85, 621.85, 621.85, 592.65, 592.65, 592.65, 592.65, 592.65, 596.61, 596.61, 596.61, 596.61, 596.61, 599.67, 599.67, 599.67, 599.67, 599.67, 610.76, 610.76, 610.76, 610.76, 610.76, 609.51, 609.51, 609.51, 609.51, 609.51, 607.82, 607.82, 607.82, 607.82, 607.82, 609.39, 609.39, 609.39, 609.39, 609.39, 612.07, 612.07, 612.07, 612.07, 612.07, 615.23, 615.23, 615.23, 615.23, 615.23, 615.88, 615.88, 615.88, 615.88, 615.88, 615.51, 615.51, 615.51, 615.51, 615.51, 619.11, 619.11, 619.11, 619.11, 619.11, 623.89, 623.89, 623.89, 623.89, 623.89, 628.32, 628.32, 628.32, 628.32, 628.32, 635.16, 635.16, 635.16, 635.16, 635.16, 643.75, 643.75, 643.75, 643.75, 643.75, 644.08, 644.08, 644.08, 644.08, 644.08, 643.68, 643.68, 643.68, 643.68, 643.68, 645.46, 645.46, 645.46, 645.46, 645.46, 646.99, 646.99, 646.99, 646.99, 646.99, 653.32, 653.32, 653.32, 653.32, 653.32, 616.95, 616.95, 616.95, 616.95, 616.95, 616.59, 616.59, 616.59, 616.59, 616.59, 616.29, 616.29, 616.29, 616.29, 616.29, 615.81, 615.81, 615.81, 615.81, 615.81, 614.99, 614.99, 614.99, 614.99, 614.99, 614.37, 614.37, 614.37, 614.37, 614.37, 614.02, 614.02, 614.02, 614.02, 614.02, 618.27, 618.27, 618.27, 618.27, 618.27, 621.37, 621.37, 621.37, 621.37, 621.37, 621.37, 621.37, 621.37]
                    
Loading
predicted_tokens_seconds
More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 422 iterations"
    y-axis "llamacpp:predicted_tokens_seconds"
    x-axis "llamacpp:predicted_tokens_seconds" 1713379977 --> 1713380611
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 34.7, 34.7, 34.7, 34.7, 34.7, 32.3, 32.3, 32.3, 32.3, 32.3, 32.3, 32.3, 32.3, 32.3, 32.3, 22.79, 22.79, 22.79, 22.79, 22.79, 22.59, 22.59, 22.59, 22.59, 22.59, 21.19, 21.19, 21.19, 21.19, 21.19, 21.03, 21.03, 21.03, 21.03, 21.03, 20.89, 20.89, 20.89, 20.89, 20.89, 20.94, 20.94, 20.94, 20.94, 20.94, 22.52, 22.52, 22.52, 22.52, 22.52, 22.93, 22.93, 22.93, 22.93, 22.93, 23.09, 23.09, 23.09, 23.09, 23.09, 23.21, 23.21, 23.21, 23.21, 23.21, 23.3, 23.3, 23.3, 23.3, 23.3, 23.19, 23.19, 23.19, 23.19, 23.19, 22.7, 22.7, 22.7, 22.7, 22.7, 22.57, 22.57, 22.57, 22.57, 22.57, 22.32, 22.32, 22.32, 22.32, 22.32, 22.29, 22.29, 22.29, 22.29, 22.29, 21.81, 21.81, 21.81, 21.81, 21.81, 21.52, 21.52, 21.52, 21.52, 21.52, 21.49, 21.49, 21.49, 21.49, 21.49, 21.55, 21.55, 21.55, 21.55, 21.55, 21.63, 21.63, 21.63, 21.63, 21.63, 21.32, 21.32, 21.32, 21.32, 21.32, 21.26, 21.26, 21.26, 21.26, 21.26, 21.23, 21.23, 21.23, 21.23, 21.23, 21.0, 21.0, 21.0, 21.0, 21.0, 20.99, 20.99, 20.99, 20.99, 20.99, 20.97, 20.97, 20.97, 20.97, 20.97, 21.17, 21.17, 21.17, 21.17, 21.17, 21.13, 21.13, 21.13, 21.13, 21.13, 21.3, 21.3, 21.3, 21.3, 21.3, 21.47, 21.47, 21.47, 21.47, 21.47, 21.45, 21.45, 21.45, 21.45, 21.45, 21.35, 21.35, 21.35, 21.35, 21.35, 21.2, 21.2, 21.2, 21.2, 21.2, 21.25, 21.25, 21.25, 21.25, 21.25, 21.39, 21.39, 21.39, 21.39, 21.39, 21.51, 21.51, 21.51, 21.51, 21.51, 21.61, 21.61, 21.61, 21.61, 21.61, 21.61, 21.61, 21.61, 21.61, 21.61, 21.78, 21.78, 21.78, 21.78, 21.78, 21.82, 21.82, 21.82, 21.82, 21.82, 21.82, 21.82, 21.82, 21.82, 21.82, 21.76, 21.76, 21.76, 21.76, 21.76, 21.67, 21.67, 21.67, 21.67, 21.67, 21.64, 21.64, 21.64, 21.64, 21.64, 21.65, 21.65, 21.65, 21.65, 21.65, 21.85, 21.85, 21.85, 21.85, 21.85, 21.94, 21.94, 21.94, 21.94, 21.94, 22.07, 22.07, 22.07, 22.07, 22.07, 22.07, 22.07, 22.07, 22.07, 22.07, 21.98, 21.98, 21.98, 21.98, 21.98, 21.9, 21.9, 21.9, 21.9, 21.9, 21.86, 21.86, 21.86, 21.86, 21.86, 21.53, 21.53, 21.53, 21.53, 21.53, 21.3, 21.3, 21.3, 21.3, 21.3, 20.53, 20.53, 20.53, 20.53, 20.53, 20.5, 20.5, 20.5, 20.5, 20.5, 20.51, 20.51, 20.51, 20.51, 20.51, 20.52, 20.52, 20.52]
                    
Loading

Details

kv_cache_usage_ratio

More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 422 iterations"
    y-axis "llamacpp:kv_cache_usage_ratio"
    x-axis "llamacpp:kv_cache_usage_ratio" 1713379977 --> 1713380611
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.07, 0.07, 0.07, 0.07, 0.07, 0.31, 0.31, 0.31, 0.31, 0.31, 0.42, 0.42, 0.42, 0.42, 0.42, 0.22, 0.22, 0.22, 0.22, 0.22, 0.28, 0.28, 0.28, 0.28, 0.28, 0.16, 0.16, 0.16, 0.16, 0.16, 0.22, 0.22, 0.22, 0.22, 0.22, 0.18, 0.18, 0.18, 0.18, 0.18, 0.13, 0.13, 0.13, 0.13, 0.13, 0.11, 0.11, 0.11, 0.11, 0.11, 0.17, 0.17, 0.17, 0.17, 0.17, 0.19, 0.19, 0.19, 0.19, 0.19, 0.21, 0.21, 0.21, 0.21, 0.21, 0.17, 0.17, 0.17, 0.17, 0.17, 0.25, 0.25, 0.25, 0.25, 0.25, 0.14, 0.14, 0.14, 0.14, 0.14, 0.25, 0.25, 0.25, 0.25, 0.25, 0.16, 0.16, 0.16, 0.16, 0.16, 0.29, 0.29, 0.29, 0.29, 0.29, 0.21, 0.21, 0.21, 0.21, 0.21, 0.25, 0.25, 0.25, 0.25, 0.25, 0.21, 0.21, 0.21, 0.21, 0.21, 0.19, 0.19, 0.19, 0.19, 0.19, 0.24, 0.24, 0.24, 0.24, 0.24, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.29, 0.29, 0.29, 0.29, 0.29, 0.16, 0.16, 0.16, 0.16, 0.16, 0.17, 0.17, 0.17, 0.17, 0.17, 0.14, 0.14, 0.14, 0.14, 0.14, 0.23, 0.23, 0.23, 0.23, 0.23, 0.12, 0.12, 0.12, 0.12, 0.12, 0.09, 0.09, 0.09, 0.09, 0.09, 0.13, 0.13, 0.13, 0.13, 0.13, 0.32, 0.32, 0.32, 0.32, 0.32, 0.22, 0.22, 0.22, 0.22, 0.22, 0.17, 0.17, 0.17, 0.17, 0.17, 0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.14, 0.14, 0.14, 0.14, 0.14, 0.17, 0.17, 0.17, 0.17, 0.17, 0.15, 0.15, 0.15, 0.15, 0.15, 0.12, 0.12, 0.12, 0.12, 0.12, 0.2, 0.2, 0.2, 0.2, 0.2, 0.11, 0.11, 0.11, 0.11, 0.11, 0.15, 0.15, 0.15, 0.15, 0.15, 0.17, 0.17, 0.17, 0.17, 0.17, 0.1, 0.1, 0.1, 0.1, 0.1, 0.2, 0.2, 0.2, 0.2, 0.2, 0.14, 0.14, 0.14, 0.14, 0.14, 0.09, 0.09, 0.09, 0.09, 0.09, 0.18, 0.18, 0.18, 0.18, 0.18, 0.47, 0.47, 0.47, 0.47, 0.47, 0.5, 0.5, 0.5, 0.5, 0.5, 0.52, 0.52, 0.52, 0.52, 0.52, 0.6, 0.6, 0.6, 0.6, 0.6, 0.52, 0.52, 0.52, 0.52, 0.52, 0.43, 0.43, 0.43, 0.43, 0.43, 0.1, 0.1, 0.1, 0.1, 0.1, 0.17, 0.17, 0.17, 0.17, 0.17, 0.21, 0.21, 0.21, 0.21, 0.21, 0.27, 0.27, 0.27]
                    
Loading
requests_processing
More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 422 iterations"
    y-axis "llamacpp:requests_processing"
    x-axis "llamacpp:requests_processing" 1713379977 --> 1713380611
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 7.0, 7.0, 7.0, 7.0, 7.0, 4.0, 4.0, 4.0, 4.0, 4.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 6.0, 6.0, 6.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 2.0, 2.0, 2.0, 2.0, 2.0, 5.0, 5.0, 5.0, 5.0, 5.0, 7.0, 7.0, 7.0, 7.0, 7.0, 6.0, 6.0, 6.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 6.0, 6.0, 6.0, 6.0, 6.0, 4.0, 4.0, 4.0, 4.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 5.0, 5.0, 5.0, 5.0, 5.0, 7.0, 7.0, 7.0, 7.0, 7.0, 6.0, 6.0, 6.0, 6.0, 6.0, 7.0, 7.0, 7.0, 7.0, 7.0, 5.0, 5.0, 5.0, 5.0, 5.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 6.0, 6.0, 6.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 5.0, 5.0, 5.0, 5.0, 5.0, 4.0, 4.0, 4.0, 4.0, 4.0, 6.0, 6.0, 6.0, 6.0, 6.0, 4.0, 4.0, 4.0, 4.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0, 8.0, 8.0, 8.0, 8.0, 8.0, 4.0, 4.0, 4.0, 4.0, 4.0, 2.0, 2.0, 2.0, 2.0, 2.0, 8.0, 8.0, 8.0, 8.0, 8.0, 5.0, 5.0, 5.0, 5.0, 5.0, 4.0, 4.0, 4.0, 4.0, 4.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 7.0, 7.0, 7.0, 7.0, 7.0, 1.0, 1.0, 1.0, 1.0, 1.0, 3.0, 3.0, 3.0, 3.0, 3.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 5.0, 5.0, 5.0, 5.0, 5.0, 3.0, 3.0, 3.0]
                    
Loading

@kalomaze
Copy link
Contributor

kalomaze commented Apr 14, 2024

The problem seemed to be that there was a race condition that - depending on the specifics - could sometimes result in garbage outputs. @kalomaze can you confirm that the fix works?

Works fine on my end now! Responses are perfectly normal on regen / fresh context ingest / etc.

I ran a full offload of a q6_K 20b model (not a new / noteworthy one, just was convenient) to test out different nvidia-smi power limits.

GPU Model Batch size Test t/s master t/s jg/flash-attn-20 Speedup
RTX 3090 (250w) internlm 20B Q6_K 512 pp4096 873 1112 1.27
RTX 3090 (300w) internlm 20B Q6_K 512 pp4096 1159 1403 1.21

Seems great @JohannesGaessler. Is there anything else left for FlashAttention to be merged into mainline

@phymbert
Copy link
Collaborator

📈 llama.cpp server for bench-server-baseline on Standard_NC4as_T4_v3 for phi-2-q4_0: 272 iterations 🚀

@JohannesGaessler Thanks for the work here, I am out of the flash attention branches, but I would like to understand why the bench workflow is slower, 272 iterations here versus 426 on master. It runs on phi-2 Q4_0. Thanks

@JohannesGaessler
Copy link
Collaborator Author

Seems great @JohannesGaessler. Is there anything else left for FlashAttention to be merged into mainline

You'll have to ask @ggerganov about that.

@JohannesGaessler Thanks for the work here, I am out of the flash attention branches, but I would like to understand why the bench workflow is slower, 272 iterations here versus #6635 (comment) on master. It runs on phi-2 Q4_0. Thanks

For this PR the bot reports 241.82 t/s pp and 42.58 t/s tg. In the post you linked it's 118.38 t/s pp and 24.06 t/s tg. Aren't those the relevant numbers?

@phymbert
Copy link
Collaborator

phymbert commented Apr 14, 2024

For this PR the bot reports 241.82 t/s pp and 42.58 t/s tg. In the post you linked it's 118.38 t/s pp and 24.06 t/s tg. Aren't those the relevant numbers?

Yes you are right thanks. It's a lot faster per sequence. Then I will investigate why the server is "globally" slower with 8 concurrent sequences/slots/users. @ggerganov I think we need to fix that before merging.

EDIT: No this not the relevant numbers

@phymbert
Copy link
Collaborator

phymbert commented Apr 14, 2024

For this PR the bot reports 241.82 t/s pp and 42.58 t/s tg. In the post you linked it's 118.38 t/s pp and 24.06 t/s
tg. Aren't those the relevant numbers?

No, the bench has been fixed since, it is now using the metrics from the server side, not the client side.
If you look at the PP graph (which is built by the server + prometheus), the prompt processing is twice slower:

master gg/flash-attn JohannesGaessler:jg/flash-attn-20
iterations 426 264 272
PP prompt_tokens_seconds prompt_tokens_seconds prompt_tokens_seconds
TG predicted_tokens_seconds predicted_tokens_seconds predicted_tokens_seconds

@JohannesGaessler
Copy link
Collaborator Author

I don't have a comprehensive overview of the changes in gg/flash-attn. And I don't know the details of how the server benchmark presented here works either. If the issue is due to the CUDA kernels and not some other changes it could just be that I tested and optimized the code on Ampere and Ada Lovelace and that for whatever reason the performance is just bad on Turing. It could also be that it's an issue specific to Phi-2 since that particular model has a janky head size of 80, If you have access to T4s, do a quick benchmark on them to check whether the issue is specific to T4s, to the server, or to Phi-2. Or tell me how I can run this benchmark locally on my own hardware.

@phymbert
Copy link
Collaborator

Thanks @JohannesGaessler for your explanation.

I don't have a comprehensive overview of the changes in gg/flash-attn

The gg/flash-attn is 63 commits behind master, I do not know if it is related.

Or tell me how I can run this benchmark locally on my own hardware.

It would be cool if you can confirm, you can have a look at the Bench README.md:

cd examples/server/bench
mkdir models
LLAMA_SERVER_BIN_PATH=../../../build/bin/server python bench.py \
    --runner-label local \
    --name local \
    --branch `git rev-parse --abbrev-ref HEAD` \
    --commit `git rev-parse HEAD` \
    --scenario script.js \
    --duration 10m \
    --hf-repo ggml-org/models	 \
    --hf-file phi-2/ggml-model-q4_0.gguf \
    --model-path-prefix models \
    --parallel 8 \
    -ngl 33 \
    --batch-size 2048 \
    --ubatch-size	256 \
    --ctx-size 16384 \
    --n-prompts 1000 \
    --max-prompt-tokens 1024 \
    --max-tokens 2048

As the current branch is behind master, it's better to look at the '/metrics' endpoint directly after the test to have accurate metrics:

curl http://localhost:8080/metrics

@JohannesGaessler
Copy link
Collaborator Author

The benchmark is crashing after the benchmark so I cannot retrieve any data from the endpoint but I was able to reproduce master being faster than this PR on an RTX 4090 in this specific benchmark.

Console out master
ens=0 truncated=false
INFO [           print_timings] prompt eval time     =      59.93 ms /   318 tokens (    0.19 ms per token,  5306.19 tokens per second) | tid="140073612861440" timestamp=1713129099 id_slot=2 id_task=25800 t_prompt_processing=59.93 n_prompt_tokens_processed=318 t_token=0.18845911949685534 n_tokens_second=5306.190555648256
INFO [           print_timings] generation eval time =    3680.41 ms /   280 runs   (   13.14 ms per token,    76.08 tokens per second) | tid="140073612861440" timestamp=1713129099 id_slot=2 id_task=25800 t_token_generation=3680.405 n_decoded=280 t_token=13.144303571428573 n_tokens_second=76.07858374282178
INFO [           print_timings]           total time =    3740.34 ms | tid="140073612861440" timestamp=1713129099 id_slot=2 id_task=25800 t_prompt_processing=59.93 t_token_generation=3680.405 t_total=3740.335
INFO [      log_server_request] request | tid="140064000561152" timestamp=1713129099 remote_addr="127.0.0.1" remote_port=38208 status=200 method="POST" path="/v1/chat/completions" params={}
INFO [            update_slots] slot released | tid="140073612861440" timestamp=1713129099 id_slot=2 id_task=25800 n_ctx=16384 n_past=597 n_system_tokens=0 n_cache_tokens=0 truncated=false
INFO [           print_timings] prompt eval time     =      25.20 ms /   112 tokens (    0.22 ms per token,  4444.62 tokens per second) | tid="140073612861440" timestamp=1713129099 id_slot=1 id_task=26025 t_prompt_processing=25.199 n_prompt_tokens_processed=112 t_token=0.22499107142857144 n_tokens_second=4444.620818286439
INFO [           print_timings] generation eval time =    1237.35 ms /    83 runs   (   14.91 ms per token,    67.08 tokens per second) | tid="140073612861440" timestamp=1713129099 id_slot=1 id_task=26025 t_token_generation=1237.354 n_decoded=83 t_token=14.90787951807229 n_tokens_second=67.07862099286056
INFO [           print_timings]           total time =    1262.55 ms | tid="140073612861440" timestamp=1713129099 id_slot=1 id_task=26025 t_prompt_processing=25.199 t_token_generation=1237.354 t_total=1262.553
INFO [      log_server_request] request | tid="140063782350848" timestamp=1713129099 remote_addr="127.0.0.1" remote_port=38290 status=200 method="POST" path="/v1/chat/completions" params={}
INFO [            update_slots] slot released | tid="140073612861440" timestamp=1713129099 id_slot=1 id_task=26025 n_ctx=16384 n_past=194 n_system_tokens=0 n_cache_tokens=0 truncated=false
INFO [           print_timings] prompt eval time     =      32.47 ms /   142 tokens (    0.23 ms per token,  4373.54 tokens per second) | tid="140073612861440" timestamp=1713129100 id_slot=7 id_task=25731 t_prompt_processing=32.468 n_prompt_tokens_processed=142 t_token=0.2286478873239437 n_tokens_second=4373.537021066896
INFO [           print_timings] generation eval time =    5204.21 ms /   433 runs   (   12.02 ms per token,    83.20 tokens per second) | tid="140073612861440" timestamp=1713129100 id_slot=7 id_task=25731 t_token_generation=5204.207 n_decoded=433 t_token=12.018953810623557 n_tokens_second=83.20191721812756
INFO [           print_timings]           total time =    5236.68 ms | tid="140073612861440" timestamp=1713129100 id_slot=7 id_task=25731 t_prompt_processing=32.468 t_token_generation=5204.207 t_total=5236.675
INFO [            update_slots] slot released | tid="140073612861440" timestamp=1713129100 id_slot=7 id_task=25731 n_ctx=16384 n_past=574 n_system_tokens=0 n_cache_tokens=0 truncated=false
INFO [      log_server_request] request | tid="140071852384256" timestamp=1713129100 remote_addr="127.0.0.1" remote_port=38192 status=200 method="POST" path="/v1/chat/completions" params={}
INFO [           print_timings] prompt eval time     =      36.98 ms /   100 tokens (    0.37 ms per token,  2704.16 tokens per second) | tid="140073612861440" timestamp=1713129101 id_slot=6 id_task=25810 t_prompt_processing=36.98 n_prompt_tokens_processed=100 t_token=0.36979999999999996 n_tokens_second=2704.1644131963226
INFO [           print_timings] generation eval time =    5292.01 ms /   512 runs   (   10.34 ms per token,    96.75 tokens per second) | tid="140073612861440" timestamp=1713129101 id_slot=6 id_task=25810 t_token_generation=5292.011 n_decoded=512 t_token=10.335958984375 n_tokens_second=96.74960993089394
INFO [           print_timings]           total time =    5328.99 ms | tid="140073612861440" timestamp=1713129101 id_slot=6 id_task=25810 t_prompt_processing=36.98 t_token_generation=5292.011 t_total=5328.991
INFO [            update_slots] slot released | tid="140073612861440" timestamp=1713129101 id_slot=6 id_task=25810 n_ctx=16384 n_past=611 n_system_tokens=0 n_cache_tokens=0 truncated=false
INFO [      log_server_request] request | tid="140071189323776" timestamp=1713129101 remote_addr="127.0.0.1" remote_port=38220 status=200 method="POST" path="/v1/chat/completions" params={}
INFO [           print_timings] prompt eval time     =      26.09 ms /   119 tokens (    0.22 ms per token,  4561.13 tokens per second) | tid="140073612861440" timestamp=1713129102 id_slot=5 id_task=25927 t_prompt_processing=26.09 n_prompt_tokens_processed=119 t_token=0.21924369747899158 n_tokens_second=4561.134534304331
INFO [           print_timings] generation eval time =    4619.03 ms /   512 runs   (    9.02 ms per token,   110.85 tokens per second) | tid="140073612861440" timestamp=1713129102 id_slot=5 id_task=25927 t_token_generation=4619.031 n_decoded=512 t_token=9.021544921875 n_tokens_second=110.84575964092902
INFO [           print_timings]           total time =    4645.12 ms | tid="140073612861440" timestamp=1713129102 id_slot=5 id_task=25927 t_prompt_processing=26.09 t_token_generation=4619.031 t_total=4645.121
INFO [            update_slots] slot released | tid="140073612861440" timestamp=1713129102 id_slot=5 id_task=25927 n_ctx=16384 n_past=630 n_system_tokens=0 n_cache_tokens=0 truncated=false
INFO [            update_slots] all slots are idle | tid="140073612861440" timestamp=1713129102
INFO [      log_server_request] request | tid="140063941812224" timestamp=1713129102 remote_addr="127.0.0.1" remote_port=38270 status=200 method="POST" path="/v1/chat/completions" params={}
INFO [            update_slots] all slots are idle | tid="140073612861440" timestamp=1713129102

     ✓ success completion

     █ setup

     checks.....................................: 100.00% ✓ 1000       ✗ 0  
     data_received..............................: 0 B     0 B/s
     data_sent..................................: 0 B     0 B/s
     http_req_duration..........................: avg=2.28s       min=157.23ms   med=1.23s    max=11.8s       p(90)=6.31s       p(95)=6.83s      
     http_req_sending...........................: avg=970µs       min=619.47µs   med=940.36µs max=1.72ms      p(90)=1.15ms      p(95)=1.21ms     
     http_reqs..................................: 1000    3.070198/s
     iteration_duration.........................: avg=2.58s       min=211.25µs   med=1.53s    max=12.1s       p(90)=6.61s       p(95)=7.13s      
     iterations.................................: 1000    3.070198/s
     llamacpp_completion_tokens.................: avg=166.233     min=1          med=89       max=512         p(90)=512         p(95)=512        
     llamacpp_completion_tokens_total_counter...: 166233  510.368305/s
     llamacpp_completions_stop_rate.............: 86.70%  ✓ 867        ✗ 133
   ✓ llamacpp_completions_truncated_rate........: 13.30%  ✓ 133        ✗ 867
     llamacpp_prompt_processing_second..........: avg=2961.728831 min=106.557377 med=2750     max=7252.252252 p(90)=4884.833091 p(95)=5460.975158
     llamacpp_prompt_tokens.....................: avg=233.008     min=57         med=85       max=1881        p(90)=675.9       p(95)=1124.05    
     llamacpp_prompt_tokens_total_counter.......: 233008  715.380809/s
     llamacpp_tokens_second.....................: avg=75.198288   min=13.468013  med=77.76343 max=129.032258  p(90)=92.008758   p(95)=97.341552  
     sse_event..................................: 160597  493.064666/s
     vus........................................: 1       min=1        max=8
     vus_max....................................: 8       min=8        max=8


running (05m25.7s), 0/8 VUs, 1000 complete and 0 interrupted iterations
default ✓ [======================================] 8 VUs  05m25.7s/10m0s  1000/1000 shared iters
bench: shutting down server pid=41308 ...
INFO [            update_slots] all slots are idle | tid="140073612861440" timestamp=1713129103
Traceback (most recent call last):
  File "/home/johannesg/Projects/llama.cpp/examples/server/bench/bench.py", line 308, in <module>
    main()
  File "/home/johannesg/Projects/llama.cpp/examples/server/bench/bench.py", line 189, in main
    "0": round(mean(prometheus_metrics['prompt_tokens_seconds']), 2),
                    ~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^
KeyError: 'prompt_tokens_seconds'
Console out PR
INFO [           print_timings] prompt eval time     =      75.78 ms /    88 tokens (    0.86 ms per token,  1161.27 tokens per second) | tid="140137934626816" timestamp=1713130052 id_slot=6 id_task=52158 t_prompt_processing=75.779 n_prompt_tokens_processed=88 t_token=0.8611249999999999 n_tokens_second=1161.2715923936712
INFO [           print_timings] generation eval time =    5106.49 ms /   512 runs   (    9.97 ms per token,   100.26 tokens per second) | tid="140137934626816" timestamp=1713130052 id_slot=6 id_task=52158 t_token_generation=5106.49 n_decoded=512 t_token=9.97361328125 n_tokens_second=100.26456528848583
INFO [           print_timings]           total time =    5182.27 ms | tid="140137934626816" timestamp=1713130052 id_slot=6 id_task=52158 t_prompt_processing=75.779 t_token_generation=5106.49 t_total=5182.269
INFO [      log_server_request] request | tid="140136394125312" timestamp=1713130052 remote_addr="127.0.0.1" remote_port=55456 status=200 method="POST" path="/v1/chat/completions" params={}
INFO [           print_timings] prompt eval time     =      75.95 ms /    59 tokens (    1.29 ms per token,   776.78 tokens per second) | tid="140137934626816" timestamp=1713130052 id_slot=7 id_task=52159 t_prompt_processing=75.955 n_prompt_tokens_processed=59 t_token=1.2873728813559322 n_tokens_second=776.7757224672504
INFO [           print_timings] generation eval time =    5106.52 ms /   512 runs   (    9.97 ms per token,   100.26 tokens per second) | tid="140137934626816" timestamp=1713130052 id_slot=7 id_task=52159 t_token_generation=5106.522 n_decoded=512 t_token=9.97367578125 n_tokens_second=100.26393698098236
INFO [           print_timings]           total time =    5182.48 ms | tid="140137934626816" timestamp=1713130052 id_slot=7 id_task=52159 t_prompt_processing=75.955 t_token_generation=5106.522 t_total=5182.477
INFO [            update_slots] slot released | tid="140137934626816" timestamp=1713130052 id_slot=1 id_task=52153 n_ctx=16384 n_past=676 n_system_tokens=0 n_cache_tokens=0 truncated=false
INFO [            update_slots] slot released | tid="140137934626816" timestamp=1713130052 id_slot=2 id_task=52154 n_ctx=16384 n_past=722 n_system_tokens=0 n_cache_tokens=0 truncated=false
INFO [            update_slots] slot released | tid="140137934626816" timestamp=1713130052 id_slot=3 id_task=52155 n_ctx=16384 n_past=571 n_system_tokens=0 n_cache_tokens=0 truncated=false
INFO [            update_slots] slot released | tid="140137934626816" timestamp=1713130052 id_slot=4 id_task=52156 n_ctx=16384 n_past=649 n_system_tokens=0 n_cache_tokens=0 truncated=false
INFO [            update_slots] slot released | tid="140137934626816" timestamp=1713130052 id_slot=5 id_task=52157 n_ctx=16384 n_past=627 n_system_tokens=0 n_cache_tokens=0 truncated=false
INFO [            update_slots] slot released | tid="140137934626816" timestamp=1713130052 id_slot=6 id_task=52158 n_ctx=16384 n_past=599 n_system_tokens=0 n_cache_tokens=0 truncated=false
INFO [            update_slots] slot released | tid="140137934626816" timestamp=1713130052 id_slot=7 id_task=52159 n_ctx=16384 n_past=570 n_system_tokens=0 n_cache_tokens=0 truncated=false
INFO [            update_slots] all slots are idle | tid="140137934626816" timestamp=1713130052
INFO [      log_server_request] request | tid="140136377339904" timestamp=1713130052 remote_addr="127.0.0.1" remote_port=55462 status=200 method="POST" path="/v1/chat/completions" params={}
INFO [   launch_slot_with_task] slot is processing task | tid="140137934626816" timestamp=1713130052 id_slot=0 id_task=52672
INFO [            update_slots] kv cache rm [p0, end) | tid="140137934626816" timestamp=1713130052 id_slot=0 id_task=52672 p0=0
INFO [   launch_slot_with_task] slot is processing task | tid="140137934626816" timestamp=1713130052 id_slot=1 id_task=52674
INFO [   launch_slot_with_task] slot is processing task | tid="140137934626816" timestamp=1713130052 id_slot=2 id_task=52675
INFO [   launch_slot_with_task] slot is processing task | tid="140137934626816" timestamp=1713130052 id_slot=3 id_task=52676
INFO [   launch_slot_with_task] slot is processing task | tid="140137934626816" timestamp=1713130052 id_slot=4 id_task=52677
INFO [   launch_slot_with_task] slot is processing task | tid="140137934626816" timestamp=1713130052 id_slot=5 id_task=52678
INFO [   launch_slot_with_task] slot is processing task | tid="140137934626816" timestamp=1713130052 id_slot=6 id_task=52679
INFO [   launch_slot_with_task] slot is processing task | tid="140137934626816" timestamp=1713130052 id_slot=7 id_task=52680
INFO [            update_slots] kv cache rm [p0, end) | tid="140137934626816" timestamp=1713130052 id_slot=1 id_task=52674 p0=0
INFO [            update_slots] kv cache rm [p0, end) | tid="140137934626816" timestamp=1713130052 id_slot=2 id_task=52675 p0=0
INFO [            update_slots] kv cache rm [p0, end) | tid="140137934626816" timestamp=1713130052 id_slot=3 id_task=52676 p0=0
INFO [            update_slots] kv cache rm [p0, end) | tid="140137934626816" timestamp=1713130052 id_slot=4 id_task=52677 p0=0
INFO [            update_slots] kv cache rm [p0, end) | tid="140137934626816" timestamp=1713130052 id_slot=5 id_task=52678 p0=0
INFO [            update_slots] kv cache rm [p0, end) | tid="140137934626816" timestamp=1713130052 id_slot=6 id_task=52679 p0=0
INFO [            update_slots] kv cache rm [p0, end) | tid="140137934626816" timestamp=1713130052 id_slot=7 id_task=52680 p0=0
INFO [           print_timings] prompt eval time     =      18.14 ms /    66 tokens (    0.27 ms per token,  3637.97 tokens per second) | tid="140137934626816" timestamp=1713130057 id_slot=0 id_task=52672 t_prompt_processing=18.142 n_prompt_tokens_processed=66 t_token=0.2748787878787879 n_tokens_second=3637.967148054239
INFO [           print_timings] generation eval time =    5501.98 ms /   512 runs   (   10.75 ms per token,    93.06 tokens per second) | tid="140137934626816" timestamp=1713130057 id_slot=0 id_task=52672 t_token_generation=5501.976 n_decoded=512 t_token=10.746046875 n_tokens_second=93.05747607768555
INFO [           print_timings]           total time =    5520.12 ms | tid="140137934626816" timestamp=1713130057 id_slot=0 id_task=52672 t_prompt_processing=18.142 t_token_generation=5501.976 t_total=5520.1179999999995
INFO [      log_server_request] request | tid="140136436088832" timestamp=1713130057 remote_addr="127.0.0.1" remote_port=55388 status=200 method="POST" path="/v1/chat/completions" params={}
INFO [            update_slots] slot released | tid="140137934626816" timestamp=1713130057 id_slot=0 id_task=52672 n_ctx=16384 n_past=577 n_system_tokens=0 n_cache_tokens=0 truncated=false
INFO [           print_timings] prompt eval time     =     163.13 ms /   338 tokens (    0.48 ms per token,  2071.95 tokens per second) | tid="140137934626816" timestamp=1713130057 id_slot=1 id_task=52674 t_prompt_processing=163.131 n_prompt_tokens_processed=338 t_token=0.4826360946745562 n_tokens_second=2071.9544415224573
INFO [           print_timings] generation eval time =    5350.26 ms /   512 runs   (   10.45 ms per token,    95.70 tokens per second) | tid="140137934626816" timestamp=1713130057 id_slot=1 id_task=52674 t_token_generation=5350.262 n_decoded=512 t_token=10.44973046875 n_tokens_second=95.69624814635246
INFO [           print_timings]           total time =    5513.39 ms | tid="140137934626816" timestamp=1713130057 id_slot=1 id_task=52674 t_prompt_processing=163.131 t_token_generation=5350.262 t_total=5513.393
INFO [      log_server_request] request | tid="140136427696128" timestamp=1713130057 remote_addr="127.0.0.1" remote_port=55404 status=200 method="POST" path="/v1/chat/completions" params={}
INFO [           print_timings] prompt eval time     =     162.01 ms /   170 tokens (    0.95 ms per token,  1049.34 tokens per second) | tid="140137934626816" timestamp=1713130057 id_slot=2 id_task=52675 t_prompt_processing=162.006 n_prompt_tokens_processed=170 t_token=0.9529764705882353 n_tokens_second=1049.3438514622917
INFO [           print_timings] generation eval time =    5350.35 ms /   512 runs   (   10.45 ms per token,    95.69 tokens per second) | tid="140137934626816" timestamp=1713130057 id_slot=2 id_task=52675 t_token_generation=5350.354 n_decoded=512 t_token=10.44991015625 n_tokens_second=95.69460263750771
INFO [           print_timings]           total time =    5512.36 ms | tid="140137934626816" timestamp=1713130057 id_slot=2 id_task=52675 t_prompt_processing=162.006 t_token_generation=5350.354 t_total=5512.360000000001
INFO [      log_server_request] request | tid="140136419303424" timestamp=1713130057 remote_addr="127.0.0.1" remote_port=55412 status=200 method="POST" path="/v1/chat/completions" params={}
INFO [           print_timings] prompt eval time     =     161.74 ms /   271 tokens (    0.60 ms per token,  1675.53 tokens per second) | tid="140137934626816" timestamp=1713130057 id_slot=3 id_task=52676 t_prompt_processing=161.74 n_prompt_tokens_processed=271 t_token=0.5968265682656827 n_tokens_second=1675.5286261901817
INFO [           print_timings] generation eval time =    5350.41 ms /   512 runs   (   10.45 ms per token,    95.69 tokens per second) | tid="140137934626816" timestamp=1713130057 id_slot=3 id_task=52676 t_token_generation=5350.406 n_decoded=512 t_token=10.45001171875 n_tokens_second=95.6936725923229
INFO [           print_timings]           total time =    5512.15 ms | tid="140137934626816" timestamp=1713130057 id_slot=3 id_task=52676 t_prompt_processing=161.74 t_token_generation=5350.406 t_total=5512.146
INFO [      log_server_request] request | tid="140136410910720" timestamp=1713130057 remote_addr="127.0.0.1" remote_port=55422 status=200 method="POST" path="/v1/chat/completions" params={}
INFO [           print_timings] prompt eval time     =     161.11 ms /   905 tokens (    0.18 ms per token,  5617.18 tokens per second) | tid="140137934626816" timestamp=1713130057 id_slot=4 id_task=52677 t_prompt_processing=161.113 n_prompt_tokens_processed=905 t_token=0.1780254143646409 n_tokens_second=5617.175522769733
INFO [           print_timings] generation eval time =    5350.45 ms /   512 runs   (   10.45 ms per token,    95.69 tokens per second) | tid="140137934626816" timestamp=1713130057 id_slot=4 id_task=52677 t_token_generation=5350.451 n_decoded=512 t_token=10.450099609375 n_tokens_second=95.69286776011965
INFO [           print_timings]           total time =    5511.56 ms | tid="140137934626816" timestamp=1713130057 id_slot=4 id_task=52677 t_prompt_processing=161.113 t_token_generation=5350.451 t_total=5511.564
INFO [      log_server_request] request | tid="140136402518016" timestamp=1713130057 remote_addr="127.0.0.1" remote_port=55438 status=200 method="POST" path="/v1/chat/completions" params={}
INFO [           print_timings] prompt eval time     =     157.78 ms /    88 tokens (    1.79 ms per token,   557.72 tokens per second) | tid="140137934626816" timestamp=1713130057 id_slot=5 id_task=52678 t_prompt_processing=157.784 n_prompt_tokens_processed=88 t_token=1.793 n_tokens_second=557.7244841048522
INFO [           print_timings] generation eval time =    5350.49 ms /   512 runs   (   10.45 ms per token,    95.69 tokens per second) | tid="140137934626816" timestamp=1713130057 id_slot=5 id_task=52678 t_token_generation=5350.488 n_decoded=512 t_token=10.450171875 n_tokens_second=95.6922060193388
INFO [           print_timings]           total time =    5508.27 ms | tid="140137934626816" timestamp=1713130057 id_slot=5 id_task=52678 t_prompt_processing=157.784 t_token_generation=5350.488 t_total=5508.272
INFO [      log_server_request] request | tid="140136385732608" timestamp=1713130057 remote_addr="127.0.0.1" remote_port=55440 status=200 method="POST" path="/v1/chat/completions" params={}
INFO [           print_timings] prompt eval time     =     157.93 ms /    62 tokens (    2.55 ms per token,   392.59 tokens per second) | tid="140137934626816" timestamp=1713130057 id_slot=6 id_task=52679 t_prompt_processing=157.926 n_prompt_tokens_processed=62 t_token=2.5471935483870967 n_tokens_second=392.5889340577233
INFO [           print_timings] generation eval time =    5350.52 ms /   512 runs   (   10.45 ms per token,    95.69 tokens per second) | tid="140137934626816" timestamp=1713130057 id_slot=6 id_task=52679 t_token_generation=5350.52 n_decoded=512 t_token=10.450234375 n_tokens_second=95.69163371036834
INFO [           print_timings]           total time =    5508.45 ms | tid="140137934626816" timestamp=1713130057 id_slot=6 id_task=52679 t_prompt_processing=157.926 t_token_generation=5350.52 t_total=5508.446000000001
INFO [      log_server_request] request | tid="140136394125312" timestamp=1713130057 remote_addr="127.0.0.1" remote_port=55456 status=200 method="POST" path="/v1/chat/completions" params={}
INFO [           print_timings] prompt eval time     =     158.20 ms /    65 tokens (    2.43 ms per token,   410.88 tokens per second) | tid="140137934626816" timestamp=1713130057 id_slot=7 id_task=52680 t_prompt_processing=158.197 n_prompt_tokens_processed=65 t_token=2.4338 n_tokens_second=410.8801051853069
INFO [           print_timings] generation eval time =    5350.55 ms /   512 runs   (   10.45 ms per token,    95.69 tokens per second) | tid="140137934626816" timestamp=1713130057 id_slot=7 id_task=52680 t_token_generation=5350.551 n_decoded=512 t_token=10.450294921875 n_tokens_second=95.69107929258126
INFO [           print_timings]           total time =    5508.75 ms | tid="140137934626816" timestamp=1713130057 id_slot=7 id_task=52680 t_prompt_processing=158.197 t_token_generation=5350.551 t_total=5508.7480000000005
INFO [            update_slots] slot released | tid="140137934626816" timestamp=1713130057 id_slot=1 id_task=52674 n_ctx=16384 n_past=849 n_system_tokens=0 n_cache_tokens=0 truncated=false
INFO [            update_slots] slot released | tid="140137934626816" timestamp=1713130057 id_slot=2 id_task=52675 n_ctx=16384 n_past=681 n_system_tokens=0 n_cache_tokens=0 truncated=false
INFO [            update_slots] slot released | tid="140137934626816" timestamp=1713130057 id_slot=3 id_task=52676 n_ctx=16384 n_past=782 n_system_tokens=0 n_cache_tokens=0 truncated=false
INFO [            update_slots] slot released | tid="140137934626816" timestamp=1713130057 id_slot=4 id_task=52677 n_ctx=16384 n_past=1416 n_system_tokens=0 n_cache_tokens=0 truncated=false
INFO [            update_slots] slot released | tid="140137934626816" timestamp=1713130057 id_slot=5 id_task=52678 n_ctx=16384 n_past=599 n_system_tokens=0 n_cache_tokens=0 truncated=false
INFO [            update_slots] slot released | tid="140137934626816" timestamp=1713130057 id_slot=6 id_task=52679 n_ctx=16384 n_past=573 n_system_tokens=0 n_cache_tokens=0 truncated=false
INFO [            update_slots] slot released | tid="140137934626816" timestamp=1713130057 id_slot=7 id_task=52680 n_ctx=16384 n_past=576 n_system_tokens=0 n_cache_tokens=0 truncated=false
INFO [            update_slots] all slots are idle | tid="140137934626816" timestamp=1713130057
INFO [      log_server_request] request | tid="140136377339904" timestamp=1713130057 remote_addr="127.0.0.1" remote_port=55462 status=200 method="POST" path="/v1/chat/completions" params={}

     ✓ success completion

     █ setup

     checks.....................................: 100.00% ✓ 816        ✗ 0  
     data_received..............................: 781 kB  1.3 kB/s
     data_sent..................................: 882 kB  1.5 kB/s
     dropped_iterations.........................: 184     0.306237/s
     http_req_blocked...........................: avg=53.95µs    min=1.7µs     med=4.46µs    max=513.88µs   p(90)=241µs      p(95)=272.39µs  
     http_req_connecting........................: avg=32.48µs    min=0s        med=0s        max=234.66µs   p(90)=157.12µs   p(95)=179.1µs   
     http_req_duration..........................: avg=5.58s      min=5.08s     med=5.35s     max=8.28s      p(90)=6.32s      p(95)=7.1s      
       { expected_response:true }...............: avg=5.58s      min=5.08s     med=5.35s     max=8.28s      p(90)=6.32s      p(95)=7.1s      
     http_req_failed............................: 0.00%   ✓ 0          ✗ 816
     http_req_receiving.........................: avg=63.4µs     min=23.61µs   med=56.8µs    max=156.69µs   p(90)=90.29µs    p(95)=105.2µs   
     http_req_sending...........................: avg=42.64µs    min=11.22µs   med=28.51µs   max=251.56µs   p(90)=86.24µs    p(95)=95.94µs   
     http_req_tls_handshaking...................: avg=0s         min=0s        med=0s        max=0s         p(90)=0s         p(95)=0s        
     http_req_waiting...........................: avg=5.58s      min=5.08s     med=5.35s     max=8.28s      p(90)=6.32s      p(95)=7.1s      
     http_reqs..................................: 816     1.358094/s
     iteration_duration.........................: avg=5.88s      min=140.25µs  med=5.65s     max=8.58s      p(90)=6.62s      p(95)=7.4s      
     iterations.................................: 816     1.358094/s
     llamacpp_completion_tokens.................: avg=512        min=512       med=512       max=512        p(90)=512        p(95)=512       
     llamacpp_completion_tokens_total_counter...: 417792  695.344334/s
     llamacpp_completions_stop_rate.............: 100.00% ✓ 816        ✗ 0  
   ✓ llamacpp_completions_truncated_rate........: 0.00%   ✓ 0          ✗ 816
     llamacpp_prompt_tokens.....................: avg=242.089461 min=57        med=85        max=1881       p(90)=737.5      p(95)=1135.25   
     llamacpp_prompt_tokens_total_counter.......: 197545  328.780341/s
     llamacpp_tokens_second.....................: avg=133.177589 min=69.716681 med=113.59654 max=385.707008 p(90)=208.049168 p(95)=251.191193
     vus........................................: 8       min=8        max=8
     vus_max....................................: 8       min=8        max=8


running (10m00.8s), 0/8 VUs, 816 complete and 0 interrupted iterations
default ✗ [==============================>-------] 8 VUs  10m00.8s/10m0s  0816/1000 shared iters
bench: shutting down server pid=42224 ...
INFO [            update_slots] all slots are idle | tid="140137934626816" timestamp=1713130058
Traceback (most recent call last):
  File "/home/johannesg/Projects/llama.cpp/examples/server/bench/bench.py", line 309, in <module>
    main()
  File "/home/johannesg/Projects/llama.cpp/examples/server/bench/bench.py", line 190, in main
    "0": round(mean(prometheus_metrics['prompt_tokens_seconds']), 2),
                    ~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^
KeyError: 'prompt_tokens_seconds'

If you compare the logs you'll see that on master the slots are used asynchronously while with this PR they are used synchronously. I don't know whether this is the result of changes on gg/flash-attn or whether that branch simply lacks some performance optimizations on master but in any case this has nothing to do with the CUDA kernels.

@phymbert
Copy link
Collaborator

Thanks for having taken the time to confirm this is nothing related to the CUDA kernel.
FYI, the test is crashing because prometheus is not started, not important.

Let's wait for the master to be synchronized then.

@JohannesGaessler
Copy link
Collaborator Author

The master commit immediately preceding the last merge into gg/flash-attn is still fast so the issue has to be some commit on that branch.

@kalomaze
Copy link
Contributor

kalomaze commented Apr 14, 2024

While 35b and 20b were functional, I also tried a 70b at q4_K_M; it's still affected by the bug and puts out pure gibberish. (alto alto alto alto...)

./server -m '/miqu-1-70b.q4_k_m.gguf' -c 8192 -ngl 81 -b 512

Not sure if this is relevant, but I notice that these are much smaller too:

llama_new_context_with_model:      CUDA0 compute buffer size =   457.75 MiB
llama_new_context_with_model:      CUDA1 compute buffer size =   168.00 MiB

when compared to mainline:

llama_new_context_with_model:      CUDA0 compute buffer size =  1108.00 MiB
llama_new_context_with_model:      CUDA1 compute buffer size =  1104.00 MiB

The generation speed is also slower on the PR for the 70b, oddly enough, which didn't apply for the other (smaller) model sizes I ran successfully where it was a consistent speedup.

@JohannesGaessler
Copy link
Collaborator Author

I cannot reproduce the issue with LLaMA 2 70b q4_K_M. I'll download Miqu and see if there is a difference.

@kalomaze
Copy link
Contributor

kalomaze commented Apr 15, 2024

Could be related to rotary embeddings, maybe, if it does only happen on Miqu

@JohannesGaessler
Copy link
Collaborator Author

I can reproduce the issue with Miqu. It could be the same issue as with Phi-2 where (according to Georgi) IEEE 754 half precision floats are not sufficient. One solution would be to instead use bfloat16 or FP32 but bfloat16 only has hardware support since Ampere and FP32 needs more memory. I'll revisit Miqu once there is a solution for Phi-2 and check whether it works then.

@ggerganov
Copy link
Owner

@JohannesGaessler @phymbert

I suspect that the benchmark using phi-2 is invalid because of the precision issues - likely each submitted request keeps generating garbage tokens without ever hitting EOS. Will try to confirm this and move forward the FA branch this week

@JohannesGaessler
Copy link
Collaborator Author

@ggerganov I think it's not an issue with the precision but rather with the numerical range. Do you know which parts of the calculation specifically are problematic?

@ggerganov
Copy link
Owner

On RTX 2060, using master you can try the following:

LLAMA_CUBLAS=1 make -j main && ./main -m ./models/phi-2/ggml-model-f16.gguf -p "I believe the meaning of life is" -s 1 -n 32 -ngl 99

Generation is fine:

I believe the meaning of life is to be happy, but what does that even mean?

Ben: It means finding fulfillment and contentment in our daily lives and pursuing our passions.

Claire: I think the meaning of life is to help others and make a positive impact on the world.

Alex: I believe the meaning of
llama_print_timings:        load time =    1224,92 ms
llama_print_timings:      sample time =       1,92 ms /    64 runs   (    0,03 ms per token, 33281,33 tokens per second)
llama_print_timings: prompt eval time =      19,06 ms /     7 tokens (    2,72 ms per token,   367,18 tokens per second)
llama_print_timings:        eval time =    1044,99 ms /    63 runs   (   16,59 ms per token,    60,29 tokens per second)
llama_print_timings:       total time =    1083,51 ms /    70 tokens

Now, disable GGML_PREC_F32 for the K*Q matrix multiplication:

diff --git a/llama.cpp b/llama.cpp
index cf95cea1..9882b3b8 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -5994,7 +5994,7 @@ static struct ggml_tensor * llm_build_kqv(
     if (model.arch == LLM_ARCH_PHI2) {
         // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs
         // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847
-        ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
+        //ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
     }
 
     if (model.arch == LLM_ARCH_GROK) {

Generation is now garbage:

I believe the meaning of life isGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGG

llama_print_timings:        load time =    1208,21 ms
llama_print_timings:      sample time =       0,81 ms /    62 runs   (    0,01 ms per token, 76827,76 tokens per second)
llama_print_timings: prompt eval time =      19,17 ms /     7 tokens (    2,74 ms per token,   365,17 tokens per second)
llama_print_timings:        eval time =    1007,04 ms /    61 runs   (   16,51 ms per token,    60,57 tokens per second)
llama_print_timings:       total time =    1048,63 ms /    68 tokens

So this makes me think that the FA kernels need to start respecting the GGML_PREC_F32 mode and if configured, then the K*Q has to be carried out using F32 accumulator.

@JohannesGaessler
Copy link
Collaborator Author

If you add a check at the end of the FP16 kernel to write back 0.0f instead of NaN the results from Phi-2 and Miqu are correct. So presumably it will be possible to fix the outputs without having to write a kernel with higher precision because the problem is underflows rather than overflows. I'll investigate where exactly the problem is and how to best fix it.

@ggerganov
Copy link
Owner

Btw, this #6685 (comment) reminded me that the build_defrag() function is not updated in the FA branch to account for the V cache no longer being transposed. So likely the server slowdown is related to that as well. Will fix when we merge this PR into gg/flash-attn

@JohannesGaessler
Copy link
Collaborator Author

JohannesGaessler commented Apr 15, 2024

I pushed a partial fix for the numerical precision issues. One of the problems was that if the model produces KQ values with a wide range the FP16 exponentiation can result in arithmetic underflow. Unfortunately the result is then not 0 but NaN. This can be fixed by flushing the results below a certain threshold to 0. I chose a difference of more than 20 to the max value. This is equivalent with flushing all post-exponentiation values ~ $2 \cdot 10^{-9}$ to 0 which should be negligible anyways. This fixes the problems for Miqu.

In addition to that, if you were to then also add a check at the end that avoids NaNs from 0.0f/0.0f division you could get Phi-2 to produce coherent outputs. With the way I did the implementation NaN values from the KQ matrix multiplication also get set to 0 so if you were to do this you would essentially be ignoring those problematic values. However, this severely affects quality: FP16 perplexity becomes roughly equivalent to q4_K_S. So I don't think this is a good way of handling the Phi-2 precision issues.

@JohannesGaessler
Copy link
Collaborator Author

I forgot: I can confirm that the server performance issues are mostly caused by NaN outputs; with the hacky Phi-2 fix the performance is much better (but still slower than master).

@phymbert
Copy link
Collaborator

phymbert commented Apr 16, 2024

Hello, I just started a server with 32 slots on a llama 70b arch with this branch. It generates garbage: "â-...â-...â-...". Does it support continuous or parallel batching ?

Note: Added --cont-batching, no defragmentation. n_batch=4096, ubatch=256, n_ctx=32768, n_parallel=32.

@JohannesGaessler
Copy link
Collaborator Author

JohannesGaessler commented Apr 16, 2024

I pushed an implementation for calculating KQ and the corresponding softmax at FP32 precision, Phi-2 should work now.

The KQ values produced by Phi-2 are just a mess. They frequently fall outside of the max. representable range of IEEE 754 half precsion floats while the values produced by e.g. LLaMA 2 are significantly smaller. In any case, now it works. The performance impact of using some FP32 is most notable at large batch sizes where the performance difference is 3-5% (may be a lot more on Volta/Turing where there is less shared memory per SM).

In any case this is the current level of performance on my systems:

Vs. master
GPU Model Batch size Test t/s master t/s 1ca6454 Speedup
RTX 4090 gemma 2B F16 1 pp4096 148.46 157.07 1.06
RTX 4090 gemma 2B F16 2 pp4096 273.49 275.65 1.01
RTX 4090 gemma 2B F16 4 pp4096 558.89 564.36 1.01
RTX 4090 gemma 2B F16 8 pp4096 1104.55 1115.63 1.01
RTX 4090 gemma 2B F16 16 pp4096 2161.74 2179.20 1.01
RTX 4090 gemma 2B F16 32 pp4096 4187.60 4259.44 1.02
RTX 4090 gemma 2B F16 64 pp4096 7446.43 7590.81 1.02
RTX 4090 gemma 2B F16 128 pp4096 13042.99 13343.23 1.02
RTX 4090 gemma 2B F16 256 pp4096 20696.70 20859.75 1.01
RTX 4090 gemma 2B F16 512 pp4096 26401.29 28185.16 1.07
RTX 4090 gemma 2B F16 1024 pp4096 27309.42 29182.72 1.07
RTX 4090 gemma 2B F16 2048 pp4096 27846.42 29831.07 1.07
RTX 4090 gemma 2B F16 4096 pp4096 28062.58 30092.97 1.07
RTX 4090 llama 7B Q4_0 1 pp4096 136.04 146.37 1.08
RTX 4090 llama 7B Q4_0 2 pp4096 271.98 279.24 1.03
RTX 4090 llama 7B Q4_0 4 pp4096 533.00 552.01 1.04
RTX 4090 llama 7B Q4_0 8 pp4096 903.12 932.79 1.03
RTX 4090 llama 7B Q4_0 16 pp4096 1272.05 1313.42 1.03
RTX 4090 llama 7B Q4_0 32 pp4096 1552.26 1607.01 1.04
RTX 4090 llama 7B Q4_0 64 pp4096 2079.20 2164.93 1.04
RTX 4090 llama 7B Q4_0 128 pp4096 3382.61 3711.51 1.10
RTX 4090 llama 7B Q4_0 256 pp4096 4710.19 6072.54 1.29
RTX 4090 llama 7B Q4_0 512 pp4096 5492.51 7843.51 1.43
RTX 4090 llama 7B Q4_0 1024 pp4096 5498.51 7849.78 1.43
RTX 4090 llama 7B Q4_0 2048 pp4096 5505.36 7852.64 1.43
RTX 4090 llama 7B Q4_0 4096 pp4096 5503.67 7859.96 1.43
RTX 4090 phi2 3B F16 1 pp4096 119.85 123.62 1.03
RTX 4090 phi2 3B F16 2 pp4096 212.59 218.55 1.03
RTX 4090 phi2 3B F16 4 pp4096 420.16 434.08 1.03
RTX 4090 phi2 3B F16 8 pp4096 821.16 856.37 1.04
RTX 4090 phi2 3B F16 16 pp4096 1568.08 1671.98 1.07
RTX 4090 phi2 3B F16 32 pp4096 2932.84 3211.49 1.10
RTX 4090 phi2 3B F16 64 pp4096 5097.44 5627.12 1.10
RTX 4090 phi2 3B F16 128 pp4096 7491.61 9358.63 1.25
RTX 4090 phi2 3B F16 256 pp4096 9217.04 14745.22 1.60
RTX 4090 phi2 3B F16 512 pp4096 9664.00 16823.78 1.74
RTX 4090 phi2 3B F16 1024 pp4096 9701.17 16941.94 1.75
RTX 4090 phi2 3B F16 2048 pp4096 9720.84 17006.79 1.75
RTX 4090 phi2 3B F16 4096 pp4096 9739.12 17046.60 1.75
RTX 3090 gemma 2B F16 1 pp4096 124.19 135.28 1.09
RTX 3090 gemma 2B F16 2 pp4096 223.09 227.95 1.02
RTX 3090 gemma 2B F16 4 pp4096 446.77 460.88 1.03
RTX 3090 gemma 2B F16 8 pp4096 858.03 908.72 1.06
RTX 3090 gemma 2B F16 16 pp4096 1577.65 1752.94 1.11
RTX 3090 gemma 2B F16 32 pp4096 3169.02 3221.29 1.02
RTX 3090 gemma 2B F16 64 pp4096 5732.01 5910.87 1.03
RTX 3090 gemma 2B F16 128 pp4096 9279.19 10026.04 1.08
RTX 3090 gemma 2B F16 256 pp4096 10880.27 11767.17 1.08
RTX 3090 gemma 2B F16 512 pp4096 11941.20 13097.30 1.10
RTX 3090 gemma 2B F16 1024 pp4096 12090.88 13359.76 1.10
RTX 3090 gemma 2B F16 2048 pp4096 12349.03 13488.39 1.09
RTX 3090 gemma 2B F16 4096 pp4096 12231.93 13485.40 1.10
RTX 3090 llama 7B Q4_0 1 pp4096 112.55 124.27 1.10
RTX 3090 llama 7B Q4_0 2 pp4096 212.74 237.70 1.12
RTX 3090 llama 7B Q4_0 4 pp4096 349.07 400.35 1.15
RTX 3090 llama 7B Q4_0 8 pp4096 465.10 532.83 1.15
RTX 3090 llama 7B Q4_0 16 pp4096 473.60 564.76 1.19
RTX 3090 llama 7B Q4_0 32 pp4096 614.24 635.46 1.03
RTX 3090 llama 7B Q4_0 64 pp4096 1211.23 1339.85 1.11
RTX 3090 llama 7B Q4_0 128 pp4096 1934.80 2227.29 1.15
RTX 3090 llama 7B Q4_0 256 pp4096 2616.74 3137.70 1.20
RTX 3090 llama 7B Q4_0 512 pp4096 2986.04 3711.51 1.24
RTX 3090 llama 7B Q4_0 1024 pp4096 2999.42 3719.73 1.24
RTX 3090 llama 7B Q4_0 2048 pp4096 3001.76 3726.12 1.24
RTX 3090 llama 7B Q4_0 4096 pp4096 3009.03 3734.50 1.24
RTX 3090 phi2 3B F16 1 pp4096 100.35 106.01 1.06
RTX 3090 phi2 3B F16 2 pp4096 176.19 184.28 1.05
RTX 3090 phi2 3B F16 4 pp4096 344.45 365.32 1.06
RTX 3090 phi2 3B F16 8 pp4096 658.08 723.23 1.10
RTX 3090 phi2 3B F16 16 pp4096 1206.10 1350.49 1.12
RTX 3090 phi2 3B F16 32 pp4096 2127.28 2485.34 1.17
RTX 3090 phi2 3B F16 64 pp4096 3465.86 4275.48 1.23
RTX 3090 phi2 3B F16 128 pp4096 4764.21 6339.79 1.33
RTX 3090 phi2 3B F16 256 pp4096 5286.81 7294.76 1.38
RTX 3090 phi2 3B F16 512 pp4096 5641.33 8159.98 1.45
RTX 3090 phi2 3B F16 1024 pp4096 5690.37 8174.25 1.44
RTX 3090 phi2 3B F16 2048 pp4096 5694.05 8225.21 1.44
RTX 3090 phi2 3B F16 4096 pp4096 5707.51 8197.68 1.44
Vs. gg/flash-attn
GPU Model Batch size Test t/s gg/flash-attn t/s 1ca6454 Speedup
RTX 4090 gemma 2B F16 1 pp4096 156.93 157.07 1.00
RTX 4090 gemma 2B F16 2 pp4096 245.72 275.65 1.12
RTX 4090 gemma 2B F16 4 pp4096 501.66 564.36 1.12
RTX 4090 gemma 2B F16 8 pp4096 994.84 1115.63 1.12
RTX 4090 gemma 2B F16 16 pp4096 1963.72 2179.20 1.11
RTX 4090 gemma 2B F16 32 pp4096 3763.81 4259.44 1.13
RTX 4090 gemma 2B F16 64 pp4096 6609.81 7590.81 1.15
RTX 4090 gemma 2B F16 128 pp4096 12166.87 13343.23 1.10
RTX 4090 gemma 2B F16 256 pp4096 20393.39 20859.75 1.02
RTX 4090 gemma 2B F16 512 pp4096 28404.84 28185.16 0.99
RTX 4090 gemma 2B F16 1024 pp4096 29462.16 29182.72 0.99
RTX 4090 gemma 2B F16 2048 pp4096 30085.55 29831.07 0.99
RTX 4090 gemma 2B F16 4096 pp4096 30343.08 30092.97 0.99
RTX 4090 llama 7B Q4_0 1 pp4096 146.68 146.37 1.00
RTX 4090 llama 7B Q4_0 2 pp4096 276.31 279.24 1.01
RTX 4090 llama 7B Q4_0 4 pp4096 547.74 552.01 1.01
RTX 4090 llama 7B Q4_0 8 pp4096 929.35 932.79 1.00
RTX 4090 llama 7B Q4_0 16 pp4096 1317.17 1313.42 1.00
RTX 4090 llama 7B Q4_0 32 pp4096 1600.31 1607.01 1.00
RTX 4090 llama 7B Q4_0 64 pp4096 2119.35 2164.93 1.02
RTX 4090 llama 7B Q4_0 128 pp4096 3705.01 3711.51 1.00
RTX 4090 llama 7B Q4_0 256 pp4096 6069.55 6072.54 1.00
RTX 4090 llama 7B Q4_0 512 pp4096 7829.50 7843.51 1.00
RTX 4090 llama 7B Q4_0 1024 pp4096 7843.75 7849.78 1.00
RTX 4090 llama 7B Q4_0 2048 pp4096 7864.92 7852.64 1.00
RTX 4090 llama 7B Q4_0 4096 pp4096 7854.40 7859.96 1.00
RTX 4090 phi2 3B F16 1 pp4096 123.51 123.62 1.00
RTX 4090 phi2 3B F16 2 pp4096 203.90 218.55 1.07
RTX 4090 phi2 3B F16 4 pp4096 405.53 434.08 1.07
RTX 4090 phi2 3B F16 8 pp4096 803.90 856.37 1.07
RTX 4090 phi2 3B F16 16 pp4096 1571.15 1671.98 1.06
RTX 4090 phi2 3B F16 32 pp4096 3085.40 3211.49 1.04
RTX 4090 phi2 3B F16 64 pp4096 5461.86 5627.12 1.03
RTX 4090 phi2 3B F16 128 pp4096 9513.56 9358.63 0.98
RTX 4090 phi2 3B F16 256 pp4096 15202.10 14745.22 0.97
RTX 4090 phi2 3B F16 512 pp4096 17564.81 16823.78 0.96
RTX 4090 phi2 3B F16 1024 pp4096 17506.02 16941.94 0.97
RTX 4090 phi2 3B F16 2048 pp4096 17565.62 17006.79 0.97
RTX 4090 phi2 3B F16 4096 pp4096 17604.26 17046.60 0.97
RTX 3090 gemma 2B F16 1 pp4096 135.58 135.28 1.00
RTX 3090 gemma 2B F16 2 pp4096 199.72 227.95 1.14
RTX 3090 gemma 2B F16 4 pp4096 401.66 460.88 1.15
RTX 3090 gemma 2B F16 8 pp4096 795.51 908.72 1.14
RTX 3090 gemma 2B F16 16 pp4096 1553.86 1752.94 1.13
RTX 3090 gemma 2B F16 32 pp4096 2909.64 3221.29 1.11
RTX 3090 gemma 2B F16 64 pp4096 5255.95 5910.87 1.12
RTX 3090 gemma 2B F16 128 pp4096 9130.67 10026.04 1.10
RTX 3090 gemma 2B F16 256 pp4096 11761.76 11767.17 1.00
RTX 3090 gemma 2B F16 512 pp4096 13284.64 13097.30 0.99
RTX 3090 gemma 2B F16 1024 pp4096 13546.97 13359.76 0.99
RTX 3090 gemma 2B F16 2048 pp4096 13625.12 13488.39 0.99
RTX 3090 gemma 2B F16 4096 pp4096 13673.57 13485.40 0.99
RTX 3090 llama 7B Q4_0 1 pp4096 125.30 124.27 0.99
RTX 3090 llama 7B Q4_0 2 pp4096 225.69 237.70 1.05
RTX 3090 llama 7B Q4_0 4 pp4096 383.69 400.35 1.04
RTX 3090 llama 7B Q4_0 8 pp4096 533.13 532.83 1.00
RTX 3090 llama 7B Q4_0 16 pp4096 569.25 564.76 0.99
RTX 3090 llama 7B Q4_0 32 pp4096 629.32 635.46 1.01
RTX 3090 llama 7B Q4_0 64 pp4096 1320.76 1339.85 1.01
RTX 3090 llama 7B Q4_0 128 pp4096 2236.75 2227.29 1.00
RTX 3090 llama 7B Q4_0 256 pp4096 3149.86 3137.70 1.00
RTX 3090 llama 7B Q4_0 512 pp4096 3706.48 3711.51 1.00
RTX 3090 llama 7B Q4_0 1024 pp4096 3718.74 3719.73 1.00
RTX 3090 llama 7B Q4_0 2048 pp4096 3730.79 3726.12 1.00
RTX 3090 llama 7B Q4_0 4096 pp4096 3735.57 3734.50 1.00
RTX 3090 phi2 3B F16 1 pp4096 106.93 106.01 0.99
RTX 3090 phi2 3B F16 2 pp4096 170.42 184.28 1.08
RTX 3090 phi2 3B F16 4 pp4096 337.30 365.32 1.08
RTX 3090 phi2 3B F16 8 pp4096 666.42 723.23 1.09
RTX 3090 phi2 3B F16 16 pp4096 1293.18 1350.49 1.04
RTX 3090 phi2 3B F16 32 pp4096 2434.70 2485.34 1.02
RTX 3090 phi2 3B F16 64 pp4096 4246.23 4275.48 1.01
RTX 3090 phi2 3B F16 128 pp4096 6835.86 6339.79 0.93
RTX 3090 phi2 3B F16 256 pp4096 7753.84 7294.76 0.94
RTX 3090 phi2 3B F16 512 pp4096 8814.01 8159.98 0.93
RTX 3090 phi2 3B F16 1024 pp4096 8585.54 8174.25 0.95
RTX 3090 phi2 3B F16 2048 pp4096 8668.73 8225.21 0.95
RTX 3090 phi2 3B F16 4096 pp4096 8557.30 8197.68 0.96

@ggerganov
Copy link
Owner

@phymbert Does it also produce garbage using main?

@ggerganov ggerganov mentioned this pull request Apr 17, 2024
8 tasks
@phymbert
Copy link
Collaborator

phymbert commented Apr 17, 2024

@phymbert Does it also produce garbage using main?

Yes tested just now with the latest commit here, a llama 70b on 2 A100, only generating #######. Testing the target branch...

@ggerganov
Copy link
Owner

The target branch only fixes the V cache defrag which I thought was causing the problem that you observed. However, if you observe the garbage using main this means that there is another issue which is still not resolved.

I tested 70B llama with Metal + FA and it works OK, so maybe there is something wrong with the CUDA implementation still

Comment on lines -236 to 249
const float * Q_f = (const float *) (Q + nb02* blockIdx.y);
const float * Q_f = (const float *) (Q + nb02* blockIdx.y + nb01*ic0);
const half * K_h = (const half *) (K + nb12*(blockIdx.y / gqa_ratio));
const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just remembered, for calculating the offsets I am not using blockIdx.z and nb03/nb13 because I didn't understand what the purpose was and in the test cases ne3 was always 1. Are they used for continuous batching? If so, what is the expected memory layout?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think dim 3 is not used in llama.cpp during inference (I think it's used only during training and in other ggml projects). But it would be nice to take it into account. The easiest thing to do is add tests in test-backend-ops for that.

But in any case, I don't think this is causing the garbage output with 70B that @phymbert reports

@phymbert
Copy link
Collaborator

phymbert commented Apr 17, 2024

@JohannesGaessler Maybe I am facing a precision issue also, I am trying to force PREC_F32 on llama arch also.

But I am facing __hgt2_mask is undefined, it is not supported in cuda 11.6 ?

@JohannesGaessler
Copy link
Collaborator Author

But I am facing __hgt2_mask is undefined, it is not supported in cuda 11.6 ?

Looking at the documentation, it seems that that particular instruction is indeed not available in CUDA 11.6. But if you're going to force FP32 precision anyways you can just delete those lines since they will not be used. Conversely, they may be the fix for running your model at FP16 precision so it would be worthwhile to also test with CUDA 12.

@JohannesGaessler
Copy link
Collaborator Author

I did a quick reimplementation of the function for CUDA 11, it should compile now.

@ggerganov
Copy link
Owner

On RTX 2060 and V100, the Phi-2 F16 model generates garbage (bf6a496):

LLAMA_CUBLAS=1 make -j main && ./main -m ./models-mnt/phi-2/ggml-model-f16.gguf -p "I believe the meaning of life is" -s 200 -n 64 -ngl 99
I believe the meaning of life isGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGG
llama_print_timings:        load time =    1609.50 ms
llama_print_timings:      sample time =       2.11 ms /    64 runs   (    0.03 ms per token, 30288.69 tokens per second)
llama_print_timings: prompt eval time =      14.47 ms /     7 tokens (    2.07 ms per token,   483.59 tokens per second)
llama_print_timings:        eval time =     650.88 ms /    63 runs   (   10.33 ms per token,    96.79 tokens per second)
llama_print_timings:       total time =     701.26 ms /    70 tokens
Log end

@JohannesGaessler Does it work on your end?

@phymbert
Copy link
Collaborator

@ggerganov it looks you need CUDA 12. My issue disapears on this branch with CUDA 12. So good to go if we consider this is a breaking change.

@ggerganov
Copy link
Owner

Hm I think I already use CUDA 12.3. Will double-check later

@JohannesGaessler
Copy link
Collaborator Author

The issue has nothing to do with CUDA 12. The code is working correctly for parallel_blocks == 1. The issue is that I had forgotten that you need more than the IEEE 754 half precision range to store the max. KQ value when I turned parallel blocks back on.

@phymbert
Copy link
Collaborator

My bad then, sorry for the confusion and thanks for the explanation.

@JohannesGaessler
Copy link
Collaborator Author

I pushed a fix, now it should work.

@phymbert
Copy link
Collaborator

@JohannesGaessler would it be possible to resolve conflicts here ? I would like to test the server --flash-attn ? thanks

Copy link
Owner

@ggerganov ggerganov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, merge at will

@JohannesGaessler JohannesGaessler merged commit 87968de into ggerganov:gg/flash-attn Apr 18, 2024
38 of 61 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants