diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 0be20303141239..f798f4193ec13b 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -7552,9 +7552,10 @@ static __global__ void flash_attn_ext_f16( __builtin_assume(tid < nthreads); constexpr int D_padded = D + 8; // Pad internal representation of KQ, KQV to reduce shared memory bank conflicts. - const float * Q_f = (const float *) (Q + nb02*blockIdx.y + ncols*nb01*blockIdx.x); - const half * K_h = (const half *) (K + nb12*blockIdx.y); - const half * V_h = (const half *) (V + nb12*blockIdx.y); // K and V have same shape + const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. + const float * Q_f = (const float *) (Q + nb02* blockIdx.y + ncols*nb01*blockIdx.x); + const half * K_h = (const half *) (K + nb12*(blockIdx.y / gqa_ratio)); + const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape const half2 * mask2 = (half2 *) mask + ncols*ne11*blockIdx.x/2; const int stride_Q = nb01 / sizeof(float); diff --git a/llama.cpp b/llama.cpp index b80080daf7506e..18f09d49e3873d 100644 --- a/llama.cpp +++ b/llama.cpp @@ -9166,7 +9166,7 @@ static int llama_decode_internal( // a heuristic, to avoid attending the full cache if it is not yet utilized // after enough generations, the benefit from this heuristic disappears // if we start defragmenting the cache, the benefit from this will be more important - kv_self.n = std::min(kv_self.size, std::max(128u, GGML_PAD(llama_kv_cache_cell_max(kv_self), 128))); + kv_self.n = std::min(kv_self.size, std::max(256u, GGML_PAD(llama_kv_cache_cell_max(kv_self), 256))); //kv_self.n = llama_kv_cache_cell_max(kv_self); } } @@ -13083,7 +13083,7 @@ struct llama_context * llama_new_context_with_model( cparams.rope_freq_scale = params.rope_freq_scale == 0.0f ? hparams.rope_freq_scale_train : params.rope_freq_scale; // this is necessary due to kv_self.n being padded later during inference - cparams.n_ctx = GGML_PAD(cparams.n_ctx, 32); + cparams.n_ctx = GGML_PAD(cparams.n_ctx, 256); // with causal attention, the batch size is limited by the context size cparams.n_batch = hparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch;