5353#include " ggml-cpu-impl.h"
5454#include " ggml-quants.h"
5555
56+ #include < atomic>
57+
5658#ifdef _MSC_VER
5759#define NOINLINE __declspec (noinline)
5860#else
@@ -298,23 +300,18 @@ static int64_t BLOCK_SIZE(size_t m) {
298300template <int KN, typename D, typename V, typename TA, typename TB, typename TC>
299301class tinyBLAS {
300302 public:
301- tinyBLAS (int64_t k,
303+ tinyBLAS (const ggml_compute_params * params, int64_t k,
302304 const TA *A, int64_t lda,
303305 const TB *B, int64_t ldb,
304- TC *C, int64_t ldc,
305- int ith, int nth)
306- : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
306+ TC *C, int64_t ldc)
307+ : params(params), A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc) {
307308 }
308309
309310 bool matmul (int64_t m, int64_t n) {
310311 if (k % KN != 0 )
311312 return false ;
312313 // compute RN/RM for only tile with size RN&RN-1/RM&RM-1
313314#if VECTOR_REGISTERS == 32
314- if (m % 8 == 0 && n < 4 ) {
315- mnpack<8 , 3 , 1 >(m, n, n);
316- return true ;
317- }
318315 if (m % 16 == 0 ) {
319316 const int64_t SIZE_N = BLOCK_SIZE<6 >(n);
320317 mnpack<4 , 6 , 4 >(m, n, SIZE_N);
@@ -331,10 +328,6 @@ class tinyBLAS {
331328 return true ;
332329 }
333330#else // VECTOR_REGISTERS == 16
334- if (m % 8 == 0 && n == 1 ) {
335- gemm<8 , 1 , 1 >(m, n);
336- return true ;
337- }
338331 if (m % 8 == 0 ) {
339332 const int64_t SIZE_N = BLOCK_SIZE<3 >(n);
340333 mnpack<4 , 3 , 2 >(m, n, SIZE_N);
@@ -400,39 +393,47 @@ class tinyBLAS {
400393 template <int RM, int RN, int BM>
401394 NOINLINE void gemm (int64_t m, int64_t n) {
402395 GGML_ASSERT (m % (RM * BM) == 0 );
403- const int64_t ytiles = m / (RM * BM);
396+ // const int64_t ytiles = m / (RM * BM);
404397 const int64_t xtiles = (n + RN -1 ) / RN;
405- const int64_t jj_RN = (xtiles - (xtiles * RN - n));
406- GGML_ASSERT (jj_RN * RN + (xtiles - jj_RN) * (RN - 1 ) == n);
398+ const int64_t jj_RN = (xtiles - (xtiles * RN - n)) * RN;
407399
408- const int64_t tiles = xtiles * ytiles;
409- const int64_t duty = (tiles + nth - 1 ) / nth;
410- const int64_t start = duty * ith;
411- int64_t end = start + duty;
412- if (end > tiles)
413- end = tiles;
414- for (int64_t job = start; job < end; ++job) {
415- const int64_t ii = job / xtiles;
416- const int64_t jj = job % xtiles;
417- for (int64_t bi = 0 ; bi < BM; ++bi) {
418- if (jj < jj_RN) {
419- gemm_bloc<RM, RN>((ii * BM + bi) * RM, jj * RN);
420- } else if constexpr (RN > 1 ) {
421- gemm_bloc<RM, RN - 1 >((ii * BM + bi) * RM, jj_RN * RN + (jj - jj_RN) * (RN - 1 ));
400+ static std::atomic<int64_t > current_chunk;
401+ if (params->ith == 0 ) {
402+ GGML_ASSERT ((xtiles * RN - n) >= 0 );
403+ GGML_ASSERT ((xtiles * RN - n) < RN);
404+
405+ // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
406+ std::atomic_store_explicit (¤t_chunk, (int64_t )params->nth , std::memory_order_relaxed);
407+ }
408+ ggml_barrier (params->threadpool );
409+ int64_t ii = params->ith * RM * BM;
410+
411+ while (ii < m) {
412+ for (int64_t bi = 0 ; bi < BM * RM; bi+=RM) {
413+ int64_t jj = 0 ;
414+ for (; jj<jj_RN; jj+=RN) {
415+ gemm_bloc<RM, RN>(ii + bi, jj);
422416 }
417+ if constexpr (RN > 1 ) {
418+ for (; jj<n; jj+=RN-1 ) {
419+ gemm_bloc<RM, RN-1 >(ii + bi, jj);
420+ }
421+ }
422+ GGML_ASSERT (jj == n);
423423 }
424+ ii = std::atomic_fetch_add_explicit (¤t_chunk, (int64_t )1 , std::memory_order_relaxed) * RM * BM;
424425 }
426+ ggml_barrier (params->threadpool );
425427 }
426428
429+ const ggml_compute_params * params;
427430 const TA *const A;
428431 const TB *const B;
429432 TC *const C;
430433 const int64_t k;
431434 const int64_t lda;
432435 const int64_t ldb;
433436 const int64_t ldc;
434- const int ith;
435- const int nth;
436437};
437438
438439// ////////////////////////////////////////////////////////////////////////////////////////
@@ -1636,18 +1637,20 @@ class tinyBLAS_PPC {
16361637 * @param Ctype is GGML data type of `C`
16371638 * @return true if this function was able to service the matmul request
16381639 */
1639- bool llamafile_sgemm (int64_t m, int64_t n, int64_t k, const void *A, int64_t lda, const void *B, int64_t ldb, void *C,
1640- int64_t ldc, int ith, int nth, int Atype, int Btype, int Ctype) {
1640+ bool llamafile_sgemm (const struct ggml_compute_params * params, int64_t m, int64_t n, int64_t k,
1641+ const void *A, int64_t lda, const void *B, int64_t ldb, void *C,
1642+ int64_t ldc, int Atype, int Btype, int Ctype) {
16411643
16421644 assert (m >= 0 );
16431645 assert (n >= 0 );
16441646 assert (k >= 0 );
16451647 assert (lda >= k);
16461648 assert (ldb >= k);
16471649 assert (ldc >= m);
1648- assert (nth > 0 );
1649- assert (ith < nth);
1650+ assert (params-> nth > 0 );
1651+ assert (params-> ith < params-> nth );
16501652
1653+ // OK avec moins de thread 4 max en zen3 / 16 coeurs?
16511654 // only enable sgemm for prompt processing
16521655 if (n < 2 )
16531656 return false ;
@@ -1661,27 +1664,24 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
16611664 if (Btype != GGML_TYPE_F32)
16621665 return false ;
16631666#if defined(__AVX512F__)
1664- tinyBLAS<16 , __m512, __m512, float , float , float > tb{
1667+ tinyBLAS<16 , __m512, __m512, float , float , float > tb{ params,
16651668 k, (const float *)A, lda,
16661669 (const float *)B, ldb,
1667- (float *)C, ldc,
1668- ith, nth};
1670+ (float *)C, ldc};
16691671 return tb.matmul (m, n);
16701672#elif defined(__AVX__) || defined(__AVX2__)
1671- tinyBLAS<8 , __m256, __m256, float , float , float > tb{
1673+ tinyBLAS<8 , __m256, __m256, float , float , float > tb{ params,
16721674 k, (const float *)A, lda,
16731675 (const float *)B, ldb,
1674- (float *)C, ldc,
1675- ith, nth};
1676+ (float *)C, ldc};
16761677 return tb.matmul (m, n);
16771678#elif defined(__ARM_NEON)
16781679 if (n < 4 )
16791680 return false ;
1680- tinyBLAS<4 , float32x4_t , float32x4_t , float , float , float > tb{
1681+ tinyBLAS<4 , float32x4_t , float32x4_t , float , float , float > tb{ params,
16811682 k, (const float *)A, lda,
16821683 (const float *)B, ldb,
1683- (float *)C, ldc,
1684- ith, nth};
1684+ (float *)C, ldc};
16851685 return tb.matmul (m, n);
16861686#elif defined(__MMA__)
16871687 if (k % 8 )
@@ -1690,7 +1690,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
16901690 k, (const float *)A, lda,
16911691 (const float *)B, ldb,
16921692 (float *)C, ldc,
1693- ith, nth};
1693+ params-> ith , params-> nth };
16941694 tb.matmul (m, n);
16951695 return true ;
16961696#else
@@ -1701,29 +1701,26 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
17011701 case GGML_TYPE_BF16: {
17021702#if defined(__AVX512BF16__)
17031703 if (Btype == GGML_TYPE_BF16) {
1704- tinyBLAS<32 , __m512, __m512bh, ggml_bf16_t , ggml_bf16_t , float > tb{ k,
1704+ tinyBLAS<32 , __m512, __m512bh, ggml_bf16_t , ggml_bf16_t , float > tb{ params, k,
17051705 (const ggml_bf16_t *)A, lda,
17061706 (const ggml_bf16_t *)B, ldb,
1707- (float *)C, ldc,
1708- ith, nth};
1707+ (float *)C, ldc};
17091708 return tb.matmul (m, n);
17101709 }
17111710#elif defined(__AVX512F__)
17121711 if (Btype == GGML_TYPE_BF16) {
1713- tinyBLAS<16 , __m512, __m512, ggml_bf16_t , ggml_bf16_t , float > tb{ k,
1712+ tinyBLAS<16 , __m512, __m512, ggml_bf16_t , ggml_bf16_t , float > tb{ params, k,
17141713 (const ggml_bf16_t *)A, lda,
17151714 (const ggml_bf16_t *)B, ldb,
1716- (float *)C, ldc,
1717- ith, nth};
1715+ (float *)C, ldc};
17181716 return tb.matmul (m, n);
17191717 }
17201718#elif defined(__AVX2__)
17211719 if (Btype == GGML_TYPE_BF16) {
1722- tinyBLAS<8 , __m256, __m256, ggml_bf16_t , ggml_bf16_t , float > tb{ k,
1720+ tinyBLAS<8 , __m256, __m256, ggml_bf16_t , ggml_bf16_t , float > tb{ params, k,
17231721 (const ggml_bf16_t *)A, lda,
17241722 (const ggml_bf16_t *)B, ldb,
1725- (float *)C, ldc,
1726- ith, nth};
1723+ (float *)C, ldc};
17271724 return tb.matmul (m, n);
17281725 }
17291726#endif
@@ -1732,40 +1729,36 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
17321729 case GGML_TYPE_F16: {
17331730#if defined(__AVX512F__)
17341731 if (Btype == GGML_TYPE_F16) {
1735- tinyBLAS<16 , __m512, __m512, ggml_fp16_t , ggml_fp16_t , float > tb{ k,
1732+ tinyBLAS<16 , __m512, __m512, ggml_fp16_t , ggml_fp16_t , float > tb{ params, k,
17361733 (const ggml_fp16_t *)A, lda,
17371734 (const ggml_fp16_t *)B, ldb,
1738- (float *)C, ldc,
1739- ith, nth};
1735+ (float *)C, ldc};
17401736 return tb.matmul (m, n);
17411737 }
17421738#elif (defined(__AVX__) || defined(__AVX2__)) && defined(__F16C__)
17431739 if (Btype == GGML_TYPE_F16) {
1744- tinyBLAS<8 , __m256, __m256, ggml_fp16_t , ggml_fp16_t , float > tb{ k,
1740+ tinyBLAS<8 , __m256, __m256, ggml_fp16_t , ggml_fp16_t , float > tb{ params, k,
17451741 (const ggml_fp16_t *)A, lda,
17461742 (const ggml_fp16_t *)B, ldb,
1747- (float *)C, ldc,
1748- ith, nth};
1743+ (float *)C, ldc};
17491744 return tb.matmul (m, n);
17501745 }
17511746#elif defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
17521747 if (n < 8 )
17531748 return false ;
17541749 if (Btype == GGML_TYPE_F16) {
1755- tinyBLAS<8 , float16x8_t , float16x8_t , ggml_fp16_t , ggml_fp16_t , float > tb{
1750+ tinyBLAS<8 , float16x8_t , float16x8_t , ggml_fp16_t , ggml_fp16_t , float > tb{ params,
17561751 k, (const ggml_fp16_t *)A, lda,
17571752 (const ggml_fp16_t *)B, ldb,
1758- (float *)C, ldc,
1759- ith, nth};
1753+ (float *)C, ldc};
17601754 return tb.matmul (m, n);
17611755 }
17621756#elif defined(__ARM_NEON) && !defined(_MSC_VER)
17631757 if (Btype == GGML_TYPE_F32) {
1764- tinyBLAS<4 , float32x4_t , float32x4_t , ggml_fp16_t , float , float > tb{
1758+ tinyBLAS<4 , float32x4_t , float32x4_t , ggml_fp16_t , float , float > tb{ params,
17651759 k, (const ggml_fp16_t *)A, lda,
17661760 (const float *)B, ldb,
1767- (float *)C, ldc,
1768- ith, nth};
1761+ (float *)C, ldc};
17691762 return tb.matmul (m, n);
17701763 }
17711764#endif
@@ -1780,15 +1773,15 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
17801773 k, (const block_q8_0 *)A, lda,
17811774 (const block_q8_0 *)B, ldb,
17821775 (float *)C, ldc,
1783- ith, nth};
1776+ params-> ith , params-> nth };
17841777 tb.matmul (m, n);
17851778 return true ;
17861779#elif defined(__ARM_FEATURE_DOTPROD)
17871780 tinyBLAS_Q0_ARM<block_q8_0> tb{
17881781 k, (const block_q8_0 *)A, lda,
17891782 (const block_q8_0 *)B, ldb,
17901783 (float *)C, ldc,
1791- ith, nth};
1784+ params-> ith , params-> nth };
17921785 tb.matmul (m, n);
17931786 return true ;
17941787#else
@@ -1804,15 +1797,15 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
18041797 k, (const block_q4_0 *)A, lda,
18051798 (const block_q8_0 *)B, ldb,
18061799 (float *)C, ldc,
1807- ith, nth};
1800+ params-> ith , params-> nth };
18081801 tb.matmul (m, n);
18091802 return true ;
18101803#elif defined(__ARM_FEATURE_DOTPROD)
18111804 tinyBLAS_Q0_ARM<block_q4_0> tb{
18121805 k, (const block_q4_0 *)A, lda,
18131806 (const block_q8_0 *)B, ldb,
18141807 (float *)C, ldc,
1815- ith, nth};
1808+ params-> ith , params-> nth };
18161809 tb.matmul (m, n);
18171810 return true ;
18181811#else
@@ -1828,7 +1821,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
18281821 k, (const block_q5_0 *)A, lda,
18291822 (const block_q8_0 *)B, ldb,
18301823 (float *)C, ldc,
1831- ith, nth};
1824+ params-> ith , params-> nth };
18321825 tb.matmul (m, n);
18331826 return true ;
18341827#else
@@ -1844,7 +1837,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
18441837 k, (const block_iq4_nl *)A, lda,
18451838 (const block_q8_0 *)B, ldb,
18461839 (float *)C, ldc,
1847- ith, nth};
1840+ params-> ith , params-> nth };
18481841 tb.matmul (m, n);
18491842 return true ;
18501843#else
@@ -1856,6 +1849,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
18561849 return false ;
18571850 }
18581851
1852+ (void )params;
18591853 (void )m;
18601854 (void )n;
18611855 (void )k;
@@ -1865,8 +1859,6 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
18651859 (void )ldb;
18661860 (void )C;
18671861 (void )ldc;
1868- (void )ith;
1869- (void )nth;
18701862 (void )Atype;
18711863 (void )Btype;
18721864 (void )Ctype;
0 commit comments