3
3
4
4
#include < mma.h>
5
5
6
- static __device__ __forceinline__ half2 warp_reduce_sum (half2 a) {
7
- #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
8
- #pragma unroll
9
- for (int mask = 16 ; mask > 0 ; mask >>= 1 ) {
10
- a = __hadd2 (a, __shfl_xor_sync (0xffffffff , a, mask, 32 ));
11
- }
12
- return a;
13
- #else
14
- GGML_UNUSED (a);
15
- NO_DEVICE_CODE;
16
- #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
17
- }
18
-
19
- // static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
20
- // #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
21
- // #pragma unroll
22
- // for (int mask = 16; mask > 0; mask >>= 1) {
23
- // x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
24
- // }
25
- // return x;
26
- // #else
27
- // GGML_UNUSED(x);
28
- // NO_DEVICE_CODE;
29
- // #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
30
- // }
31
-
32
6
#define FATTN_KQ_STRIDE 256
33
7
34
8
template <int D, int parallel_blocks> // D == head size
@@ -61,6 +35,7 @@ static __global__ void flash_attn_vec_ext_f16(
61
35
const int ne1,
62
36
const int ne2,
63
37
const int ne3) {
38
+ #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
64
39
// In this kernel Q, K, V are matrices while i, j, k are matrix indices.
65
40
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
66
41
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx .y );
@@ -201,6 +176,9 @@ static __global__ void flash_attn_vec_ext_f16(
201
176
dst_meta[blockIdx .y *parallel_blocks + blockIdx .x ] = make_half2 (kqmax, kqsum);
202
177
}
203
178
}
179
+ #else
180
+ NO_DEVICE_CODE;
181
+ #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
204
182
}
205
183
206
184
template <int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks> // D == head size, VKQ_stride == num VKQ rows calculated in parallel
@@ -233,6 +211,7 @@ static __global__ void flash_attn_ext_f16(
233
211
const int ne1,
234
212
const int ne2,
235
213
const int ne3) {
214
+ #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
236
215
// In this kernel Q, K, V are matrices while i, j, k are matrix indices.
237
216
static_assert (D <= FATTN_KQ_STRIDE, " D must be <= FATTN_KQ_STRIDE." );
238
217
static_assert (ncols == 8 || ncols % 16 == 0 , " ncols must be 8 or a multiple of 16." );
@@ -491,6 +470,9 @@ static __global__ void flash_attn_ext_f16(
491
470
__low2half (KQ_max[0 ]), __low2half (KQ_rowsum[0 ]) + __high2half (KQ_rowsum[0 ]));
492
471
}
493
472
}
473
+ #else
474
+ NO_DEVICE_CODE;
475
+ #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
494
476
}
495
477
496
478
template <int D, int parallel_blocks> // D == head size
@@ -499,6 +481,7 @@ static __global__ void flash_attn_combine_results(
499
481
const float * __restrict__ VKQ_parts,
500
482
const half2 * __restrict__ VKQ_meta,
501
483
float * __restrict__ dst) {
484
+ #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
502
485
503
486
const int tid = threadIdx .x ;
504
487
__builtin_assume (tid < D);
@@ -527,6 +510,9 @@ static __global__ void flash_attn_combine_results(
527
510
}
528
511
529
512
dst[blockIdx .y *D + tid] = VKQ_numerator / VKQ_denominator;
513
+ #else
514
+ NO_DEVICE_CODE;
515
+ #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
530
516
}
531
517
532
518
constexpr int get_max_power_of_2 (int x) {
0 commit comments