66#include " fattn-common.cuh"
77#include " fattn-wmma-f16.cuh"
88
9- #ifdef GGML_USE_WMMA_FATTN
9+ #ifdef FP16_MMA_AVAILABLE
1010#if !defined(GGML_USE_HIP)
1111#include < mma.h>
12- #if defined( GGML_USE_MUSA)
12+ #ifdef GGML_USE_MUSA
1313namespace wmma = mtmusa::wmma;
1414#else // GGML_USE_MUSA
1515namespace wmma = nvcuda::wmma;
1616#endif // GGML_USE_MUSA
17- #elif defined(GGML_USE_HIP )
17+ #elif defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE )
1818#include < rocwmma/rocwmma.hpp>
1919namespace wmma = rocwmma;
2020#endif // !defined(GGML_USE_HIP)
21- #endif // GGML_USE_WMMA_FATTN
21+ #endif // FP16_MMA_AVAILABLE
2222
2323// D == head size, VKQ_stride == num VKQ rows calculated in parallel:
2424template <int D, int ncols, int nwarps, int VKQ_stride, typename KQ_acc_t, bool use_logit_softcap>
@@ -45,7 +45,7 @@ static __global__ void flash_attn_ext_f16(
4545 const int32_t nb21, const int32_t nb22, const int64_t nb23,
4646 const int32_t ne31, const int32_t ne32, const int32_t ne33,
4747 const int32_t nb31, const int32_t nb32, const int64_t nb33) {
48- #if defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_USE_WMMA_FATTN )))
48+ #if defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE )))
4949 // Skip unused kernel variants for faster compilation:
5050 if (use_logit_softcap && !(D == 128 || D == 256 )) {
5151 NO_DEVICE_CODE;
@@ -481,7 +481,7 @@ static __global__ void flash_attn_ext_f16(
481481 ne31, ne32, ne33,
482482 nb31, nb32, nb33);
483483 NO_DEVICE_CODE;
484- #endif // defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_USE_WMMA_FATTN )))
484+ #endif // defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE )))
485485}
486486
487487constexpr int get_max_power_of_2 (int x) {
0 commit comments