Skip to content

Commit 11c11da

Browse files
committed
Revert "HIP: Disable ROCWMMA fattn on CDNA when compiled against ROCWMMA 2.0.0 (ggml-org#16221)"
This reverts commit e95fec6.
1 parent 4da0bbe commit 11c11da

File tree

6 files changed

+39
-61
lines changed

6 files changed

+39
-61
lines changed

ggml/src/ggml-cuda/common.cuh

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,14 @@ static const char * cu_get_error_str(CUresult err) {
225225
#define FAST_FP16_AVAILABLE
226226
#endif // defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610
227227

228+
#if (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA)
229+
#define FP16_MMA_AVAILABLE
230+
#endif // (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA)
231+
232+
#if defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || (defined(GGML_HIP_ROCWMMA_FATTN_GFX12) && defined(RDNA4)))
233+
#define FP16_MMA_AVAILABLE
234+
#endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || (defined(GGML_HIP_ROCWMMA_FATTN_GFX12) && defined(RDNA4)))
235+
228236
#if defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA)
229237
#define AMD_MFMA_AVAILABLE
230238
#endif // defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA)
@@ -259,6 +267,27 @@ static bool fast_fp16_hardware_available(const int cc) {
259267
(GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2);
260268
}
261269

270+
// Any FP16 tensor core instructions are available for ggml code.
271+
static bool fp16_mma_available(const int cc) {
272+
#if defined(GGML_USE_HIP) && !defined(GGML_HIP_ROCWMMA_FATTN)
273+
return false;
274+
#else
275+
if ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) ||
276+
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) ||
277+
GGML_CUDA_CC_IS_MTHREADS(cc)) {
278+
return true;
279+
} else if (GGML_CUDA_CC_IS_RDNA4(cc)) {
280+
#if defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_HIP_ROCWMMA_FATTN_GFX12)
281+
return true;
282+
#else
283+
return false;
284+
#endif // defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_HIP_ROCWMMA_FATTN_GFX12)
285+
} else {
286+
return false;
287+
}
288+
#endif // defined(GGML_USE_HIP) && !defined(GGML_HIP_ROCWMMA_FATTN)
289+
}
290+
262291
// To be used for feature selection of external libraries, e.g. cuBLAS.
263292
static bool fp16_mma_hardware_available(const int cc) {
264293
return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_VOLTA) ||

ggml/src/ggml-cuda/fattn-tile.cu

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
#include "common.cuh"
22
#include "fattn-common.cuh"
33
#include "fattn-tile.cuh"
4-
#include "fattn-wmma-f16.cuh"
54

