Skip to content

Commit 15451f2

Browse files
authored
[Kernel] Bug fix for small_gemm_transb (#318)
1 parent 5349b3b commit 15451f2

File tree

2 files changed

+18
-10
lines changed

2 files changed

+18
-10
lines changed

src/kernels/gemm_kernel_ext.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,7 @@ void small_gemm_transb_1xn_dynk(const TA *A, const TB *B, float *C, int N, int K
335335
// Each loop compute 'BC' elements in C
336336
int i = 0;
337337
for (; i + BC - 1 < N; i += BC) {
338-
const TA *pA = A + i * ldb;
338+
const TA *pA = A;
339339
const TB *pB = B + i * ldb;
340340

341341
__m512 vc[BC];
@@ -356,7 +356,7 @@ void small_gemm_transb_1xn_dynk(const TA *A, const TB *B, float *C, int N, int K
356356

357357
// Remain elements
358358
for (; i < N; ++i) {
359-
const TA *pA = A + i * ldb;
359+
const TA *pA = A;
360360
const TB *pB = B + i * ldb;
361361
__m512 vc = _mm512_set1_ps(0);
362362

tests/ut/gemm_kernel_ext_test.cpp

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,17 @@
1919

2020
#include "gtest/gtest.h"
2121

22+
template <typename TA, typename TB, typename TC>
2223
static void small_gemm_tranb_ref(
23-
const float *A, const float *B, float *C, int M, int N, int K, int lda, int ldb, int ldc) {
24+
const TA *A, const TB *B, TC *C, int M, int N, int K, int lda, int ldb, int ldc) {
2425
// Loop over the rows of A
2526
for (int i = 0; i < M; i++) {
2627
// Loop over the columns of B
2728
for (int j = 0; j < N; j++) {
2829
// Compute the dot product of row i of A with column j of B
2930
float dot_product = 0;
3031
for (int k = 0; k < K; k++) {
31-
dot_product += A[i * lda + k] * B[j * ldb + k];
32+
dot_product += (float)A[i * lda + k] * (float)B[j * ldb + k];
3233
}
3334
// Store the result in C[i][j]
3435
C[i * ldc + j] = dot_product;
@@ -54,13 +55,14 @@ static void small_gemm_tranb_ref(
5455
}
5556

5657
// Test function to compare reference and optimized implementations
58+
template <typename TA = float, typename TB = float, typename TC = float>
5759
void test_small_gemm_tranb(int M, int N, int K) {
58-
float *A_ref = new float[M * K];
59-
float *B_ref = new float[K * N];
60-
float *C_ref = new float[M * N];
61-
float *A_opt = new float[M * K];
62-
float *B_opt = new float[K * N];
63-
float *C_opt = new float[M * N];
60+
TA *A_ref = new TA[M * K];
61+
TB *B_ref = new TB[K * N];
62+
TC *C_ref = new TC[M * N];
63+
TA *A_opt = new TA[M * K];
64+
TB *B_opt = new TB[K * N];
65+
TC *C_opt = new TC[M * N];
6466

6567
// Generate random matrices A and B
6668
std::random_device dev;
@@ -262,6 +264,12 @@ TEST(small_gemm_tranb, small_gemm_tranb_f32) {
262264
test_bigger_kernel();
263265
}
264266

267+
TEST(small_gemm_tranb, small_gemm_tranb_bf16fp16f32) {
268+
test_small_gemm_tranb<bfloat16_t, float16_t, float>(1, 2, 16);
269+
test_small_gemm_tranb<bfloat16_t, float16_t, float>(1, 4, 128);
270+
test_small_gemm_tranb<bfloat16_t, float16_t, float>(1, 4, 256);
271+
}
272+
265273
TEST(small_gemm_tranb, small_gemm_tranb_int8) {
266274
test_small_gemm_tranb_int8(1, 100, 128);
267275
test_small_gemm_tranb_int8(2, 101, 256);

0 commit comments

Comments
 (0)