Skip to content

Commit 927b13c

Browse files
committed
metal: yet another MUL mat-vec template
1 parent aa4b7d2 commit 927b13c

File tree

2 files changed

+69
-5
lines changed

2 files changed

+69
-5
lines changed

ggml-metal.m

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -864,7 +864,10 @@ void ggml_metal_graph_compute(
864864
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:8];
865865
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:9];
866866
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:10];
867-
[encoder setThreadgroupMemoryLength:8 * buffer_size_aligned atIndex:0];
867+
// only for k-quants we use threadgroup memory
868+
if (ggml_blck_size(src0t) >= 64){
869+
[encoder setThreadgroupMemoryLength:8 * buffer_size_aligned atIndex:0];
870+
}
868871
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(64, 1, 1)];
869872
} else {
870873
switch (src0->type) {

ggml-metal.metal

Lines changed: 65 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1320,6 +1320,67 @@ kernel void kernel_mat_mv(device const void * src0,
13201320
}
13211321
}
13221322

1323+
template<typename block_q_type, int nr, int nsg, int nl, int n_shift, template<typename, typename, typename> class quant_dri>
1324+
kernel void kernel_mat_mv_no_tg_mem(device const void * src0,
1325+
device const float * src1,
1326+
device float * dst,
1327+
constant int64_t & ne00,
1328+
constant int64_t & ne01,
1329+
constant int64_t & ne02,
1330+
constant int64_t & ne10,
1331+
constant int64_t & ne12,
1332+
constant int64_t & ne0,
1333+
constant int64_t & ne1,
1334+
constant uint & gqa,
1335+
threadgroup uint * shared_memory[[threadgroup(0)]],
1336+
uint3 tgpig[[threadgroup_position_in_grid]],
1337+
uint tiisg[[thread_index_in_simdgroup]],
1338+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
1339+
const int nb = ne00/(nl * 16);
1340+
const int r0 = tgpig.x;
1341+
const int r1 = tgpig.y;
1342+
const int im = tgpig.z;
1343+
const int ix = tiisg / nl;
1344+
const int il = tiisg % nl;
1345+
const int first_row = (r0 * nsg) * nr + sgitg;
1346+
const uint offset0 = first_row * nb + im/gqa*(nb*ne0) + ix;
1347+
const uint offset1 = r1*ne10 + im*ne00*ne1 + ix * (nl * 16) + (il/(n_shift/8))*16*(n_shift/8) + (il%(n_shift/8)) * 8;
1348+
1349+
device const block_q_type * x = (device const block_q_type *) src0 + offset0;
1350+
device const float * y = (device const float *) src1 + offset1;
1351+
1352+
float4x4 yl; // src1 vector cache
1353+
float sumf[nr] = {0.f};
1354+
1355+
quant_dri<device const uint16_t *, device const block_q_type *, half4x4> dequan_worker;
1356+
dequan_worker.init(il);
1357+
1358+
// each thread in a SIMD group deals with 16 dequantized weights.
1359+
for (int ib = ix; ib < nb; ib += N_SIMDWIDTH / nl) {
1360+
yl[0] = *((device const float4 *)y);
1361+
yl[1] = *((device const float4 *)y + 1);
1362+
yl[2] = *((device const float4 *)y + n_shift/4);
1363+
yl[3] = *((device const float4 *)y + n_shift/4 + 1);
1364+
1365+
dequan_worker.inner_product_pre(il, yl);
1366+
#pragma unroll(nr)
1367+
for (int row = 0; row < nr; row++) {
1368+
float sum_temp = 0.f;
1369+
dequan_worker.inner_product(x + 2 * nb * row, il, yl, sum_temp);
1370+
sumf[row] += sum_temp;
1371+
}
1372+
x += N_SIMDWIDTH / nl;
1373+
y += N_SIMDWIDTH * 16;
1374+
}
1375+
1376+
for (int row = 0; row < nr; ++row) {
1377+
const float tot = simd_sum(sumf[row]);
1378+
if (tiisg == 0 && first_row + nsg * row < ne01) {
1379+
dst[r1*ne0 + im*ne0*ne1 + first_row + nsg * row] = tot;
1380+
}
1381+
}
1382+
}
1383+
13231384
#define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
13241385
#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix A
13251386
#define BLOCK_SIZE_K 32
@@ -1487,10 +1548,10 @@ typedef void (mat_mv_t)(device const void *, device const float *, device float
14871548

14881549
#define N_DST 4
14891550
#define N_SIMDGROUP 2
1490-
template [[host_name("kernel_mul_mv_f16_f32" )]] kernel mat_mv_t kernel_mat_mv<half4x4, N_DST, N_SIMDGROUP, 1, 8, f16_driver>;
1491-
template [[host_name("kernel_mul_mv_q4_0_f32")]] kernel mat_mv_t kernel_mat_mv<block_q4_0, N_DST, N_SIMDGROUP, 2, 16, q4_0_driver>;
1492-
template [[host_name("kernel_mul_mv_q4_1_f32")]] kernel mat_mv_t kernel_mat_mv<block_q4_1, N_DST, N_SIMDGROUP, 2, 16, q4_1_driver>;
1493-
template [[host_name("kernel_mul_mv_q8_0_f32")]] kernel mat_mv_t kernel_mat_mv<block_q8_0, N_DST, N_SIMDGROUP, 2, 8, q8_0_driver>;
1551+
template [[host_name("kernel_mul_mv_f16_f32" )]] kernel mat_mv_t kernel_mat_mv_no_tg_mem<half4x4, N_DST, N_SIMDGROUP, 1, 8, f16_driver>;
1552+
template [[host_name("kernel_mul_mv_q4_0_f32")]] kernel mat_mv_t kernel_mat_mv_no_tg_mem<block_q4_0, N_DST, N_SIMDGROUP, 2, 16, q4_0_driver>;
1553+
template [[host_name("kernel_mul_mv_q4_1_f32")]] kernel mat_mv_t kernel_mat_mv_no_tg_mem<block_q4_1, N_DST, N_SIMDGROUP, 2, 16, q4_1_driver>;
1554+
template [[host_name("kernel_mul_mv_q8_0_f32")]] kernel mat_mv_t kernel_mat_mv_no_tg_mem<block_q8_0, N_DST, N_SIMDGROUP, 2, 8, q8_0_driver>;
14941555
template [[host_name("kernel_mul_mv_q2_K_f32")]] kernel mat_mv_t kernel_mat_mv<block_q2_K, N_DST, N_SIMDGROUP, QK_NL, 8, q2_K_driver>;
14951556
template [[host_name("kernel_mul_mv_q3_K_f32")]] kernel mat_mv_t kernel_mat_mv<block_q3_K, N_DST, N_SIMDGROUP, QK_NL, 8, q3_K_driver>;
14961557
template [[host_name("kernel_mul_mv_q4_K_f32")]] kernel mat_mv_t kernel_mat_mv<block_q4_K, N_DST, N_SIMDGROUP, QK_NL, 32, q4_K_driver>;

0 commit comments

Comments
 (0)