65
// kq_stride == number of KQ rows to process per iteration
76
// kq_nbatch == number of K columns to load in parallel for KQ calculation
@@ -191,10 +190,10 @@ static __global__ void flash_attn_tile(
191190
#ifdef FLASH_ATTN_AVAILABLE
192191

193192
// Skip unused kernel variants for faster compilation:
194-
#ifdef GGML_USE_WMMA_FATTN
193+
#ifdef FP16_MMA_AVAILABLE
195194
NO_DEVICE_CODE;
196195
return;
197-
#endif // GGML_USE_WMMA_FATTN
196+
#endif // FP16_MMA_AVAILABLE
198197

199198
if (use_logit_softcap && !(D == 128 || D == 256)) {
200199
GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,

ggml/src/ggml-cuda/fattn-wmma-f16.cu

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,19 @@
66
#include "fattn-common.cuh"
77
#include "fattn-wmma-f16.cuh"
88

9-
#ifdef GGML_USE_WMMA_FATTN
9+
#ifdef FP16_MMA_AVAILABLE
1010
#if !defined(GGML_USE_HIP)
1111
#include <mma.h>
12-
#if defined(GGML_USE_MUSA)
12+
#ifdef GGML_USE_MUSA
1313
namespace wmma = mtmusa::wmma;
1414
#else // GGML_USE_MUSA
1515
namespace wmma = nvcuda::wmma;
1616
#endif // GGML_USE_MUSA
17-
#elif defined(GGML_USE_HIP)
17+
#elif defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)
1818
#include <rocwmma/rocwmma.hpp>
1919
namespace wmma = rocwmma;
2020
#endif // !defined(GGML_USE_HIP)
21-
#endif // GGML_USE_WMMA_FATTN
21+
#endif // FP16_MMA_AVAILABLE
2222

2323
// D == head size, VKQ_stride == num VKQ rows calculated in parallel:
2424
template<int D, int ncols, int nwarps, int VKQ_stride, typename KQ_acc_t, bool use_logit_softcap>
@@ -45,7 +45,7 @@ static __global__ void flash_attn_ext_f16(
4545
const int32_t nb21, const int32_t nb22, const int64_t nb23,
4646
const int32_t ne31, const int32_t ne32, const int32_t ne33,
4747
const int32_t nb31, const int32_t nb32, const int64_t nb33) {
48-
#if defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_USE_WMMA_FATTN)))
48+
#if defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)))
4949
// Skip unused kernel variants for faster compilation:
5050
if (use_logit_softcap && !(D == 128 || D == 256)) {
5151
NO_DEVICE_CODE;
@@ -481,7 +481,7 @@ static __global__ void flash_attn_ext_f16(
481481
ne31, ne32, ne33,
482482
nb31, nb32, nb33);
483483
NO_DEVICE_CODE;
484-
#endif // defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_USE_WMMA_FATTN)))
484+
#endif // defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)))
485485
}
486486

487487
constexpr int get_max_power_of_2(int x) {
Lines changed: 0 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,49 +1,3 @@
11
#include "common.cuh"
22

3-
#if (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA)
4-
#define GGML_USE_WMMA_FATTN
5-
#endif // (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA)
6-
7-
#if defined(GGML_HIP_ROCWMMA_FATTN)
8-
#if defined(CDNA) && (ROCWMMA_VERSION_MAJOR < 2 || ROCWMMA_VERSION_MINOR > 0 || ROCWMMA_VERSION_PATCH > 0)
9-
#define GGML_USE_WMMA_FATTN
10-
#elif defined(CDNA)
11-
#warning "rocwmma fattn on CDNA is broken on rocwmma v2.0.0, expect degraded performance"
12-
#endif // defined(CDNA) && (ROCWMMA_VERSION_MAJOR < 2 || ROCWMMA_VERSION_MINOR > 0 || ROCWMMA_VERSION_PATCH > 0)
13-
#if defined(RDNA3)
14-
#define GGML_USE_WMMA_FATTN
15-
#endif // defined(RDNA3)
16-
#if defined(RDNA4) && ROCWMMA_VERSION_MAJOR > 1
17-
#define GGML_USE_WMMA_FATTN
18-
#elif defined(RDNA4)
19-
#warning "rocwmma fattn is not suported on RDNA4 on rocwmma < v2.0.0, expect degraded performance"
20-
#endif // defined(RDNA4) && ROCWMMA_VERSION_MAJOR > 1
21-
#endif // defined(GGML_HIP_ROCWMMA_FATTN)
22-
23-
// WMMA flash attention requires FP16 matrix instructions to be available for ggml code.
24-
static bool ggml_cuda_should_use_wmma_fattn(const int cc) {
25-
#if defined(GGML_USE_HIP) && !defined(GGML_HIP_ROCWMMA_FATTN)
26-
return false;
27-
#else
28-
if ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_VOLTA) ||
29-
GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_MTHREADS(cc)) {
30-
return true;
31-
} else if (GGML_CUDA_CC_IS_CDNA(cc)){
32-
#if defined(GGML_HIP_ROCWMMA_FATTN) && (ROCWMMA_VERSION_MAJOR < 2 || ROCWMMA_VERSION_MINOR > 0 || ROCWMMA_VERSION_PATCH > 0)
33-
return true;
34-
#else
35-
return false;
36-
#endif // defined(GGML_HIP_ROCWMMA_FATTN) (ROCWMMA_VERSION_MAJOR < 2 || ROCWMMA_VERSION_MINOR > 0 || ROCWMMA_VERSION_PATCH > 0)
37-
} else if (GGML_CUDA_CC_IS_RDNA4(cc)) {
38-
#if defined(GGML_HIP_ROCWMMA_FATTN) && ROCWMMA_VERSION_MAJOR > 1
39-
return true;
40-
#else
41-
return false;
42-
#endif // defined(GGML_HIP_ROCWMMA_FATTN) && ROCWMMA_VERSION_MAJOR > 1
43-
} else {
44-
return false;
45-
}
46-
#endif // defined(GGML_USE_HIP) && !defined(GGML_HIP_ROCWMMA_FATTN)
47-
}
48-
493
void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

ggml/src/ggml-cuda/fattn.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
228228
if (V->ne[0] != K->ne[0]) {
229229
return BEST_FATTN_KERNEL_NONE;
230230
}
231-
if (!ggml_cuda_should_use_wmma_fattn(cc) && !turing_mma_available(cc)) {
231+
if (!fp16_mma_available(cc) && !turing_mma_available(cc)) {
232232
return BEST_FATTN_KERNEL_NONE;
233233
}
234234
break;
@@ -311,7 +311,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
311311
}
312312

313313
// For large batch sizes, use the WMMA kernel if possible:
314-
if (ggml_cuda_should_use_wmma_fattn(cc)) {
314+
if (fp16_mma_available(cc)) {
315315
return BEST_FATTN_KERNEL_WMMA_F16;
316316
}
317317

ggml/src/ggml-cuda/vendors/hip.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,6 @@
66
#include <hip/hip_fp16.h>
77
#include <hip/hip_bf16.h>
88

9-
#if defined(GGML_HIP_ROCWMMA_FATTN)
10-
#include <rocwmma/rocwmma-version.hpp>
11-
#endif // defined(GGML_HIP_ROCWMMA_FATTN)
12-
139
#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
1410
#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT
1511
#define CUBLAS_OP_N HIPBLAS_OP_N

0 commit comments

Comments
 (0)