| 
2 | 2 | #include "fattn-common.cuh"  | 
3 | 3 | 
 
  | 
4 | 4 | template<int D, int ncols, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size  | 
5 |  | -#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))  | 
 | 5 | +#ifndef GGML_USE_HIP  | 
6 | 6 | __launch_bounds__(D, 1)  | 
7 |  | -#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))  | 
 | 7 | +#endif // GGML_USE_HIP  | 
8 | 8 | static __global__ void flash_attn_vec_ext_f16(  | 
9 | 9 |         const char * __restrict__ Q,  | 
10 | 10 |         const char * __restrict__ K,  | 
@@ -48,6 +48,12 @@ static __global__ void flash_attn_vec_ext_f16(  | 
48 | 48 |         NO_DEVICE_CODE;  | 
49 | 49 |         return;  | 
50 | 50 |     }  | 
 | 51 | +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)  | 
 | 52 | +    if (ncols > 1) {  | 
 | 53 | +        NO_DEVICE_CODE;  | 
 | 54 | +        return;  | 
 | 55 | +    }  | 
 | 56 | +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)  | 
51 | 57 | 
 
  | 
52 | 58 |     //In this kernel Q, K, V are matrices while i, j, k are matrix indices.  | 
53 | 59 | 
 
  | 
@@ -91,6 +97,13 @@ static __global__ void flash_attn_vec_ext_f16(  | 
91 | 97 |             kqsum_shared[j][threadIdx.x] = 0.0f;  | 
92 | 98 |         }  | 
93 | 99 |     }  | 
 | 100 | + | 
 | 101 | +    __shared__ half maskh_shared[ncols*D];  | 
 | 102 | +#pragma unroll  | 
 | 103 | +    for (int j = 0; j < ncols; ++j) {  | 
 | 104 | +        maskh_shared[j*D + tid] = 0.0f;  | 
 | 105 | +    }  | 
 | 106 | + | 
94 | 107 |     __syncthreads();  | 
95 | 108 | 
 
  | 
96 | 109 |     // Convert Q to half2 (f16 K) or q8_1 (quantized K) and store in registers:  | 
@@ -175,6 +188,35 @@ static __global__ void flash_attn_vec_ext_f16(  | 
175 | 188 |     for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) {  | 
176 | 189 |         // Calculate KQ tile and keep track of new maximum KQ values:  | 
177 | 190 | 
 
  | 
 | 191 | +        if (mask) {  | 
 | 192 | +#pragma unroll  | 
 | 193 | +            for (int j = 0; j < ncols; ++j) {  | 
 | 194 | +                maskh_shared[j*D + tid] = slopeh*maskh[j*ne11 + k_VKQ_0 + tid];  | 
 | 195 | +            }  | 
 | 196 | + | 
 | 197 | +            __syncthreads();  | 
 | 198 | + | 
 | 199 | +            // When using multiple parallel sequences in llama.cpp, some KV slices can be fully masked out.  | 
 | 200 | +            // In such cases, skip the KV slice.  | 
 | 201 | +            // On AMD __all_sync would not work correctly because it assumes a warp size of 64.  | 
 | 202 | +#ifndef GGML_USE_HIP  | 
 | 203 | +            bool skip = true;  | 
 | 204 | +#pragma unroll  | 
 | 205 | +            for (int j = 0; j < ncols; ++j) {  | 
 | 206 | +#pragma unroll  | 
 | 207 | +                for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {  | 
 | 208 | +                    const int i = i0 + threadIdx.x;  | 
 | 209 | + | 
 | 210 | +                    const float2 tmp = __half22float2(((const half2 *) maskh_shared)[j*(D/2) + i]);  | 
 | 211 | +                    skip = skip && isinf(tmp.x) && isinf(tmp.y);  | 
 | 212 | +                }  | 
 | 213 | +            }  | 
 | 214 | +            if (__all_sync(0xFFFFFFFF, skip)) {  | 
 | 215 | +                continue;  | 
 | 216 | +            }  | 
 | 217 | +#endif // GGML_USE_HIP  | 
 | 218 | +        }  | 
 | 219 | + | 
178 | 220 |         // For unknown reasons using a half array of size 1 for kqmax_new causes a performance regression,  | 
179 | 221 |         // see https://github.com/ggerganov/llama.cpp/pull/7061 .  | 
180 | 222 |         // Therefore this variable is defined twice but only used once (so that the compiler can optimize out the unused variable).  | 
@@ -202,7 +244,7 @@ static __global__ void flash_attn_vec_ext_f16(  | 
202 | 244 |                     sum = logit_softcap*tanhf(sum);  | 
203 | 245 |                 }  | 
204 | 246 | 
 
  | 
205 |  | -                sum += mask ? slopeh*maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);  | 
 | 247 | +                sum += maskh_shared[j*D + i_KQ];  | 
206 | 248 | 
 
  | 
207 | 249 |                 if (ncols == 1) {  | 
208 | 250 |                     kqmax_new        = ggml_cuda_hmax(kqmax_new,        sum);  | 
@@ -335,7 +377,9 @@ void ggml_cuda_flash_attn_ext_vec_f16_case(ggml_backend_cuda_context & ctx, ggml  | 
335 | 377 |     float logit_softcap;  | 
336 | 378 |     memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));  | 
337 | 379 | 
 
  | 
338 |  | -    if (Q->ne[1] == 1) {  | 
 | 380 | +    const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;  | 
 | 381 | + | 
 | 382 | +    if (Q->ne[1] == 1 || GGML_CUDA_CC_IS_NVIDIA(cc)) {  | 
339 | 383 |         constexpr int cols_per_block = 1;  | 
340 | 384 |         if (logit_softcap == 0.0f) {  | 
341 | 385 |             constexpr bool use_logit_softcap = false;  | 
 | 
0 commit comments