Skip to content

Commit bcf5bda

Browse files
authored
Vulkan MMQ Integer Dot Refactor and K-Quant support (#16536)
* vulkan: add mmq q2_k integer dot support * Refactor mmq caching * Reduce mmq register use * Load 4 quant blocks into shared memory in one step * Pack q2_k blocks into caches of 32 * Use 32-bit accumulators for integer dot matmul * Add q4_k mmq * Add q3_k mmq * Add q5_k mmq * Add q6_k mmq * Add mxfp4 mmq, enable MMQ MUL_MAT_ID * Fix mmv dm loads
1 parent 3eb2be1 commit bcf5bda

18 files changed

+928
-405
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 140 additions & 25 deletions
Large diffs are not rendered by default.

ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -437,7 +437,7 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
437437
#if defined(DATA_A_MXFP4)
438438
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
439439
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
440-
return vec2(kvalues_mxfp4[vui & 0xF], kvalues_mxfp4[vui >> 4]);
440+
return vec2(kvalues_mxfp4[vui & 0xF], kvalues_mxfp4[vui >> 4]) * 0.5;
441441
}
442442
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
443443
vec2 v0 = dequantize(ib, iqs, a_offset);
@@ -488,9 +488,9 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
488488

489489
const uvec2 qs = uvec2(data_a[a_offset + ib].qs[qsi], data_a[a_offset + ib].qs[qsi + 1]);
490490
const uint scales = data_a[a_offset + ib].scales[scalesi];
491-
const vec2 d = vec2(data_a[a_offset + ib].d);
491+
const vec2 dm = vec2(data_a[a_offset + ib].dm);
492492

493-
return d.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - d.y * float(scales >> 4);
493+
return dm.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - dm.y * float(scales >> 4);
494494
}
495495
vec2 get_dm(uint ib, uint a_offset) {
496496
return vec2(1, 0);
@@ -529,7 +529,7 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
529529
const uint is = 2 * n + b; // 0..7
530530
const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126
531531

532-
const vec2 loadd = vec2(data_a[a_offset + ib].d);
532+
const vec2 loadd = vec2(data_a[a_offset + ib].dm);
533533

534534
const uint scidx0 = (is < 4) ? is : (is + 4);
535535
const uint scidx1 = (is < 4) ? is : (is - 4);
@@ -567,7 +567,7 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
567567

568568
const uint8_t hm = uint8_t(1 << (iqs / 16));
569569

570-
const vec2 loadd = vec2(data_a[a_offset + ib].d);
570+
const vec2 loadd = vec2(data_a[a_offset + ib].dm);
571571

572572
const uint scidx0 = (is < 4) ? is : (is + 4);
573573
const uint scidx1 = (is < 4) ? is : (is - 4);

ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ2
120120
float16_t dequantFuncQ2_K(const in decodeBufQ2_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
121121
{
122122
decodeBufQ2_K_packed16 bl16 = decodeBufQ2_K_packed16(bl);
123-
const f16vec2 d = bl.block.d;
123+
const f16vec2 dm = bl.block.dm;
124124
const uint idx = coordInBlock[1];
125125

126126
const uint scalesi = (idx & 0xF0) >> 4; // 0..15
@@ -131,7 +131,7 @@ float16_t dequantFuncQ2_K(const in decodeBufQ2_K bl, const in uint blockCoords[2
131131
qs = unpack8(qs)[idx & 1];
132132

133133
const uint scales = bl.block.scales[scalesi];
134-
float16_t ret = d.x * float16_t(scales & 0xF) * float16_t(qs) - d.y * float16_t(scales >> 4);
134+
float16_t ret = dm.x * float16_t(scales & 0xF) * float16_t(qs) - dm.y * float16_t(scales >> 4);
135135
return ret;
136136
}
137137

@@ -680,7 +680,7 @@ float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords
680680
uint32_t qs = bl.block.qs[iqs];
681681
qs >>= shift;
682682
qs &= 0xF;
683-
float16_t ret = float16_t(kvalues_mxfp4[qs] * d);
683+
float16_t ret = float16_t(kvalues_mxfp4[qs] * d * 0.5);
684684
return ret;
685685
}
686686
#endif

ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ void main() {
2626
const float d = e8m0_to_fp32(data_a[ib].e);
2727

2828
[[unroll]] for (uint l = 0; l < 8; ++l) {
29-
data_b[b_idx + l + 0] = D_TYPE(d * kvalues_mxfp4[data_a[ib].qs[q_idx + l] & 0xF]);
30-
data_b[b_idx + l + 16] = D_TYPE(d * kvalues_mxfp4[data_a[ib].qs[q_idx + l] >> 4]);
29+
data_b[b_idx + l + 0] = D_TYPE(d * 0.5 * float(kvalues_mxfp4[data_a[ib].qs[q_idx + l] & 0xF]));
30+
data_b[b_idx + l + 16] = D_TYPE(d * 0.5 * float(kvalues_mxfp4[data_a[ib].qs[q_idx + l] >> 4]));
3131
}
3232
}

ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ void main() {
2424
const uint ql_idx = 32 * ip + il;
2525
const uint8_t qs = data_a[i].qs[32 * ip + il];
2626

27-
FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].d.x);
28-
FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].d.y);
27+
FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].dm.x);
28+
FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].dm.y);
2929
data_b[y_idx + 0] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+0] & 0xF) * ((qs >> 0) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+0] >> 4));
3030
data_b[y_idx + 32] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+2] & 0xF) * ((qs >> 2) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+2] >> 4));
3131
data_b[y_idx + 64] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+4] & 0xF) * ((qs >> 4) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+4] >> 4));

ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ void main() {
2020
const uint is = 2 * il;
2121
const uint n = 4;
2222

23-
const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].d.x);
24-
const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].d.y);
23+
const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].dm.x);
24+
const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].dm.y);
2525

