1919
2020#include " gtest/gtest.h"
2121
22+ template <typename TA, typename TB, typename TC>
2223static void small_gemm_tranb_ref (
23- const float *A, const float *B, float *C, int M, int N, int K, int lda, int ldb, int ldc) {
24+ const TA *A, const TB *B, TC *C, int M, int N, int K, int lda, int ldb, int ldc) {
2425 // Loop over the rows of A
2526 for (int i = 0 ; i < M; i++) {
2627 // Loop over the columns of B
2728 for (int j = 0 ; j < N; j++) {
2829 // Compute the dot product of row i of A with column j of B
2930 float dot_product = 0 ;
3031 for (int k = 0 ; k < K; k++) {
31- dot_product += A[i * lda + k] * B[j * ldb + k];
32+ dot_product += ( float ) A[i * lda + k] * ( float ) B[j * ldb + k];
3233 }
3334 // Store the result in C[i][j]
3435 C[i * ldc + j] = dot_product;
@@ -54,13 +55,14 @@ static void small_gemm_tranb_ref(
5455}
5556
5657// Test function to compare reference and optimized implementations
58+ template <typename TA = float , typename TB = float , typename TC = float >
5759void test_small_gemm_tranb (int M, int N, int K) {
58- float *A_ref = new float [M * K];
59- float *B_ref = new float [K * N];
60- float *C_ref = new float [M * N];
61- float *A_opt = new float [M * K];
62- float *B_opt = new float [K * N];
63- float *C_opt = new float [M * N];
60+ TA *A_ref = new TA [M * K];
61+ TB *B_ref = new TB [K * N];
62+ TC *C_ref = new TC [M * N];
63+ TA *A_opt = new TA [M * K];
64+ TB *B_opt = new TB [K * N];
65+ TC *C_opt = new TC [M * N];
6466
6567 // Generate random matrices A and B
6668 std::random_device dev;
@@ -262,6 +264,12 @@ TEST(small_gemm_tranb, small_gemm_tranb_f32) {
262264 test_bigger_kernel ();
263265}
264266
267+ TEST (small_gemm_tranb, small_gemm_tranb_bf16fp16f32) {
268+ test_small_gemm_tranb<bfloat16_t , float16_t , float >(1 , 2 , 16 );
269+ test_small_gemm_tranb<bfloat16_t , float16_t , float >(1 , 4 , 128 );
270+ test_small_gemm_tranb<bfloat16_t , float16_t , float >(1 , 4 , 256 );
271+ }
272+
265273TEST (small_gemm_tranb, small_gemm_tranb_int8) {
266274 test_small_gemm_tranb_int8 (1 , 100 , 128 );
267275 test_small_gemm_tranb_int8 (2 , 101 , 256 );
0 commit comments