@@ -1320,6 +1320,67 @@ kernel void kernel_mat_mv(device const void * src0,
1320
1320
}
1321
1321
}
1322
1322
1323
+ template <typename block_q_type, int nr, int nsg, int nl, int n_shift, template <typename , typename , typename > class quant_dri >
1324
+ kernel void kernel_mat_mv_no_tg_mem (device const void * src0,
1325
+ device const float * src1,
1326
+ device float * dst,
1327
+ constant int64_t & ne00,
1328
+ constant int64_t & ne01,
1329
+ constant int64_t & ne02,
1330
+ constant int64_t & ne10,
1331
+ constant int64_t & ne12,
1332
+ constant int64_t & ne0,
1333
+ constant int64_t & ne1,
1334
+ constant uint & gqa,
1335
+ threadgroup uint * shared_memory[[threadgroup(0 )]],
1336
+ uint3 tgpig[[threadgroup_position_in_grid]],
1337
+ uint tiisg[[thread_index_in_simdgroup]],
1338
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1339
+ const int nb = ne00/(nl * 16 );
1340
+ const int r0 = tgpig.x ;
1341
+ const int r1 = tgpig.y ;
1342
+ const int im = tgpig.z ;
1343
+ const int ix = tiisg / nl;
1344
+ const int il = tiisg % nl;
1345
+ const int first_row = (r0 * nsg) * nr + sgitg;
1346
+ const uint offset0 = first_row * nb + im/gqa*(nb*ne0) + ix;
1347
+ const uint offset1 = r1*ne10 + im*ne00*ne1 + ix * (nl * 16 ) + (il/(n_shift/8 ))*16 *(n_shift/8 ) + (il%(n_shift/8 )) * 8 ;
1348
+
1349
+ device const block_q_type * x = (device const block_q_type *) src0 + offset0;
1350
+ device const float * y = (device const float *) src1 + offset1;
1351
+
1352
+ float4x4 yl; // src1 vector cache
1353
+ float sumf[nr] = {0 .f };
1354
+
1355
+ quant_dri<device const uint16_t *, device const block_q_type *, half4x4> dequan_worker;
1356
+ dequan_worker.init (il);
1357
+
1358
+ // each thread in a SIMD group deals with 16 dequantized weights.
1359
+ for (int ib = ix; ib < nb; ib += N_SIMDWIDTH / nl) {
1360
+ yl[0 ] = *((device const float4 *)y);
1361
+ yl[1 ] = *((device const float4 *)y + 1 );
1362
+ yl[2 ] = *((device const float4 *)y + n_shift/4 );
1363
+ yl[3 ] = *((device const float4 *)y + n_shift/4 + 1 );
1364
+
1365
+ dequan_worker.inner_product_pre (il, yl);
1366
+ #pragma unroll(nr)
1367
+ for (int row = 0 ; row < nr; row++) {
1368
+ float sum_temp = 0 .f ;
1369
+ dequan_worker.inner_product (x + 2 * nb * row, il, yl, sum_temp);
1370
+ sumf[row] += sum_temp;
1371
+ }
1372
+ x += N_SIMDWIDTH / nl;
1373
+ y += N_SIMDWIDTH * 16 ;
1374
+ }
1375
+
1376
+ for (int row = 0 ; row < nr; ++row) {
1377
+ const float tot = simd_sum (sumf[row]);
1378
+ if (tiisg == 0 && first_row + nsg * row < ne01) {
1379
+ dst[r1*ne0 + im*ne0*ne1 + first_row + nsg * row] = tot;
1380
+ }
1381
+ }
1382
+ }
1383
+
1323
1384
#define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
1324
1385
#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix A
1325
1386
#define BLOCK_SIZE_K 32
@@ -1487,10 +1548,10 @@ typedef void (mat_mv_t)(device const void *, device const float *, device float
1487
1548
1488
1549
#define N_DST 4
1489
1550
#define N_SIMDGROUP 2
1490
- template [[host_name(" kernel_mul_mv_f16_f32" )]] kernel mat_mv_t kernel_mat_mv <half4x4, N_DST, N_SIMDGROUP, 1 , 8 , f16_driver>;
1491
- template [[host_name(" kernel_mul_mv_q4_0_f32" )]] kernel mat_mv_t kernel_mat_mv <block_q4_0, N_DST, N_SIMDGROUP, 2 , 16 , q4_0_driver>;
1492
- template [[host_name(" kernel_mul_mv_q4_1_f32" )]] kernel mat_mv_t kernel_mat_mv <block_q4_1, N_DST, N_SIMDGROUP, 2 , 16 , q4_1_driver>;
1493
- template [[host_name(" kernel_mul_mv_q8_0_f32" )]] kernel mat_mv_t kernel_mat_mv <block_q8_0, N_DST, N_SIMDGROUP, 2 , 8 , q8_0_driver>;
1551
+ template [[host_name(" kernel_mul_mv_f16_f32" )]] kernel mat_mv_t kernel_mat_mv_no_tg_mem <half4x4, N_DST, N_SIMDGROUP, 1 , 8 , f16_driver>;
1552
+ template [[host_name(" kernel_mul_mv_q4_0_f32" )]] kernel mat_mv_t kernel_mat_mv_no_tg_mem <block_q4_0, N_DST, N_SIMDGROUP, 2 , 16 , q4_0_driver>;
1553
+ template [[host_name(" kernel_mul_mv_q4_1_f32" )]] kernel mat_mv_t kernel_mat_mv_no_tg_mem <block_q4_1, N_DST, N_SIMDGROUP, 2 , 16 , q4_1_driver>;
1554
+ template [[host_name(" kernel_mul_mv_q8_0_f32" )]] kernel mat_mv_t kernel_mat_mv_no_tg_mem <block_q8_0, N_DST, N_SIMDGROUP, 2 , 8 , q8_0_driver>;
1494
1555
template [[host_name(" kernel_mul_mv_q2_K_f32" )]] kernel mat_mv_t kernel_mat_mv<block_q2_K, N_DST, N_SIMDGROUP, QK_NL, 8 , q2_K_driver>;
1495
1556
template [[host_name(" kernel_mul_mv_q3_K_f32" )]] kernel mat_mv_t kernel_mat_mv<block_q3_K, N_DST, N_SIMDGROUP, QK_NL, 8 , q3_K_driver>;
1496
1557
template [[host_name(" kernel_mul_mv_q4_K_f32" )]] kernel mat_mv_t kernel_mat_mv<block_q4_K, N_DST, N_SIMDGROUP, QK_NL, 32 , q4_K_driver>;
0 commit comments