@@ -258,6 +258,48 @@ TEST_F(VulkanComputeAPITest, calculate_tensor_strides_test) {
258258  }
259259}
260260
261+ TEST_F (VulkanComputeAPITest, virtual_transpose_test) {
262+   std::vector<int64_t > sizes = {7 , 9 , 11 , 13 };
263+   //  (dim0, dim1), new_sizes, new_dim_order, new_axis_map, new_packed_dim_idx
264+   std::vector<std::vector<std::vector<int64_t >>> test_cases = {
265+       {{2 , 3 }, {7 , 9 , 13 , 11 }, {0 , 1 , 3 , 2 }, {1 , 0 , 2 , 2 }, {1 }},
266+       {{2 , 1 }, {7 , 11 , 9 , 13 }, {0 , 2 , 1 , 3 }, {0 , 2 , 1 , 2 }, {0 }},
267+       {{1 , 3 }, {7 , 13 , 11 , 9 }, {0 , 3 , 2 , 1 }, {2 , 1 , 0 , 2 }, {2 }},
268+   };
269+ 
270+   for  (const  auto & test_case : test_cases) {
271+     const  int  dim0 = test_case.at (0 ).at (0 );
272+     const  int  dim1 = test_case.at (0 ).at (1 );
273+ 
274+     const  auto & expected_sizes = test_case.at (1 );
275+     const  auto & expected_dim_order = test_case.at (2 );
276+     const  auto & expected_axis_map = test_case.at (3 );
277+     const  int  expected_packed_dim = test_case.at (4 ).at (0 );
278+ 
279+     {
280+       vTensor a_buffer = vTensor (
281+           context (), sizes, vkapi::kFloat , utils::kBuffer , utils::kWidthPacked );
282+ 
283+       a_buffer.virtual_transpose (dim0, dim1);
284+       EXPECT_TRUE (a_buffer.sizes () == expected_sizes);
285+       EXPECT_TRUE (a_buffer.dim_order () == expected_dim_order);
286+     }
287+ 
288+     {
289+       vTensor a_texture = vTensor (
290+           context (),
291+           sizes,
292+           vkapi::kFloat ,
293+           utils::kTexture3D ,
294+           utils::kWidthPacked );
295+       a_texture.virtual_transpose (dim0, dim1);
296+       EXPECT_TRUE (a_texture.sizes () == expected_sizes);
297+       EXPECT_TRUE (a_texture.axis_map () == expected_axis_map);
298+       EXPECT_TRUE (a_texture.packed_dim_whcn_idx () == expected_packed_dim);
299+     }
300+   }
301+ }
302+ 
261303TEST_F (VulkanComputeAPITest, vec_test) {
262304  utils::vec3 v3 ({1 , 2 , 3 });
263305  ASSERT_TRUE (v3[0 ] == 1 );
@@ -637,46 +679,58 @@ TEST_F(VulkanComputeAPITest, tensor_no_copy_transpose_test) {
637679  constexpr  int  N = 17 ;
638680  std::vector<int64_t > mat1_sizes = {M, K};
639681  std::vector<int64_t > mat2_sizes = {N, K};
640-   std::vector<int64_t > mat2_t_sizes = {K, N};
641682  std::vector<int64_t > out_sizes = {M, N};
642683
643-   std::vector<int64_t > transposed_dim_order = {1 , 0 };
644- 
645-   vTensor mat1 = CREATE_FLOAT_BUFFER (mat1_sizes, /* allocate_memory=*/ true );
646-   vTensor mat2 = CREATE_FLOAT_BUFFER (mat2_sizes, /* allocate_memory=*/ true );
647-   vTensor out = CREATE_FLOAT_BUFFER (out_sizes, /* allocate_memory=*/ true );
648- 
649-   //  Generate data
650-   std::vector<float > mat1_data =
651-       create_random_float_buffer (mat1.staging_buffer_numel ());
652-   std::vector<float > mat2_data =
653-       create_random_float_buffer (mat2.staging_buffer_numel ());
654- 
655-   //  Create direct view and modify sizes and strides later
656-   vTensor mat2_t  = vTensor (mat2);
657- 
658-   std::vector<float > mat2_t_data = transpose_matrix (mat2_data, N, K);
659-   std::vector<float > ref_out =
660-       compute_reference_matmul (mat1_data, mat2_t_data, M, K, N);
661- 
662-   //  Fill original tensor with some data
663-   fill_vtensor (mat1, mat1_data);
664-   fill_vtensor (mat2, mat2_data);
665- 
666-   record_reference_matmul (api::context (), out, mat1, mat2_t );
684+   for  (const  auto  storage_type : {utils::kTexture3D , utils::kBuffer }) {
685+     vTensor mat1 = vTensor (
686+         context (),
687+         mat1_sizes,
688+         vkapi::kFloat ,
689+         storage_type,
690+         utils::kWidthPacked );
691+     vTensor mat2 = vTensor (
692+         context (),
693+         mat2_sizes,
694+         vkapi::kFloat ,
695+         storage_type,
696+         utils::kWidthPacked );
697+     vTensor out = vTensor (
698+         context (), out_sizes, vkapi::kFloat , storage_type, utils::kWidthPacked );
699+ 
700+     //  Generate data
701+     std::vector<float > mat1_data =
702+         create_random_float_buffer (mat1.staging_buffer_numel ());
703+     std::vector<float > mat2_data =
704+         create_random_float_buffer (mat2.staging_buffer_numel ());
705+ 
706+     //  Create direct view and modify sizes and strides later
707+     vTensor mat2_t  = vTensor (mat2);
708+     //  Update sizes and strides of mat2_t to be that of a transposed tensor
709+     mat2_t .virtual_transpose (0 , 1 );
710+ 
711+     EXPECT_TRUE (mat2_t .gpu_memory_layout () == utils::kHeightPacked );
712+ 
713+     std::vector<float > mat2_t_data = transpose_matrix (mat2_data, N, K);
714+     std::vector<float > ref_out =
715+         compute_reference_matmul (mat1_data, mat2_t_data, M, K, N);
667716
668-   //  Update sizes and strides of mat2_t to be that of a transposed tensor 
669-   mat2_t . virtual_reconfigure (mat2_t_sizes, transposed_dim_order );
670-   EXPECT_TRUE ( mat2_t . gpu_memory_layout () == utils:: kHeightPacked );
717+      //  Fill original tensor with some data 
718+      fill_vtensor (mat1, mat1_data );
719+      fill_vtensor (mat2, mat2_data );
671720
672-   std::vector<float > data_out (out.staging_buffer_numel ());
673-   //  Extract the copy tensor; should contain the data of the original tensor
674-   extract_vtensor (out, data_out);
721+     if  (storage_type == utils::kTexture3D ) {
722+       record_matmul_texture3d (context (), out, mat1, mat2_t );
723+     } else  {
724+       record_reference_matmul (context (), out, mat1, mat2_t );
725+     }
675726
676-   EXPECT_TRUE (data_out.size () == ref_out.size ());
727+     std::vector<float > data_out (out.staging_buffer_numel ());
728+     //  Extract the copy tensor; should contain the data of the original tensor
729+     extract_vtensor (out, data_out);
677730
678-   for  (size_t  i = 0 ; i < data_out.size (); ++i) {
679-     EXPECT_TRUE (check_close (data_out[i], ref_out[i]));
731+     for  (size_t  i = 0 ; i < ref_out.size (); ++i) {
732+       EXPECT_TRUE (check_close (data_out[i], ref_out[i]));
733+     }
680734  }
681735}
682736
0 commit comments