-
Notifications
You must be signed in to change notification settings - Fork 10.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add GGML_HIP_ROCWMMA_FATTN to enable rocWMMA for FlashAttention #12032
base: master
Are you sure you want to change the base?
Conversation
Adding @IMbackK for review |
As of right now I do not have the hardware necessary to test this code. Is anyone pledging to maintain it long-term if it gets merged? |
@JohannesGaessler So we will keep this code path and @adelj88 @thamwangjun 's optimized code path together in the future? |
ggml/src/ggml-cuda/fattn-wmma-f16.cu
Outdated
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) | ||
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, frag_m, frag_n, 16, half, nvcuda::wmma::row_major> frag_a_K; | ||
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_a_V; | ||
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_b; | ||
typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t> frag_c_KQ; | ||
typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, half> frag_c_VKQ; | ||
#else | ||
typedef rocwmma::fragment<rocwmma::matrix_a, frag_m, frag_n, 16, half, rocwmma::row_major> frag_a_K; | ||
typedef rocwmma::fragment<rocwmma::matrix_a, frag_m, frag_n, 16, half, rocwmma::col_major> frag_a_V; | ||
typedef rocwmma::fragment<rocwmma::matrix_b, frag_m, frag_n, 16, half, rocwmma::col_major> frag_b; | ||
typedef rocwmma::fragment<rocwmma::accumulator, frag_m, frag_n, 16, KQ_acc_t> frag_c_KQ; | ||
typedef rocwmma::fragment<rocwmma::accumulator, frag_m, frag_n, 16, half> frag_c_VKQ; | ||
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there any reason to do it like this and not with something like using namespace nvcuda:wmma
?
#if defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE) | ||
if (fp16_mma_available(cc) && dst->src[0]->ne[1] > 8) { | ||
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst); | ||
return; | ||
} | ||
#endif // defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Move the comment above down since it is now confusing.
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
GGML_HIP_ROCWMMA_FATTN
and defaults to OFFGGML_HIP_ROCWMMA_FATTN
is enabledFP16_MMA_AVAILABLE
whenGGML_HIP_ROCWMMA_FATTN
is enabled and target is supported by rocWMMA (CDNA / RDNA3)Related issue: #10439