Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 36 additions & 18 deletions src/kernels/gemm_kernel_ext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -296,9 +296,7 @@ void small_gemm_transb_1xn_fixk(const TA *A, const TB *B, float *C, int N, int l
const TB *pB = B + i * ldb;

__m512 vc[BC];
compile_time_for<BC>::op([&](auto idx) {
vc[idx] = _mm512_set1_ps(0);
});
compile_time_for<BC>::op([&](auto idx) { vc[idx] = _mm512_set1_ps(0); });

for (int j = 0; j < vecs; ++j) {
__mmask16 m = (j == vecs - 1 ? mask : 0xffff);
Expand All @@ -309,16 +307,14 @@ void small_gemm_transb_1xn_fixk(const TA *A, const TB *B, float *C, int N, int l
}

// Store to C
compile_time_for<BC>::op([&](auto idx) {
C[i + idx] = _mm512_reduce_add_ps(vc[idx]);
});
compile_time_for<BC>::op([&](auto idx) { C[i + idx] = _mm512_reduce_add_ps(vc[idx]); });
}

// Remain elements
for (; i < N; ++i) {
const TB *pB = B + i * ldb;
__m512 vc = _mm512_set1_ps(0);

for (int j = 0; j < vecs; ++j) {
__mmask16 m = (j == vecs - 1 ? mask : 0xffff);
__m512 vb = xft::load_avx512(m, pB + j * 16);
Expand All @@ -343,9 +339,7 @@ void small_gemm_transb_1xn_dynk(const TA *A, const TB *B, float *C, int N, int K
const TB *pB = B + i * ldb;

__m512 vc[BC];
compile_time_for<BC>::op([&](auto idx) {
vc[idx] = _mm512_set1_ps(0);
});
compile_time_for<BC>::op([&](auto idx) { vc[idx] = _mm512_set1_ps(0); });

for (int j = 0; j < vecs; ++j) {
__mmask16 m = (j == vecs - 1 ? mask : 0xffff);
Expand All @@ -357,17 +351,15 @@ void small_gemm_transb_1xn_dynk(const TA *A, const TB *B, float *C, int N, int K
}

// Store to C
compile_time_for<BC>::op([&](auto idx) {
C[i + idx] = _mm512_reduce_add_ps(vc[idx]);
});
compile_time_for<BC>::op([&](auto idx) { C[i + idx] = _mm512_reduce_add_ps(vc[idx]); });
}

// Remain elements
for (; i < N; ++i) {
const TA *pA = A + i * ldb;
const TB *pB = B + i * ldb;
__m512 vc = _mm512_set1_ps(0);

for (int j = 0; j < vecs; ++j) {
__mmask16 m = (j == vecs - 1 ? mask : 0xffff);
__m512 va = xft::load_avx512(m, pA + j * 16);
Expand Down Expand Up @@ -426,9 +418,7 @@ void small_gemm_transb(const TA *A, const TB *B, float *C, int M, int N, int K,
constexpr int MB = 6;

// Special case for M = 1
if (M == 1) {
return small_gemm_transb_1xn(A, B, C, N, K, lda, ldb, ldc);
}
if (M == 1) { return small_gemm_transb_1xn(A, B, C, N, K, lda, ldb, ldc); }

for (i = 0; i + MB - 1 < M; i += MB) {
const TA *pA = A + i * lda;
Expand Down Expand Up @@ -542,7 +532,8 @@ void small_gemm_transb(const float *A, const float16_t *B, float *C, int M, int
small_gemm_transb<float, float16_t>(A, B, C, M, N, K, lda, ldb, ldc);
}

void small_gemm_transb(const bfloat16_t *A, const float16_t *B, float *C, int M, int N, int K, int lda, int ldb, int ldc) {
void small_gemm_transb(
const bfloat16_t *A, const float16_t *B, float *C, int M, int N, int K, int lda, int ldb, int ldc) {
small_gemm_transb<bfloat16_t, float16_t>(A, B, C, M, N, K, lda, ldb, ldc);
}

Expand All @@ -569,4 +560,31 @@ void small_gemm_transb(const float *attnMask, const bfloat16_t *A, const bfloat1
void small_gemm_transb(const float *attnMask, const bfloat16_t *A, const float16_t *B, float *C, int M, int N, int K,
int lda, int ldb, int ldc) {
small_gemm_transb<bfloat16_t, float16_t>(attnMask, A, B, C, M, N, K, lda, ldb, ldc);
}

////////////////////////////////////////////////////////////////////////////////

static void apply_scale(float *C, const float *scale, int M, int N, int ldc) {
for (int i = 0; i < M; ++i) {
for (int j = 0; j < N; j += 16) {
int remain = N - j;
__mmask16 mask = (remain >= 16 ? 0xffff : (1 << remain) - 1);
__m512 v = xft::load_avx512(mask, &C[i * ldc + j]);
__m512 scaleVec = xft::load_avx512(mask, scale + j);
v = v * scaleVec;
xft::store_avx512(&C[i * ldc + j], mask, v);
}
}
}

void small_gemm_transb(const float *A, const int8_t *B, const float *bScale, float *C, int M, int N, int K, int lda,
int ldb, int ldc) {
small_gemm_transb<float, int8_t>(A, B, C, M, N, K, lda, ldb, ldc);
if (bScale) { apply_scale(C, bScale, M, N, ldc); }
}

void small_gemm_transb(const bfloat16_t *A, const int8_t *B, const float *bScale, float *C, int M, int N, int K,
int lda, int ldb, int ldc) {
small_gemm_transb<bfloat16_t, int8_t>(A, B, C, M, N, K, lda, ldb, ldc);
if (bScale) { apply_scale(C, bScale, M, N, ldc); }
}
18 changes: 15 additions & 3 deletions src/kernels/gemm_kernel_ext.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,25 @@
// limitations under the License.
// ============================================================================
#pragma once
#include <cstdint>
#include <cstdio>
#include <typeinfo>

#include "bfloat16.h"
#include "float16.h"
#include "sgemm.h"
#include "sgemm_f32f16f32.h"
#include "sgemm_f32f16bf16.h"
#include "sgemm_f32f16f32.h"

// Single thread small gemm
void small_gemm_transb(const float *A, const float *B, float *C, int M, int N, int K, int lda, int ldb, int ldc);
void small_gemm_transb(const float *A, const float16_t *B, float *C, int M, int N, int K, int lda, int ldb, int ldc);
void small_gemm_transb(const bfloat16_t *A, const float16_t *B, float *C, int M, int N, int K, int lda, int ldb, int ldc);
void small_gemm_transb(
const bfloat16_t *A, const float16_t *B, float *C, int M, int N, int K, int lda, int ldb, int ldc);
void small_gemm_transb(const float *A, const int8_t *B, const float *bScale, float *C, int M, int N, int K,
int lda, int ldb, int ldc);
void small_gemm_transb(const bfloat16_t *A, const int8_t *B, const float *bScale, float *C, int M, int N, int K,
int lda, int ldb, int ldc);

// Single thread small gemm with attention mask (skip skippable computation according to attnMask)
void small_gemm_transb(const float *attnMask, const float *A, const float *B, float *C, int M, int N, int K, int lda,
Expand Down Expand Up @@ -61,7 +67,13 @@ inline void small_gemm(const float *A, const float16_t *B, float *C, int M, int
}

template <>
inline void small_gemm(const float *A, const float16_t *B, bfloat16_t *C, int M, int N, int K, int lda, int ldb, int ldc) {
inline void small_gemm(
const float *A, const float16_t *B, bfloat16_t *C, int M, int N, int K, int lda, int ldb, int ldc) {
small_sgemm_f32f16bf16(false, M, N, K, A, lda, (const XDNN_FP16 *)B, ldb, (XDNN_BF16 *)C, ldc);
}

void small_gemm(const float *A, const int8_t *B, const float *bScale, float *C, int M, int N, int K, int lda,
int ldb, int ldc);
void small_gemm(const float *A, const int8_t *B, const float *bScale, bfloat16_t *C, int M, int N, int K, int lda,
int ldb, int ldc);
} // namespace xft
Loading