@@ -4366,26 +4366,32 @@ struct test_flash_attn_ext : public test_case {
43664366 const int64_t hsk_padded = GGML_PAD (hsk, ggml_blck_size (type_KV));
43674367 const int64_t hsv_padded = GGML_PAD (hsv, ggml_blck_size (type_KV));
43684368
4369- auto const &create_permuted = [&](ggml_type type, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3) -> ggml_tensor * {
4369+ auto const &create_permuted = [&](ggml_type type, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3, bool is_view ) -> ggml_tensor * {
43704370 int64_t ne[4 ] = {ne0, ne1, ne2, ne3};
43714371 int64_t ne_perm[4 ];
43724372 for (int i = 0 ; i < 4 ; ++i) {
43734373 ne_perm[permute[i]] = ne[i];
43744374 }
4375- ggml_tensor * t = ggml_new_tensor_4d (ctx, type, ne_perm[0 ], ne_perm[1 ], ne_perm[2 ], ne_perm[3 ]);
4375+ ggml_tensor * t;
4376+ if (is_view) {
4377+ ggml_tensor * t0 = ggml_new_tensor_4d (ctx, type, ne_perm[0 ], 2 *ne_perm[1 ], ne_perm[2 ], ne_perm[3 ]);
4378+ t = ggml_view_4d (ctx, t0, ne_perm[0 ], ne_perm[1 ], ne_perm[2 ], ne_perm[3 ], t0->nb [1 ], t0->nb [2 ], t0->nb [3 ], 0 );
4379+ } else {
4380+ t = ggml_new_tensor_4d (ctx, type, ne_perm[0 ], ne_perm[1 ], ne_perm[2 ], ne_perm[3 ]);
4381+ }
43764382 if (permute != std::array<int32_t , 4 >{0 , 1 , 2 , 3 }) {
43774383 t = ggml_permute (ctx, t, permute[0 ], permute[1 ], permute[2 ], permute[3 ]);
43784384 }
43794385 return t;
43804386 };
43814387
4382- ggml_tensor * q = create_permuted (GGML_TYPE_F32, hsk_padded, nb, nh*nr23[0 ], nr23[1 ]);
4388+ ggml_tensor * q = create_permuted (GGML_TYPE_F32, hsk_padded, nb, nh*nr23[0 ], nr23[1 ], false );
43834389 ggml_set_name (q, " q" );
43844390
4385- ggml_tensor * k = create_permuted (type_KV, hsk_padded, kv, nh, nr23[1 ]);
4391+ ggml_tensor * k = create_permuted (type_KV, hsk_padded, kv, nh, nr23[1 ], true ); // the K tensor is usually a view of the K cache
43864392 ggml_set_name (k, " k" );
43874393
4388- ggml_tensor * v = create_permuted (type_KV, hsv_padded, kv, nh, nr23[1 ]);
4394+ ggml_tensor * v = create_permuted (type_KV, hsv_padded, kv, nh, nr23[1 ], true ); // the V tensor is usually a view of the V cache
43894395 ggml_set_name (v, " v" );
43904396
43914397 ggml_tensor * m = nullptr ;
0 commit comments