@@ -1000,13 +1000,13 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
10001000 {
10011001 GGML_ASSERT (hparams.swa_type == LLAMA_SWA_TYPE_NONE && " Hybrid recurrent is not supported with SWA attention layers" );
10021002
1003- const auto n_kv = inp->mctx ->get_attn ()->get_n_kv ();
1003+ const auto n_kv = inp->mctx ->get_attn ()->get_n_kv ();
1004+ const auto n_seqs = cparams.n_seq_virt > 1 ? ubatch.n_seqs_unq : 1 ;
10041005
10051006 inp->self_k_idxs = mctx_cur->get_attn ()->build_input_k_idxs (ctx0, ubatch);
10061007 inp->self_v_idxs = mctx_cur->get_attn ()->build_input_v_idxs (ctx0, ubatch);
10071008
1008- inp->self_kq_mask = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
1009- // cb(inp->self_kq_mask, "KQ_mask", -1);
1009+ inp->self_kq_mask = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens/n_seqs, GGML_KQ_MASK_PAD), 1 , n_seqs);
10101010 ggml_set_input (inp->self_kq_mask );
10111011
10121012 inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask , GGML_TYPE_F16) : inp->self_kq_mask ;
@@ -1033,6 +1033,10 @@ ggml_tensor * llm_graph_context::build_attn_mha(
10331033 float kq_scale) const {
10341034 const bool v_trans = v->nb [1 ] > v->nb [2 ];
10351035
1036+ const auto n_seqs = cparams.n_seq_virt > 1 ? ubatch.n_seqs_unq : 1 ;
1037+
1038+ q = ggml_reshape_4d (ctx0, q, q->ne [0 ], q->ne [1 ], q->ne [2 ]/n_seqs, n_seqs);
1039+
10361040 q = ggml_permute (ctx0, q, 0 , 2 , 1 , 3 );
10371041 k = ggml_permute (ctx0, k, 0 , 2 , 1 , 3 );
10381042 v = ggml_permute (ctx0, v, 0 , 2 , 1 , 3 );
@@ -1081,7 +1085,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
10811085#endif
10821086 }
10831087
1084- cur = ggml_reshape_2d (ctx0, cur, cur->ne [0 ]*n_head, n_tokens);
1088+ cur = ggml_reshape_2d (ctx0, cur, cur->ne [0 ]*n_head, n_tokens*n_seqs );
10851089 } else {
10861090 ggml_tensor * kq = ggml_mul_mat (ctx0, k, q);
10871091
@@ -1126,7 +1130,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
11261130
11271131 cur = ggml_permute (ctx0, kqv, 0 , 2 , 1 , 3 );
11281132
1129- cur = ggml_cont_2d (ctx0, cur, cur->ne [0 ]*n_head, n_tokens);
1133+ cur = ggml_cont_2d (ctx0, cur, cur->ne [0 ]*n_head, n_tokens*n_seqs );
11301134
11311135 if (!cparams.offload_kqv ) {
11321136 // all nodes between the KV store and the attention output are run on the CPU
@@ -1204,12 +1208,13 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
12041208 {
12051209 GGML_ASSERT (hparams.swa_type == LLAMA_SWA_TYPE_NONE && " Use llama_kv_cache_unified_iswa for SWA" );
12061210
1207- const auto n_kv = mctx_cur->get_n_kv ();
1211+ const auto n_kv = mctx_cur->get_n_kv ();
1212+ const auto n_seqs = cparams.n_seq_virt > 1 ? ubatch.n_seqs_unq : 1 ;
12081213
12091214 inp->self_k_idxs = mctx_cur->build_input_k_idxs (ctx0, ubatch);
12101215 inp->self_v_idxs = mctx_cur->build_input_v_idxs (ctx0, ubatch);
12111216
1212- inp->self_kq_mask = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
1217+ inp->self_kq_mask = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens/n_seqs , GGML_KQ_MASK_PAD), 1 , n_seqs );
12131218 ggml_set_input (inp->self_kq_mask );
12141219
12151220 inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask , GGML_TYPE_F16) : inp->self_kq_mask ;
@@ -1451,13 +1456,15 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
14511456
14521457 auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, mctx_cur);
14531458
1459+ const auto n_seqs = cparams.n_seq_virt > 1 ? ubatch.n_seqs_unq : 1 ;
1460+
14541461 {
14551462 const auto n_kv = mctx_cur->get_base ()->get_n_kv ();
14561463
14571464 inp->self_k_idxs = mctx_cur->get_base ()->build_input_k_idxs (ctx0, ubatch);
14581465 inp->self_v_idxs = mctx_cur->get_base ()->build_input_v_idxs (ctx0, ubatch);
14591466
1460- inp->self_kq_mask = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
1467+ inp->self_kq_mask = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens/n_seqs , GGML_KQ_MASK_PAD), 1 , n_seqs );
14611468 ggml_set_input (inp->self_kq_mask );
14621469
14631470 inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask , GGML_TYPE_F16) : inp->self_kq_mask ;
@@ -1471,7 +1478,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
14711478 inp->self_k_idxs_swa = mctx_cur->get_swa ()->build_input_k_idxs (ctx0, ubatch);
14721479 inp->self_v_idxs_swa = mctx_cur->get_swa ()->build_input_v_idxs (ctx0, ubatch);
14731480
1474- inp->self_kq_mask_swa = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
1481+ inp->self_kq_mask_swa = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens/n_seqs , GGML_KQ_MASK_PAD), 1 , n_seqs );
14751482 ggml_set_input (inp->self_kq_mask_swa );
14761483
14771484 inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask_swa , GGML_TYPE_F16) : inp->self_kq_mask_swa ;
0 commit comments