Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUDA: Faster FlashAttention kernel #6374

Merged
merged 10 commits into from
Apr 2, 2024
Prev Previous commit
Next Next commit
fix excessive KQ_b loads
  • Loading branch information
JohannesGaessler committed Apr 2, 2024
commit 46968c93dc755cf382c7f38bfd4ba1edad24a149
12 changes: 8 additions & 4 deletions ggml-cuda/fattn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -387,12 +387,16 @@ static __global__ void flash_attn_ext_f16(

__syncthreads();

frag_b KQ_b[FATTN_KQ_STRIDE/16][ncols/frag_n];
frag_b KQ_b[FATTN_KQ_STRIDE/(VKQ_ratio*16)][ncols/frag_n];
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
#pragma unroll
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += 16) {
nvcuda::wmma::load_matrix_sync(KQ_b[k0/16][j0/frag_n], KQ + j0*kqs_padded + k0, kqs_padded);
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
nvcuda::wmma::load_matrix_sync(
KQ_b[k0/(VKQ_ratio*16)][j0/frag_n],
KQ + j0*kqs_padded + k,
kqs_padded);
}
}

Expand All @@ -412,7 +416,7 @@ static __global__ void flash_attn_ext_f16(
nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
#pragma unroll
for (int j = 0; j < ncols/frag_n; ++j) {
nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k/16][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
}
}
}
Expand Down