Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GGML_HIP_ROCWMMA_FATTN to enable rocWMMA for FlashAttention #12032

Open
wants to merge 6 commits into
base: master
Choose a base branch
from

Conversation

hjc4869
Copy link

@hjc4869 hjc4869 commented Feb 22, 2025

  • Add a new option GGML_HIP_ROCWMMA_FATTN and defaults to OFF
  • Check for rocWMMA header availability when GGML_HIP_ROCWMMA_FATTN is enabled
  • Define FP16_MMA_AVAILABLE when GGML_HIP_ROCWMMA_FATTN is enabled and target is supported by rocWMMA (CDNA / RDNA3)
  • Use rocWMMA in FlashAttention kernel when possible

Related issue: #10439

@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Feb 22, 2025
@hjc4869
Copy link
Author

hjc4869 commented Feb 22, 2025

Adding @IMbackK for review

@JohannesGaessler
Copy link
Collaborator

As of right now I do not have the hardware necessary to test this code. Is anyone pledging to maintain it long-term if it gets merged?

@hjc4869
Copy link
Author

hjc4869 commented Feb 23, 2025

As of right now I do not have the hardware necessary to test this code. Is anyone pledging to maintain it long-term if it gets merged?

@IMbackK mentioned in hjc4869#1 that he'll take the ownership of this implementation and maintain it here. That's why I'm pinging him in this thread.

@Headcrabed
Copy link

@JohannesGaessler So we will keep this code path and @adelj88 @thamwangjun 's optimized code path together in the future?

Comment on lines 76 to 88
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, frag_m, frag_n, 16, half, nvcuda::wmma::row_major> frag_a_K;
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_a_V;
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_b;
typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t> frag_c_KQ;
typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, half> frag_c_VKQ;
#else
typedef rocwmma::fragment<rocwmma::matrix_a, frag_m, frag_n, 16, half, rocwmma::row_major> frag_a_K;
typedef rocwmma::fragment<rocwmma::matrix_a, frag_m, frag_n, 16, half, rocwmma::col_major> frag_a_V;
typedef rocwmma::fragment<rocwmma::matrix_b, frag_m, frag_n, 16, half, rocwmma::col_major> frag_b;
typedef rocwmma::fragment<rocwmma::accumulator, frag_m, frag_n, 16, KQ_acc_t> frag_c_KQ;
typedef rocwmma::fragment<rocwmma::accumulator, frag_m, frag_n, 16, half> frag_c_VKQ;
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any reason to do it like this and not with something like using namespace nvcuda:wmma?

Comment on lines +257 to +263
#if defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)
if (fp16_mma_available(cc) && dst->src[0]->ne[1] > 8) {
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
return;
}
#endif // defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move the comment above down since it is now confusing.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants