Skip to content

Commit d732874

Browse files
committed
tinyblas dynamic dispaching
1 parent 3f2bc65 commit d732874

File tree

3 files changed

+72
-80
lines changed

3 files changed

+72
-80
lines changed

ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7419,14 +7419,14 @@ static void ggml_compute_forward_mul_mat(
74197419
if (src1_cont) {
74207420
for (int64_t i13 = 0; i13 < ne13; i13++)
74217421
for (int64_t i12 = 0; i12 < ne12; i12++)
7422-
if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type),
7422+
if (!llamafile_sgemm(params,
7423+
ne01, ne11, ne00/ggml_blck_size(src0->type),
74237424
(const char *)src0->data + i12/r2*nb02 + i13/r3*nb03,
74247425
nb01/ggml_type_size(src0->type),
74257426
(const char *)src1->data + i12*nb12 + i13*nb13,
74267427
nb11/ggml_type_size(src1->type),
74277428
(char *)dst->data + i12*nb2 + i13*nb3,
74287429
nb1/ggml_type_size(dst->type),
7429-
ith, nth,
74307430
src0->type,
74317431
src1->type,
74327432
dst->type))
@@ -7471,14 +7471,14 @@ UseGgmlGemm1:;
74717471

74727472
for (int64_t i13 = 0; i13 < ne13; i13++)
74737473
for (int64_t i12 = 0; i12 < ne12; i12++)
7474-
if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type),
7474+
if (!llamafile_sgemm(params,
7475+
ne01, ne11, ne00/ggml_blck_size(src0->type),
74757476
(const char *)src0->data + i12/r2*nb02 + i13/r3*nb03,
74767477
nb01/ggml_type_size(src0->type),
74777478
(const char *)wdata + (i12*ne11 + i13*ne12*ne11)*row_size,
74787479
row_size/ggml_type_size(vec_dot_type),
74797480
(char *)dst->data + i12*nb2 + i13*nb3,
74807481
nb1/ggml_type_size(dst->type),
7481-
ith, nth,
74827482
src0->type,
74837483
vec_dot_type,
74847484
dst->type))

ggml/src/ggml-cpu/llamafile/sgemm.cpp

Lines changed: 66 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@
5353
#include "ggml-cpu-impl.h"
5454
#include "ggml-quants.h"
5555

56+
#include <atomic>
57+
5658
#ifdef _MSC_VER
5759
#define NOINLINE __declspec(noinline)
5860
#else
@@ -298,23 +300,18 @@ static int64_t BLOCK_SIZE(size_t m) {
298300
template <int KN, typename D, typename V, typename TA, typename TB, typename TC>
299301
class tinyBLAS {
300302
public:
301-
tinyBLAS(int64_t k,
303+
tinyBLAS(const ggml_compute_params * params, int64_t k,
302304
const TA *A, int64_t lda,
303305
const TB *B, int64_t ldb,
304-
TC *C, int64_t ldc,
305-
int ith, int nth)
306-
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
306+
TC *C, int64_t ldc)
307+
: params(params), A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc) {
307308
}
308309

309310
bool matmul(int64_t m, int64_t n) {
310311
if (k % KN != 0)
311312
return false;
312313
// compute RN/RM for only tile with size RN&RN-1/RM&RM-1
313314
#if VECTOR_REGISTERS == 32
314-
if (m % 8 == 0 && n < 4) {
315-
mnpack<8, 3, 1>(m, n, n);
316-
return true;
317-
}
318315
if (m % 16 == 0) {
319316
const int64_t SIZE_N = BLOCK_SIZE<6>(n);
320317
mnpack<4, 6, 4>(m, n, SIZE_N);
@@ -331,10 +328,6 @@ class tinyBLAS {
331328
return true;
332329
}
333330
#else // VECTOR_REGISTERS == 16
334-
if (m % 8 == 0 && n == 1) {
335-
gemm<8, 1, 1>(m, n);
336-
return true;
337-
}
338331
if (m % 8 == 0) {
339332
const int64_t SIZE_N = BLOCK_SIZE<3>(n);
340333
mnpack<4, 3, 2>(m, n, SIZE_N);
@@ -400,39 +393,47 @@ class tinyBLAS {
400393
template <int RM, int RN, int BM>
401394
NOINLINE void gemm(int64_t m, int64_t n) {
402395
GGML_ASSERT(m % (RM * BM) == 0);
403-
const int64_t ytiles = m / (RM * BM);
396+
// const int64_t ytiles = m / (RM * BM);
404397
const int64_t xtiles = (n + RN -1) / RN;
405-
const int64_t jj_RN = (xtiles - (xtiles * RN - n));
406-
GGML_ASSERT(jj_RN * RN + (xtiles - jj_RN) * (RN - 1) == n);
398+
const int64_t jj_RN = (xtiles - (xtiles * RN - n)) * RN;
407399

408-
const int64_t tiles = xtiles * ytiles;
409-
const int64_t duty = (tiles + nth - 1) / nth;
410-
const int64_t start = duty * ith;
411-
int64_t end = start + duty;
412-
if (end > tiles)
413-
end = tiles;
414-
for (int64_t job = start; job < end; ++job) {
415-
const int64_t ii = job / xtiles;
416-
const int64_t jj = job % xtiles;
417-
for (int64_t bi = 0; bi < BM; ++bi) {
418-
if (jj < jj_RN) {
419-
gemm_bloc<RM, RN>((ii * BM + bi) * RM, jj * RN);
420-
} else if constexpr (RN > 1) {
421-
gemm_bloc<RM, RN - 1>((ii * BM + bi) * RM, jj_RN * RN + (jj - jj_RN) * (RN - 1));
400+
static std::atomic<int64_t> current_chunk;
401+
if (params->ith == 0) {
402+
GGML_ASSERT((xtiles * RN - n) >= 0);
403+
GGML_ASSERT((xtiles * RN - n) < RN);
404+
405+
// Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
406+
std::atomic_store_explicit(&current_chunk, (int64_t)params->nth, std::memory_order_relaxed);
407+
}
408+
ggml_barrier(params->threadpool);
409+
int64_t ii = params->ith * RM * BM;
410+
411+
while (ii < m) {
412+
for (int64_t bi = 0; bi < BM * RM; bi+=RM) {
413+
int64_t jj = 0;
414+
for (; jj<jj_RN; jj+=RN) {
415+
gemm_bloc<RM, RN>(ii + bi, jj);
422416
}
417+
if constexpr (RN > 1) {
418+
for (; jj<n; jj+=RN-1) {
419+
gemm_bloc<RM, RN-1>(ii + bi, jj);
420+
}
421+
}
422+
GGML_ASSERT(jj == n);
423423
}
424+
ii = std::atomic_fetch_add_explicit(&current_chunk, (int64_t)1, std::memory_order_relaxed) * RM * BM;
424425
}
426+
ggml_barrier(params->threadpool);
425427
}
426428

429+
const ggml_compute_params * params;
427430
const TA *const A;
428431
const TB *const B;
429432
TC *const C;
430433
const int64_t k;
431434
const int64_t lda;
432435
const int64_t ldb;
433436
const int64_t ldc;
434-
const int ith;
435-
const int nth;
436437
};
437438

438439
//////////////////////////////////////////////////////////////////////////////////////////
@@ -1636,18 +1637,20 @@ class tinyBLAS_PPC {
16361637
* @param Ctype is GGML data type of `C`
16371638
* @return true if this function was able to service the matmul request
16381639
*/
1639-
bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda, const void *B, int64_t ldb, void *C,
1640-
int64_t ldc, int ith, int nth, int Atype, int Btype, int Ctype) {
1640+
bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64_t n, int64_t k,
1641+
const void *A, int64_t lda, const void *B, int64_t ldb, void *C,
1642+
int64_t ldc, int Atype, int Btype, int Ctype) {
16411643

16421644
assert(m >= 0);
16431645
assert(n >= 0);
16441646
assert(k >= 0);
16451647
assert(lda >= k);
16461648
assert(ldb >= k);
16471649
assert(ldc >= m);
1648-
assert(nth > 0);
1649-
assert(ith < nth);
1650+
assert(params->nth > 0);
1651+
assert(params->ith < params->nth);
16501652

1653+
// OK avec moins de thread 4 max en zen3 / 16 coeurs?
16511654
// only enable sgemm for prompt processing
16521655
if (n < 2)
16531656
return false;
@@ -1661,27 +1664,24 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
16611664
if (Btype != GGML_TYPE_F32)
16621665
return false;
16631666
#if defined(__AVX512F__)
1664-
tinyBLAS<16, __m512, __m512, float, float, float> tb{
1667+
tinyBLAS<16, __m512, __m512, float, float, float> tb{ params,
16651668
k, (const float *)A, lda,
16661669
(const float *)B, ldb,
1667-
(float *)C, ldc,
1668-
ith, nth};
1670+
(float *)C, ldc};
16691671
return tb.matmul(m, n);
16701672
#elif defined(__AVX__) || defined(__AVX2__)
1671-
tinyBLAS<8, __m256, __m256, float, float, float> tb{
1673+
tinyBLAS<8, __m256, __m256, float, float, float> tb{ params,
16721674
k, (const float *)A, lda,
16731675
(const float *)B, ldb,
1674-
(float *)C, ldc,
1675-
ith, nth};
1676+
(float *)C, ldc};
16761677
return tb.matmul(m, n);
16771678
#elif defined(__ARM_NEON)
16781679
if (n < 4)
16791680
return false;
1680-
tinyBLAS<4, float32x4_t, float32x4_t, float, float, float> tb{
1681+
tinyBLAS<4, float32x4_t, float32x4_t, float, float, float> tb{ params,
16811682
k, (const float *)A, lda,
16821683
(const float *)B, ldb,
1683-
(float *)C, ldc,
1684-
ith, nth};
1684+
(float *)C, ldc};
16851685
return tb.matmul(m, n);
16861686
#elif defined(__MMA__)
16871687
if (k % 8)
@@ -1690,7 +1690,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
16901690
k, (const float *)A, lda,
16911691
(const float *)B, ldb,
16921692
(float *)C, ldc,
1693-
ith, nth};
1693+
params->ith, params->nth};
16941694
tb.matmul(m, n);
16951695
return true;
16961696
#else
@@ -1701,29 +1701,26 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
17011701
case GGML_TYPE_BF16: {
17021702
#if defined(__AVX512BF16__)
17031703
if (Btype == GGML_TYPE_BF16) {
1704-
tinyBLAS<32, __m512, __m512bh, ggml_bf16_t, ggml_bf16_t, float> tb{ k,
1704+
tinyBLAS<32, __m512, __m512bh, ggml_bf16_t, ggml_bf16_t, float> tb{ params, k,
17051705
(const ggml_bf16_t *)A, lda,
17061706
(const ggml_bf16_t *)B, ldb,
1707-
(float *)C, ldc,
1708-
ith, nth};
1707+
(float *)C, ldc};
17091708
return tb.matmul(m, n);
17101709
}
17111710
#elif defined(__AVX512F__)
17121711
if (Btype == GGML_TYPE_BF16) {
1713-
tinyBLAS<16, __m512, __m512, ggml_bf16_t, ggml_bf16_t, float> tb{ k,
1712+
tinyBLAS<16, __m512, __m512, ggml_bf16_t, ggml_bf16_t, float> tb{ params, k,
17141713
(const ggml_bf16_t *)A, lda,
17151714
(const ggml_bf16_t *)B, ldb,
1716-
(float *)C, ldc,
1717-
ith, nth};
1715+
(float *)C, ldc};
17181716
return tb.matmul(m, n);
17191717
}
17201718
#elif defined(__AVX2__)
17211719
if (Btype == GGML_TYPE_BF16) {
1722-
tinyBLAS<8, __m256, __m256, ggml_bf16_t, ggml_bf16_t, float> tb{ k,
1720+
tinyBLAS<8, __m256, __m256, ggml_bf16_t, ggml_bf16_t, float> tb{ params, k,
17231721
(const ggml_bf16_t *)A, lda,
17241722
(const ggml_bf16_t *)B, ldb,
1725-
(float *)C, ldc,
1726-
ith, nth};
1723+
(float *)C, ldc};
17271724
return tb.matmul(m, n);
17281725
}
17291726
#endif
@@ -1732,40 +1729,36 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
17321729
case GGML_TYPE_F16: {
17331730
#if defined(__AVX512F__)
17341731
if (Btype == GGML_TYPE_F16) {
1735-
tinyBLAS<16, __m512, __m512, ggml_fp16_t, ggml_fp16_t, float> tb{ k,
1732+
tinyBLAS<16, __m512, __m512, ggml_fp16_t, ggml_fp16_t, float> tb{ params, k,
17361733
(const ggml_fp16_t *)A, lda,
17371734
(const ggml_fp16_t *)B, ldb,
1738-
(float *)C, ldc,
1739-
ith, nth};
1735+
(float *)C, ldc};
17401736
return tb.matmul(m, n);
17411737
}
17421738
#elif (defined(__AVX__) || defined(__AVX2__)) && defined(__F16C__)
17431739
if (Btype == GGML_TYPE_F16) {
1744-
tinyBLAS<8, __m256, __m256, ggml_fp16_t, ggml_fp16_t, float> tb{ k,
1740+
tinyBLAS<8, __m256, __m256, ggml_fp16_t, ggml_fp16_t, float> tb{ params, k,
17451741
(const ggml_fp16_t *)A, lda,
17461742
(const ggml_fp16_t *)B, ldb,
1747-
(float *)C, ldc,
1748-
ith, nth};
1743+
(float *)C, ldc};
17491744
return tb.matmul(m, n);
17501745
}
17511746
#elif defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
17521747
if (n < 8)
17531748
return false;
17541749
if (Btype == GGML_TYPE_F16) {
1755-
tinyBLAS<8, float16x8_t, float16x8_t, ggml_fp16_t, ggml_fp16_t, float> tb{
1750+
tinyBLAS<8, float16x8_t, float16x8_t, ggml_fp16_t, ggml_fp16_t, float> tb{ params,
17561751
k, (const ggml_fp16_t *)A, lda,
17571752
(const ggml_fp16_t *)B, ldb,
1758-
(float *)C, ldc,
1759-
ith, nth};
1753+
(float *)C, ldc};
17601754
return tb.matmul(m, n);
17611755
}
17621756
#elif defined(__ARM_NEON) && !defined(_MSC_VER)
17631757
if (Btype == GGML_TYPE_F32) {
1764-
tinyBLAS<4, float32x4_t, float32x4_t, ggml_fp16_t, float, float> tb{
1758+
tinyBLAS<4, float32x4_t, float32x4_t, ggml_fp16_t, float, float> tb{ params,
17651759
k, (const ggml_fp16_t *)A, lda,
17661760
(const float *)B, ldb,
1767-
(float *)C, ldc,
1768-
ith, nth};
1761+
(float *)C, ldc};
17691762
return tb.matmul(m, n);
17701763
}
17711764
#endif
@@ -1780,15 +1773,15 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
17801773
k, (const block_q8_0 *)A, lda,
17811774
(const block_q8_0 *)B, ldb,
17821775
(float *)C, ldc,
1783-
ith, nth};
1776+
params->ith, params->nth};
17841777
tb.matmul(m, n);
17851778
return true;
17861779
#elif defined(__ARM_FEATURE_DOTPROD)
17871780
tinyBLAS_Q0_ARM<block_q8_0> tb{
17881781
k, (const block_q8_0 *)A, lda,
17891782
(const block_q8_0 *)B, ldb,
17901783
(float *)C, ldc,
1791-
ith, nth};
1784+
params->ith, params->nth};
17921785
tb.matmul(m, n);
17931786
return true;
17941787
#else
@@ -1804,15 +1797,15 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
18041797
k, (const block_q4_0 *)A, lda,
18051798
(const block_q8_0 *)B, ldb,
18061799
(float *)C, ldc,
1807-
ith, nth};
1800+
params->ith, params->nth};
18081801
tb.matmul(m, n);
18091802
return true;
18101803
#elif defined(__ARM_FEATURE_DOTPROD)
18111804
tinyBLAS_Q0_ARM<block_q4_0> tb{
18121805
k, (const block_q4_0 *)A, lda,
18131806
(const block_q8_0 *)B, ldb,
18141807
(float *)C, ldc,
1815-
ith, nth};
1808+
params->ith, params->nth};
18161809
tb.matmul(m, n);
18171810
return true;
18181811
#else
@@ -1828,7 +1821,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
18281821
k, (const block_q5_0 *)A, lda,
18291822
(const block_q8_0 *)B, ldb,
18301823
(float *)C, ldc,
1831-
ith, nth};
1824+
params->ith, params->nth};
18321825
tb.matmul(m, n);
18331826
return true;
18341827
#else
@@ -1844,7 +1837,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
18441837
k, (const block_iq4_nl *)A, lda,
18451838
(const block_q8_0 *)B, ldb,
18461839
(float *)C, ldc,
1847-
ith, nth};
1840+
params->ith, params->nth};
18481841
tb.matmul(m, n);
18491842
return true;
18501843
#else
@@ -1856,6 +1849,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
18561849
return false;
18571850
}
18581851

1852+
(void)params;
18591853
(void)m;
18601854
(void)n;
18611855
(void)k;
@@ -1865,8 +1859,6 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
18651859
(void)ldb;
18661860
(void)C;
18671861
(void)ldc;
1868-
(void)ith;
1869-
(void)nth;
18701862
(void)Atype;
18711863
(void)Btype;
18721864
(void)Ctype;

ggml/src/ggml-cpu/llamafile/sgemm.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
extern "C" {
66
#endif
77

8-
bool llamafile_sgemm(int64_t, int64_t, int64_t, const void *, int64_t,
9-
const void *, int64_t, void *, int64_t, int, int,
8+
bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t, int64_t, int64_t,
9+
const void *, int64_t, const void *, int64_t, void *, int64_t,
1010
int, int, int);
1111

1212
#ifdef __cplusplus

0 commit comments

Comments
 (0)