2626
const uint y_idx = ib * QUANT_K + 64 * il + n * ir;
2727
const uint qs_idx = 32*il + n * ir;

ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ void main() {
1919
const uint ir = tid % 16;
2020
const uint is = 2 * il;
2121

22-
const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].d.x);
23-
const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].d.y);
22+
const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].dm.x);
23+
const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].dm.y);
2424

2525
const uint y_idx = ib * QUANT_K + 64 * il + 2 * ir;
2626
const uint qs_idx = 32*il + 2 * ir;

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
4141
const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303));
4242
const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303));
4343

44-
vec2 d = vec2(data_a[ib0 + i].d);
45-
const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
46-
const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
44+
const FLOAT_TYPE_VEC2 dm = vec2(data_a[ib0 + i].dm);
4745

4846
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
4947
vec2 b0 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]);
@@ -75,7 +73,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
7573
fma(FLOAT_TYPE(b96[l]), sccache2[csel][ix][6 + 8*v_im],
7674
fma(FLOAT_TYPE(b112[l]), sccache2[csel][ix][7 + 8*v_im], sum2))))))));
7775
}
78-
temp[j][n] = fma(dall, sum1, fma(-dmin, sum2, temp[j][n]));
76+
temp[j][n] = fma(dm.x, sum1, fma(-dm.y, sum2, temp[j][n]));
7977
}
8078
}
8179
}

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im,
1414

1515
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
1616
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
17-
vec2 d = vec2(data_a[ib0 + i].d);
18-
const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
19-
const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
17+
const FLOAT_TYPE_VEC2 dm = FLOAT_TYPE_VEC2(data_a[ib0 + i].dm);
2018

2119
const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ];
2220
const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2];
@@ -81,7 +79,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im,
8179
fma(FLOAT_TYPE(by10.y), sc2, fma(FLOAT_TYPE(by132.y), sc3, fma(FLOAT_TYPE(by20.y), sc6, fma(FLOAT_TYPE(by232.y), sc7,
8280
fma(FLOAT_TYPE(by10.z), sc2, fma(FLOAT_TYPE(by132.z), sc3, fma(FLOAT_TYPE(by20.z), sc6, fma(FLOAT_TYPE(by232.z), sc7,
8381
fma(FLOAT_TYPE(by10.w), sc2, fma(FLOAT_TYPE(by132.w), sc3, fma(FLOAT_TYPE(by20.w), sc6, FLOAT_TYPE(by232.w) * sc7)))))))))))))));
84-
temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n]));
82+
temp[j][n] = fma(dm.x, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dm.y, smin, temp[j][n]));
8583
}
8684
}
8785
}

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im,
1414

1515
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
1616
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
17-
vec2 d = vec2(data_a[ib0 + i].d);
18-
const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
19-
const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
17+
const FLOAT_TYPE_VEC2 dm = FLOAT_TYPE_VEC2(data_a[ib0 + i].dm);
2018

2119
const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ];
2220
const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2];
@@ -113,7 +111,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im,
113111
fma(FLOAT_TYPE(by132.x) + FLOAT_TYPE(by132.y) + FLOAT_TYPE(by148.x) + FLOAT_TYPE(by148.y), sc3,
114112
fma(FLOAT_TYPE(by20.x) + FLOAT_TYPE(by20.y) + FLOAT_TYPE(by216.x) + FLOAT_TYPE(by216.y), sc6,
115113
(FLOAT_TYPE(by232.x) + FLOAT_TYPE(by232.y) + FLOAT_TYPE(by248.x) + FLOAT_TYPE(by248.y)) * sc7)));
116-
temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n]));
114+
temp[j][n] = fma(dm.x, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dm.y, smin, temp[j][n]));
117115
}
118116
}
119117
}

0 commit comments

Comments
 (0)