@@ -3707,6 +3707,7 @@ struct test_im2col : public test_case {
37073707struct test_conv_2d : public test_case {
37083708 const std::array<int64_t , 4 > ne_input;
37093709 const std::array<int64_t , 4 > ne_kernel;
3710+ const ggml_type type_kernel;
37103711 const int stride0;
37113712 const int stride1;
37123713 const int padding0;
@@ -3724,7 +3725,7 @@ struct test_conv_2d : public test_case {
37243725 // IM2COL -> MUL_MM graph will be built.
37253726
37263727 std::string vars () override {
3727- return VARS_TO_STR9 (ne_input, ne_kernel, stride0, stride1, padding0, padding1, dilation0, dilation1, cwhn);
3728+ return VARS_TO_STR10 (ne_input, ne_kernel, type_kernel , stride0, stride1, padding0, padding1, dilation0, dilation1, cwhn);
37283729 }
37293730
37303731 uint64_t op_flops (ggml_tensor * t) override {
@@ -3755,10 +3756,11 @@ struct test_conv_2d : public test_case {
37553756 }
37563757
37573758 test_conv_2d (std::array<int64_t , 4 > ne_input = { 64 , 64 , 16 , 1 },
3758- std::array<int64_t , 4 > ne_kernel = { 3 , 3 , 1 , 16 }, int stride0 = 1 , int stride1 = 1 , int padding0 = 0 ,
3759- int padding1 = 0 , int dilation0 = 1 , int dilation1 = 1 , bool cwhn = false ) :
3759+ std::array<int64_t , 4 > ne_kernel = { 3 , 3 , 1 , 16 }, ggml_type type_kernel = GGML_TYPE_F32 , int stride0 = 1 ,
3760+ int stride1 = 1 , int padding0 = 0 , int padding1 = 0 , int dilation0 = 1 , int dilation1 = 1 , bool cwhn = false ) :
37603761 ne_input (ne_input),
37613762 ne_kernel (ne_kernel),
3763+ type_kernel (type_kernel),
37623764 stride0 (stride0),
37633765 stride1 (stride1),
37643766 padding0 (padding0),
@@ -3771,7 +3773,7 @@ struct test_conv_2d : public test_case {
37713773 ggml_tensor * input = ggml_new_tensor (ctx, GGML_TYPE_F32, 4 , ne_input.data ());
37723774 ggml_set_name (input, " input" );
37733775
3774- ggml_tensor * kernel = ggml_new_tensor (ctx, GGML_TYPE_F32 , 4 , ne_kernel.data ());
3776+ ggml_tensor * kernel = ggml_new_tensor (ctx, type_kernel , 4 , ne_kernel.data ());
37753777 ggml_set_name (kernel, " kernel" );
37763778
37773779 if (cwhn) {
@@ -5138,10 +5140,13 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
51385140 { 16 , 3 , 256 , 128 , 8 }
51395141 };
51405142
5141- for (auto act_case : cases) {
5142- test_cases.emplace_back (new test_conv_2d (
5143- { act_case[iwh_idx], act_case[iwh_idx], act_case[Cin_idx], act_case[B_idx] },
5144- { act_case[kwh_idx], act_case[kwh_idx], act_case[Cin_idx], act_case[Cout_idx] }, 1 , 1 , 0 , 0 , 1 , 1 , false ));
5143+ for (auto kernel_type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
5144+ for (auto act_case : cases) {
5145+ test_cases.emplace_back (new test_conv_2d (
5146+ { act_case[iwh_idx], act_case[iwh_idx], act_case[Cin_idx], act_case[B_idx] },
5147+ { act_case[kwh_idx], act_case[kwh_idx], act_case[Cin_idx], act_case[Cout_idx] },
5148+ kernel_type, 1 , 1 , 0 , 0 , 1 , 1 , false ));
5149+ }
51455150 }
51465151#endif
51475152
@@ -5167,8 +5172,10 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
51675172 for (uint32_t W : { 1 , 141 }) {
51685173 if (calc_conv_output_size (W, KW, s0, p0, d0) > 0 &&
51695174 calc_conv_output_size (H, KH, s1, p1, d1) > 0 ) {
5170- test_cases.emplace_back (new test_conv_2d (
5171- { W, H, Cin, 2 }, { KW, KH, Cin, Cout }, s0, s1, p0, p1, d0, d1, false ));
5175+ for (auto kernel_type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
5176+ test_cases.emplace_back (new test_conv_2d (
5177+ { W, H, Cin, 2 }, { KW, KH, Cin, Cout }, kernel_type, s0, s1, p0, p1, d0, d1, false ));
5178+ }
51725179 }
51735180 }
51745181 }
@@ -5813,11 +5820,14 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
58135820 { 16 , 3 , 512 , 128 , 8 },
58145821 };
58155822
5816- for (auto act_case : cases) {
5817- // Direct CONV_2D
5818- test_cases.emplace_back (new test_conv_2d (
5819- { act_case[iwh_idx], act_case[iwh_idx], act_case[Cin_idx], act_case[B_idx] },
5820- { act_case[kwh_idx], act_case[kwh_idx], act_case[Cin_idx], act_case[Cout_idx] }, 1 , 1 , 0 , 0 , 1 , 1 , false ));
5823+ for (auto kernel_type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
5824+ for (auto act_case : cases) {
5825+ // Direct CONV_2D
5826+ test_cases.emplace_back (new test_conv_2d (
5827+ { act_case[iwh_idx], act_case[iwh_idx], act_case[Cin_idx], act_case[B_idx] },
5828+ { act_case[kwh_idx], act_case[kwh_idx], act_case[Cin_idx], act_case[Cout_idx] },
5829+ kernel_type, 1 , 1 , 0 , 0 , 1 , 1 , false ));
5830+ }
58215831 }
58225832
58235833 test_cases.emplace_back (new test_bin_bcast (ggml_add, GGML_TYPE_F32, {4096 , 1 , 1 , 1 }, {1 , 1 , 1 , 1 }));
0 commit comments