Skip to content

CUDA: 4D FlashAttention support #14628

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

JohannesGaessler
Copy link
Collaborator

This PR adds 4-dimensional CUDA FlashAttention support for #14363 . The data layout for the fixup was changed but there should be no change to performance. As discussed in #14505 (comment) , the CUDA code requires mask->ne[2] == 1, otherwise it would require additional complexity to ensure that the GQA-specific optimizations in fattn-mma-f16.cuh produce correct results.

@github-actions github-actions bot added testing Everything test related Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Jul 11, 2025
Copy link
Member

@ggerganov ggerganov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests are passing on RTX 2060

@ggerganov ggerganov force-pushed the gg/llama-high-throughput branch from f23950a to ab82dc2 Compare July 11, 2025 08:27
@JohannesGaessler
Copy link
Collaborator Author

There was some issue with the WMMA kernel (which is now fixed), merge when convenient for you.

@ggerganov ggerganov merged commit c43f275 into ggml-org:gg/llama-high-throughput Jul 11, 2025
47 checks passed
ggerganov pushed a commit that referenced this pull request Jul 12, 2025
* CUDA: 4D FlashAttention support

* CUDA: fix WMMA FA kernel
@CISC
Copy link
Collaborator

CISC commented Jul 13, 2025

Something is wrong, I'm getting a ton of failures on 3090Ti (CUDA 12.9):

[...]
  FLASH_ATTN_EXT(hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]): OK
  FLASH_ATTN_EXT(hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q8_0,permute=[0,2,1,3]): OK
  FLASH_ATTN_EXT(hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]): OK
  FLASH_ATTN_EXT(hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q4_0,permute=[0,2,1,3]): OK
[FLASH_ATTN_EXT] NMSE = 0.421540541 > 0.000500000   FLASH_ATTN_EXT(hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]): FAIL
[FLASH_ATTN_EXT] NMSE = 0.471500105 > 0.000500000   FLASH_ATTN_EXT(hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,2,1,3]): FAIL
  FLASH_ATTN_EXT(hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]): not supported [CUDA0] 
  FLASH_ATTN_EXT(hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=bf16,permute=[0,2,1,3]): not supported [CUDA0] 
[FLASH_ATTN_EXT] NMSE = 0.458659731 > 0.000500000   FLASH_ATTN_EXT(hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]): FAIL
[FLASH_ATTN_EXT] NMSE = 0.460324585 > 0.000500000   FLASH_ATTN_EXT(hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q8_0,permute=[0,2,1,3]): FAIL
[FLASH_ATTN_EXT] NMSE = 0.445988407 > 0.000500000   FLASH_ATTN_EXT(hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]): FAIL
[FLASH_ATTN_EXT] NMSE = 0.465820280 > 0.000500000   FLASH_ATTN_EXT(hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q4_0,permute=[0,2,1,3]): FAIL
[FLASH_ATTN_EXT] NMSE = 0.409744725 > 0.000500000   FLASH_ATTN_EXT(hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f16,permute=[0,1,2,3]): FAIL
[FLASH_ATTN_EXT] NMSE = 0.420985664 > 0.000500000   FLASH_ATTN_EXT(hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f16,permute=[0,2,1,3]): FAIL
[...]
  6407/6551 tests passed
  Backend CUDA0: FAIL

@ggerganov
Copy link
Member

You are testing master. This wa merged in another brabch

@CISC
Copy link
Collaborator

CISC commented Jul 13, 2025

You are testing master. This wa merged in another brabch

Ah, LOL, sorry. :)

Why is master failing though?

@JohannesGaessler
Copy link
Collaborator Author

If master is failing, can you do a git bisect to determine since when?

@ggerganov
Copy link
Member

Its failing the mask->ne[2] != 1 tests. These are not relevant

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs testing Everything test related
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants