Skip to content

Commit 993ca95

Browse files
ikawrakowIwan Kawrakow
andauthored
iq4_ks: faster dot product on Metal (#90)
TG-128(LLaMA-3.1-8B) goes to 52.5 t/s up from 48.4 t/s. Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
1 parent ff23008 commit 993ca95

File tree

1 file changed

+14
-10
lines changed

1 file changed

+14
-10
lines changed

ggml/src/ggml-metal.metal

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6079,6 +6079,7 @@ void kernel_mul_mv_iq4_ks_f32_impl(
60796079

60806080
float4 yl[4];
60816081
float2 sumf = 0.f;
6082+
float d[2];
60826083

60836084
device const float * yb = y + ix * QK_K + ib * 32 + il * 8;
60846085

@@ -6087,22 +6088,25 @@ void kernel_mul_mv_iq4_ks_f32_impl(
60876088

60886089
float4 qf1, qf2;
60896090

6091+
device const float * dptr = (device const float *)cx;
6092+
d[0] = *dptr;
6093+
device const block_iq4_ks * x = (device const block_iq4_ks *)(dptr + 1) + ix;
6094+
dptr += row_size/4;
6095+
d[1] = *dptr;
6096+
60906097
for (int ibl = ix; ibl < nb; ibl += 2) {
60916098

60926099
device const float4 * y4 = (device const float4 *)yb;
60936100
yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5];
60946101

6095-
device const float * dptr = (device const float *)cx;
6102+
device const uint8_t * scales = x->scales;
60966103

60976104
for (int row = 0; row < 2; ++row) {
60986105

6099-
//device const float * dptr = (device const float *)(cx + row*row_size);
6100-
const float d = *dptr;
6101-
device const block_iq4_ks * x = (device const block_iq4_ks *)(dptr + 1);
6102-
device const block_iq4_ks & xb = x[ibl];
6103-
device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il);
6106+
threadgroup const float * block_values = shared_values + ((scales[ib] & 1) << 4);
6107+
const float ls = ((scales[ib] & 254) - 127);
61046108

6105-
threadgroup const float * block_values = shared_values + ((xb.scales[ib] & 1) << 4);
6109+
device const uint32_t * q4 = (device const uint32_t *)scales + QK_K/128 + 4*ib + 2*il;
61066110

61076111
float4 acc1 = {0.f}, acc2 = {0.f};
61086112

@@ -6122,14 +6126,14 @@ void kernel_mul_mv_iq4_ks_f32_impl(
61226126

61236127
acc1 += acc2;
61246128

6125-
const int ls = (xb.scales[ib] & 254) - 127;
6126-
sumf[row] += d * ls * (acc1[0] + acc1[1] + acc1[2] + acc1[3]);
6129+
sumf[row] += d[row] * ls * (acc1[0] + acc1[1] + acc1[2] + acc1[3]);
61276130

6128-
dptr += row_size/4;
6131+
scales += row_size;
61296132

61306133
}
61316134

61326135
yb += 2 * QK_K;
6136+
x += 2;
61336137
}
61346138

61356139
sumf = simd_sum(sumf);

0 commit comments

Comments
 (0)