@@ -1000,12 +1000,13 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
10001000 {
10011001 GGML_ASSERT (hparams.swa_type == LLAMA_SWA_TYPE_NONE && " Hybrid recurrent is not supported with SWA attention layers" );
10021002
1003- const auto n_kv = inp->mctx ->get_attn ()->get_n_kv ();
1003+ const auto n_kv = inp->mctx ->get_attn ()->get_n_kv ();
1004+ const auto n_seqs = cparams.n_seq_virt > 1 ? ubatch.n_seqs_unq : 1 ;
10041005
10051006 inp->self_k_idxs = mctx_cur->get_attn ()->build_input_k_idxs (ctx0, ubatch);
10061007 inp->self_v_idxs = mctx_cur->get_attn ()->build_input_v_idxs (ctx0, ubatch);
10071008
1008- inp->self_kq_mask = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD), 1 , 1 );
1009+ inp->self_kq_mask = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens/n_seqs , GGML_KQ_MASK_PAD), 1 , n_seqs );
10091010 ggml_set_input (inp->self_kq_mask );
10101011
10111012 inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask , GGML_TYPE_F16) : inp->self_kq_mask ;
@@ -1032,6 +1033,10 @@ ggml_tensor * llm_graph_context::build_attn_mha(
10321033 float kq_scale) const {
10331034 const bool v_trans = v->nb [1 ] > v->nb [2 ];
10341035
1036+ const auto n_seqs = cparams.n_seq_virt > 1 ? ubatch.n_seqs_unq : 1 ;
1037+
1038+ q = ggml_reshape_4d (ctx0, q, q->ne [0 ], q->ne [1 ], q->ne [2 ]/n_seqs, n_seqs);
1039+
10351040 q = ggml_permute (ctx0, q, 0 , 2 , 1 , 3 );
10361041 k = ggml_permute (ctx0, k, 0 , 2 , 1 , 3 );
10371042 v = ggml_permute (ctx0, v, 0 , 2 , 1 , 3 );
@@ -1080,7 +1085,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
10801085#endif
10811086 }
10821087
1083- cur = ggml_reshape_2d (ctx0, cur, cur->ne [0 ]*n_head, n_tokens);
1088+ cur = ggml_reshape_2d (ctx0, cur, cur->ne [0 ]*n_head, n_tokens*n_seqs );
10841089 } else {
10851090 ggml_tensor * kq = ggml_mul_mat (ctx0, k, q);
10861091
@@ -1125,7 +1130,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
11251130
11261131 cur = ggml_permute (ctx0, kqv, 0 , 2 , 1 , 3 );
11271132
1128- cur = ggml_cont_2d (ctx0, cur, cur->ne [0 ]*n_head, n_tokens);
1133+ cur = ggml_cont_2d (ctx0, cur, cur->ne [0 ]*n_head, n_tokens*n_seqs );
11291134
11301135 if (!cparams.offload_kqv ) {
11311136 // all nodes between the KV store and the attention output are run on the CPU
@@ -1202,12 +1207,13 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
12021207 {
12031208 GGML_ASSERT (hparams.swa_type == LLAMA_SWA_TYPE_NONE && " Use llama_kv_cache_unified_iswa for SWA" );
12041209
1205- const auto n_kv = mctx_cur->get_n_kv ();
1210+ const auto n_kv = mctx_cur->get_n_kv ();
1211+ const auto n_seqs = cparams.n_seq_virt > 1 ? ubatch.n_seqs_unq : 1 ;
12061212
12071213 inp->self_k_idxs = mctx_cur->build_input_k_idxs (ctx0, ubatch);
12081214 inp->self_v_idxs = mctx_cur->build_input_v_idxs (ctx0, ubatch);
12091215
1210- inp->self_kq_mask = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD), 1 , 1 );
1216+ inp->self_kq_mask = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens/n_seqs , GGML_KQ_MASK_PAD), 1 , n_seqs );
12111217 ggml_set_input (inp->self_kq_mask );
12121218
12131219 inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask , GGML_TYPE_F16) : inp->self_kq_mask ;
@@ -1449,13 +1455,15 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
14491455
14501456 auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, mctx_cur);
14511457
1458+ const auto n_seqs = cparams.n_seq_virt > 1 ? ubatch.n_seqs_unq : 1 ;
1459+
14521460 {
14531461 const auto n_kv = mctx_cur->get_base ()->get_n_kv ();
14541462
14551463 inp->self_k_idxs = mctx_cur->get_base ()->build_input_k_idxs (ctx0, ubatch);
14561464 inp->self_v_idxs = mctx_cur->get_base ()->build_input_v_idxs (ctx0, ubatch);
14571465
1458- inp->self_kq_mask = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD), 1 , 1 );
1466+ inp->self_kq_mask = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens/n_seqs , GGML_KQ_MASK_PAD), 1 , n_seqs );
14591467 ggml_set_input (inp->self_kq_mask );
14601468
14611469 inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask , GGML_TYPE_F16) : inp->self_kq_mask ;
@@ -1469,7 +1477,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
14691477 inp->self_k_idxs_swa = mctx_cur->get_swa ()->build_input_k_idxs (ctx0, ubatch);
14701478 inp->self_v_idxs_swa = mctx_cur->get_swa ()->build_input_v_idxs (ctx0, ubatch);
14711479
1472- inp->self_kq_mask_swa = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD), 1 , 1 );
1480+ inp->self_kq_mask_swa = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens/n_seqs , GGML_KQ_MASK_PAD), 1 , n_seqs );
14731481 ggml_set_input (inp->self_kq_mask_swa );
14741482
14751483 inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask_swa , GGML_TYPE_F16) : inp->self_kq_mask_swa ;
0 commit comments