@@ -291,10 +291,13 @@ static const __m128i iq4nlt = _mm_loadu_si128((const __m128i *) kvalues_iq4nl);
291291// FLOATING POINT MATRIX MULTIPLICATION
292292
293293template <int M>
294- static int64_t BLOCK_SIZE (size_t m) {
294+ static inline int64_t BLOCK_SIZE (size_t m) {
295295 const int64_t NB_BLOC_M = (m + M - 1 ) / M;
296- int64_t res = (m % NB_BLOC_M == 0 ) ? m / NB_BLOC_M : (m / NB_BLOC_M) + 1 ;
297- return res;
296+ return (m % NB_BLOC_M == 0 ) ? m / NB_BLOC_M : (m / NB_BLOC_M) + 1 ;
297+ }
298+
299+ static constexpr inline int64_t BLOC_POS (int64_t ib, int64_t ibN, int64_t bloc_size) {
300+ return ib < ibN ? ib * bloc_size : ibN * bloc_size + (ib - ibN) * (bloc_size - 1 );
298301}
299302
300303template <int KN, typename D, typename V, typename TA, typename TB, typename TC>
@@ -310,32 +313,37 @@ class tinyBLAS {
310313 bool matmul (int64_t m, int64_t n) {
311314 if (k % KN != 0 )
312315 return false ;
313- // compute RN/ RM for only tile with size RN&RN-1/ RM&RM-1
316+ // compute RM for only need tile with size RM&RM-1
314317#if VECTOR_REGISTERS == 32
315- if (m % 16 == 0 ) {
318+ if (m % 16 == 0 && (m/ 16 >= params-> nth ) ) {
316319 const int64_t SIZE_N = BLOCK_SIZE<6 >(n);
317- mnpack<4 , 6 , 4 >(m, n, SIZE_N);
320+ mnpack<4 , 6 , 4 >(m, n, SIZE_N, 12 );
318321 return true ;
319322 }
320- if (m % 8 == 0 ) {
323+ if (m % 8 == 0 ) {
321324 const int64_t SIZE_N = BLOCK_SIZE<6 >(n);
322- mnpack<4 , 6 , 2 >(m, n, SIZE_N);
325+ mnpack<4 , 6 , 2 >(m, n, SIZE_N, 12 );
323326 return true ;
324327 }
325328 if (m % 4 == 0 ) {
326329 const int64_t SIZE_N = BLOCK_SIZE<6 >(n);
327- mnpack<4 , 6 , 1 >(m, n, SIZE_N);
330+ mnpack<4 , 6 , 1 >(m, n, SIZE_N, 12 );
328331 return true ;
329332 }
330333#else // VECTOR_REGISTERS == 16
331- if (m % 8 == 0 ) {
334+ if (m % 16 == 0 && (m/16 >= params->nth )) {
335+ const int64_t SIZE_N = BLOCK_SIZE<3 >(n);
336+ mnpack<4 , 3 , 4 >(m, n, SIZE_N, 24 );
337+ return true ;
338+ }
339+ if (m % 8 == 0 ) {
332340 const int64_t SIZE_N = BLOCK_SIZE<3 >(n);
333- mnpack<4 , 3 , 2 >(m, n, SIZE_N);
341+ mnpack<4 , 3 , 2 >(m, n, SIZE_N, 24 );
334342 return true ;
335343 }
336344 if (m % 4 == 0 ) {
337345 const int64_t SIZE_N = BLOCK_SIZE<3 >(n);
338- mnpack<4 , 3 , 1 >(m, n, SIZE_N);
346+ mnpack<4 , 3 , 1 >(m, n, SIZE_N, 24 );
339347 return true ;
340348 }
341349#endif
@@ -344,12 +352,12 @@ class tinyBLAS {
344352
345353 private:
346354 template <int RM, int RN, int BM>
347- inline void mnpack (int64_t m, int64_t n, int64_t SIZE_N) {
355+ inline void mnpack (int64_t m, int64_t n, int64_t SIZE_N, int64_t BN ) {
348356 if (SIZE_N == RN) {
349- return gemm<RM, RN, BM>(m, n);
357+ return gemm<RM, RN, BM>(m, n, BN );
350358 }
351359 if constexpr (RN > 1 ) {
352- return mnpack<RM, RN-1 , BM>(m, n, SIZE_N);
360+ return mnpack<RM, RN-1 , BM>(m, n, SIZE_N, BN );
353361 } else {
354362 GGML_LOG_ERROR (" mnpack<%d, %d> bloc size not supported\n " , RM, (int )SIZE_N);
355363 GGML_ASSERT (false ); // we have miss something.
@@ -391,39 +399,58 @@ class tinyBLAS {
391399 }
392400
393401 template <int RM, int RN, int BM>
394- NOINLINE void gemm (int64_t m, int64_t n) {
402+ NOINLINE void gemm (int64_t m, int64_t n, int64_t BN) {
403+ static std::atomic<int64_t > current_chunk;
404+
395405 GGML_ASSERT (m % (RM * BM) == 0 );
396- // const int64_t ytiles = m / (RM * BM);
406+ const int64_t ytiles = m / (RM * BM);
397407 const int64_t xtiles = (n + RN -1 ) / RN;
398- const int64_t jj_RN = (xtiles - (xtiles * RN - n)) * RN ;
408+ const int64_t jj_RN = (xtiles - (xtiles * RN - n));
399409
400- static std::atomic<int64_t > current_chunk;
401- if (params->ith == 0 ) {
402- GGML_ASSERT ((xtiles * RN - n) >= 0 );
403- GGML_ASSERT ((xtiles * RN - n) < RN);
410+ // "round" bloc_size to "nearest" BN
411+ const int64_t NB_BN = xtiles < BN ? 1 : (xtiles + BN / 2 ) / BN;
412+ const int64_t SIZE_BN = xtiles % NB_BN == 0 ? xtiles / NB_BN : xtiles / NB_BN + 1 ;
413+ const int64_t jj_BN = (NB_BN - (NB_BN * SIZE_BN - xtiles));
414+ const int64_t nb_job = ytiles * NB_BN;
404415
416+ if (params->ith == 0 ) {
417+ GGML_ASSERT ( jj_BN * SIZE_BN + (NB_BN - jj_BN) * (SIZE_BN - 1 ) == xtiles);
405418 // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
406419 std::atomic_store_explicit (¤t_chunk, (int64_t )params->nth , std::memory_order_relaxed);
407420 }
421+
408422 ggml_barrier (params->threadpool );
409- int64_t ii = params->ith * RM * BM;
410423
411- while (ii < m) {
412- for (int64_t bi = 0 ; bi < BM * RM; bi+=RM) {
413- int64_t jj = 0 ;
414- for (; jj<jj_RN; jj+=RN) {
424+ int64_t job = params->ith ;
425+ while (job < nb_job) {
426+ const int64_t ii = (job % ytiles) * RM * BM;
427+ const int64_t jb = job / ytiles;
428+ const int64_t jr0 = BLOC_POS (jb , jj_BN, SIZE_BN);
429+ const int64_t jrN = BLOC_POS (jb+1 , jj_BN, SIZE_BN);
430+
431+ const int64_t jj0 = BLOC_POS (jr0, jj_RN, RN);
432+ const int64_t jj2 = BLOC_POS (jrN, jj_RN, RN);
433+ const int64_t jj1 = jj2 < jj_RN * RN ? jj2 : jj_RN * RN;
434+
435+ for (int64_t bi = 0 ; bi < BM * RM; bi += RM) {
436+ int64_t jj = jj0;
437+ for (; jj < jj1; jj += RN) {
415438 gemm_bloc<RM, RN>(ii + bi, jj);
416439 }
417440 if constexpr (RN > 1 ) {
418- for (; jj<n ; jj+=RN- 1 ) {
441+ for (; jj < jj2 ; jj += RN - 1 ) {
419442 gemm_bloc<RM, RN-1 >(ii + bi, jj);
420443 }
421444 }
422- GGML_ASSERT (jj == n );
445+ GGML_ASSERT (jj == jj2 );
423446 }
424- ii = std::atomic_fetch_add_explicit (¤t_chunk, (int64_t )1 , std::memory_order_relaxed) * RM * BM;
447+
448+ // next step.
449+ job = std::atomic_fetch_add_explicit (¤t_chunk, (int64_t )1 , std::memory_order_relaxed);
425450 }
451+
426452 ggml_barrier (params->threadpool );
453+ return ;
427454 }
428455
429456 const ggml_compute_params * params;
@@ -1650,7 +1677,6 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
16501677 assert (params->nth > 0 );
16511678 assert (params->ith < params->nth );
16521679
1653- // OK avec moins de thread 4 max en zen3 / 16 coeurs?
16541680 // only enable sgemm for prompt processing
16551681 if (n < 2 )
16561682 return false ;
0 commit comments