@@ -3377,11 +3377,11 @@ struct test_mul_mat : public test_case {
33773377 const std::array<int64_t , 2 > bs; // dims 3 and 4
33783378 const std::array<int64_t , 2 > nr; // repeat in dims 3 and 4
33793379 const std::array<int64_t , 4 > per; // permutation of dimensions
3380- const bool v ; // whether a and b are non-contiguous views
3380+ const int64_t k_v ; // size of k in memory, resulting in a non-contiguous view for k_v > k, no view for k_v == 0
33813381 const uint32_t o; // number of outputs
33823382
33833383 std::string vars () override {
3384- return VARS_TO_STR10 (type_a, type_b, m, n, k, bs, nr, per, v , o);
3384+ return VARS_TO_STR10 (type_a, type_b, m, n, k, bs, nr, per, k_v , o);
33853385 }
33863386
33873387 double max_nmse_err () override {
@@ -3402,8 +3402,8 @@ struct test_mul_mat : public test_case {
34023402 std::array<int64_t , 2 > bs = {10 , 10 },
34033403 std::array<int64_t , 2 > nr = {2 , 2 },
34043404 std::array<int64_t , 4 > per = {0 , 1 , 2 , 3 },
3405- bool v = false , uint32_t o = 1 )
3406- : type_a(type_a), type_b(type_b), m(m), n(n), k(k), bs(bs), nr(nr), per(per), v(v ), o(o) {}
3405+ int64_t k_v = 0 , uint32_t o = 1 )
3406+ : type_a(type_a), type_b(type_b), m(m), n(n), k(k), bs(bs), nr(nr), per(per), k_v(k_v ), o(o) {}
34073407
34083408 ggml_tensor * build_graph (ggml_context * ctx) override {
34093409 // C^T = A * B^T: (k, m) * (k, n) => (m, n)
@@ -3413,7 +3413,7 @@ struct test_mul_mat : public test_case {
34133413 const int npermuted = (per[0 ] != 0 ) + (per[1 ] != 1 ) + (per[2 ] != 2 ) + (per[3 ] != 3 );
34143414 if (npermuted > 0 ) {
34153415 GGML_ASSERT (npermuted == 2 );
3416- GGML_ASSERT (!v ); // not handled
3416+ GGML_ASSERT (k_v == 0 ); // not handled
34173417 GGML_ASSERT (!ggml_is_quantized (type_a) || per[0 ] == 0 );
34183418 GGML_ASSERT (!ggml_is_quantized (type_b) || per[0 ] == 0 );
34193419
@@ -3437,29 +3437,21 @@ struct test_mul_mat : public test_case {
34373437 ggml_set_name (a, " a_permuted" );
34383438 ggml_set_name (b, " b_permuted" );
34393439 } else {
3440- if (v) {
3441- a = ggml_new_tensor_4d (ctx, type_a, k* 2 , m, bs[0 ], bs[1 ]);
3442- b = ggml_new_tensor_4d (ctx, type_b, k* 2 , n, bs[0 ]*nr[0 ], bs[1 ]*nr[1 ]);
3440+ const int64_t k_physical = k_v == 0 ? k : k_v;
3441+ a = ggml_new_tensor_4d (ctx, type_a, k_physical , m, bs[0 ], bs[1 ]);
3442+ b = ggml_new_tensor_4d (ctx, type_b, k_physical , n, bs[0 ]*nr[0 ], bs[1 ]*nr[1 ]);
34433443
3444- if (!ggml_is_quantized (type_a)) {
3445- if (bs[1 ] == 1 && nr[1 ] == 1 ) {
3446- ggml_set_param (a);
3447- }
3448- ggml_set_param (b);
3444+ if (!ggml_is_quantized (type_a)) {
3445+ if (bs[1 ] == 1 && nr[1 ] == 1 ) {
3446+ ggml_set_param (a);
34493447 }
3448+ ggml_set_param (b);
3449+ }
34503450
3451+ if (k_v != 0 ) {
3452+ GGML_ASSERT (k_v > k);
34513453 a = ggml_view_4d (ctx, a, k, m, bs[0 ], bs[1 ], a->nb [1 ], a->nb [2 ], a->nb [3 ], 0 );
34523454 b = ggml_view_4d (ctx, b, k, n, bs[0 ]*nr[0 ], bs[1 ]*nr[1 ], b->nb [1 ], b->nb [2 ], b->nb [3 ], 0 );
3453- } else {
3454- a = ggml_new_tensor_4d (ctx, type_a, k, m, bs[0 ], bs[1 ]);
3455- b = ggml_new_tensor_4d (ctx, type_b, k, n, bs[0 ]*nr[0 ], bs[1 ]*nr[1 ]);
3456-
3457- if (!ggml_is_quantized (type_a)) {
3458- if (bs[1 ] == 1 && nr[1 ] == 1 ) {
3459- ggml_set_param (a);
3460- }
3461- ggml_set_param (b);
3462- }
34633455 }
34643456 ggml_set_name (a, " a" );
34653457 ggml_set_name (b, " b" );
@@ -6886,7 +6878,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
68866878 test_cases.emplace_back (new test_mul_mat (GGML_TYPE_F16, GGML_TYPE_F32, 128 , 45 , 64 , { 8 , 1 }, {4 , 1 }));
68876879 test_cases.emplace_back (new test_mul_mat (GGML_TYPE_F16, GGML_TYPE_F32, 1056 , 1 , 193 , {1 , 1 }, {4 , 1 }, {0 , 2 , 1 , 3 }));
68886880 test_cases.emplace_back (new test_mul_mat (GGML_TYPE_F16, GGML_TYPE_F32, 1056 , 1 , 67 , {1 , 1 }, {4 , 1 }, {0 , 2 , 1 , 3 }));
6889- test_cases.emplace_back (new test_mul_mat (GGML_TYPE_F32, GGML_TYPE_F32, 16 , 32 , 32 , { 1 , 1 }, {1 , 1 }, {0 , 1 , 2 , 3 }, true , 3 ));
6881+ test_cases.emplace_back (new test_mul_mat (GGML_TYPE_F32, GGML_TYPE_F32, 16 , 32 , 32 , { 1 , 1 }, {1 , 1 }, {0 , 1 , 2 , 3 }, 64 , 3 ));
68906882 test_cases.emplace_back (new test_mul_mat (GGML_TYPE_F32, GGML_TYPE_F32, 64 , 77 , 77 , {12 ,1 }, {1 ,1 }));
68916883
68926884#if 0
@@ -6912,7 +6904,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
69126904 for (uint32_t k = 0 ; k < 2 ; ++k) {
69136905 for (ggml_type type: {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_F32}) {
69146906 test_cases.emplace_back (new test_mul_mat (type, GGML_TYPE_F32, 1056 + m, 1 , 128 + k, {bs, bs2}, {nr, 1 }, {0 , 2 , 1 , 3 }));
6915- test_cases.emplace_back (new test_mul_mat (type, GGML_TYPE_F32, 128 + m, 1 , 1056 + k, {bs, bs2}, {nr, 1 }, {0 , 1 , 2 , 3 }, true ));
6907+ test_cases.emplace_back (new test_mul_mat (type, GGML_TYPE_F32, 128 + m, 1 , 1056 + k, {bs, bs2}, {nr, 1 }, {0 , 1 , 2 , 3 }, 2 * 1056 + k ));
69166908 }
69176909 }
69186910 }
@@ -7405,7 +7397,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
74057397 test_cases.emplace_back (new test_pad_reflect_1d (GGML_TYPE_F32, {3000 , 384 , 4 , 1 }));
74067398
74077399 test_cases.emplace_back (new test_mul_mat (GGML_TYPE_F16, GGML_TYPE_F32, 16416 , 1 , 128 , {8 , 1 }, {4 , 1 }, {0 , 2 , 1 , 3 }));
7408- test_cases.emplace_back (new test_mul_mat (GGML_TYPE_F16, GGML_TYPE_F32, 128 , 1 , 16416 , {8 , 1 }, {4 , 1 }, {0 , 1 , 2 , 3 }, true ));
7400+ test_cases.emplace_back (new test_mul_mat (GGML_TYPE_F16, GGML_TYPE_F32, 128 , 1 , 16416 , {8 , 1 }, {4 , 1 }, {0 , 1 , 2 , 3 }, 2 * 16416 ));
74097401
74107402 for (int bs : {1 , 2 , 3 , 4 , 5 , 8 , 512 }) {
74117403 for (ggml_type type_a : all_types) {
0 commit comments