Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

k-quants #1684

Merged
merged 32 commits into from
Jun 5, 2023
Merged
Changes from 1 commit
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
8673a41
Starting to add k-quantization to ggml
Kawrakow May 27, 2023
b4f7134
Adding Q3_K and Q8_K (de)-quantization
Kawrakow May 27, 2023
c93cce3
Q3_K now working on CUDA and AVX2/scalar
Kawrakow May 28, 2023
a3c0673
Some improvement for Q3_K on CUDA
Kawrakow May 28, 2023
3d8b1de
Some more CUDA optimizations for Q3_K
Kawrakow May 29, 2023
a0b8e9f
Adding Q4_K - scalar, AVX2, CUDA
Kawrakow May 29, 2023
cf221af
Adding Q6_K - scalar, AVX2, CUDA
Kawrakow May 29, 2023
b835d0f
Adding Q5_K - scalar, AVX2, CUDA
Kawrakow May 29, 2023
5c5191a
Per convention, all QX_K quantizations use Q5_K for output.weight
Kawrakow May 29, 2023
d537b97
Adding quantization mixes
Kawrakow May 29, 2023
54f808d
Quantization mixes: didn't quite get what I wanted in the last commit
Kawrakow May 29, 2023
a2533a7
Q4_K dot product for ARM_NEON
Kawrakow May 30, 2023
5ca15ce
Q6_K dot product for ARM_NEON
Kawrakow May 30, 2023
a197eb5
Q5_K dot product for ARM_NEON
Kawrakow May 30, 2023
13264fa
Adding Q3_K dot for ARM_NEON
Kawrakow May 30, 2023
4faa040
A very slightly faster ARM_NEON Q3_K dot
Kawrakow May 31, 2023
b439efb
Adding Q2_K - just CUDA for now
Kawrakow May 31, 2023
8516fdf
Adding scalar and AVX2 Q2_K dot
Kawrakow May 31, 2023
6ec7057
Adding ARM_NEON Q2_K dot
Kawrakow May 31, 2023
7bcc376
A slightly faster ARM_NEON Q2_K dot
Kawrakow Jun 1, 2023
e51ce72
Fixed bug in Q2_K CUDA dot product kernel
Kawrakow Jun 1, 2023
c5959d5
Don't print zeros/NaNs when no count histogram has been collected
Kawrakow Jun 1, 2023
9a9c5a0
A 10% faster CUDA vector dot kernel for Q3_K
Kawrakow Jun 1, 2023
894210a
A slightly daster Q4_K AVX2 dot product
Kawrakow Jun 2, 2023
abd99a8
A slightly faster ARM_NEON A4_K dot product
Kawrakow Jun 3, 2023
8f5d42d
Minor
Kawrakow Jun 3, 2023
6ef1382
Fix quantization error test
Kawrakow Jun 3, 2023
0a71a4e
Fix docker build
Kawrakow Jun 3, 2023
431693c
Added forgotten ggml.o dependence on k_quants.h to the Makefile
Kawrakow Jun 4, 2023
32a5f3a
Had unintentionally committed the Makefile with -Ofast enabled
Kawrakow Jun 4, 2023
12d4344
ggml : rename k_quants -> ggml-quants-k, use lowercase in code
ggerganov Jun 5, 2023
af275fa
Merge branch 'master' into ik/k_quants
ggerganov Jun 5, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
A very slightly faster ARM_NEON Q3_K dot
  • Loading branch information
Kawrakow committed Jun 3, 2023
commit 4faa040c20e2f92d2c7e44cf24146400200b89fa
38 changes: 18 additions & 20 deletions k_quants.c
Original file line number Diff line number Diff line change
Expand Up @@ -904,6 +904,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
const uint8x16_t m1 = vshlq_n_u8(m0, 1);
const uint8x16_t m2 = vshlq_n_u8(m0, 2);
const uint8x16_t m3 = vshlq_n_u8(m0, 3);
const int8_t m32 = 32;

int8x16x4_t q3bytes;

Expand All @@ -930,7 +931,8 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);
utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4);

const int8_t * scale = (const int8_t *)utmp;
int8_t * scale = (int8_t *)utmp;
for (int j = 0; j < 16; ++j) scale[j] -= m32;

for (int j = 0; j < QK_K/128; ++j) {

Expand All @@ -949,10 +951,10 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 2), m3b)), q3h.val[3]);

#if defined(__ARM_FEATURE_DOTPROD)
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes_1.val[0])) * (scale[0] - 32);
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes_1.val[1])) * (scale[1] - 32);
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes_1.val[2])) * (scale[2] - 32);
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes_1.val[3])) * (scale[3] - 32);
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes_1.val[0])) * scale[0];
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes_1.val[1])) * scale[1];
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes_1.val[2])) * scale[2];
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes_1.val[3])) * scale[3];
#else
int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[0]), vget_low_s8 (q8bytes_1.val[0])),
vmull_s8(vget_high_s8(q3bytes.val[0]), vget_high_s8(q8bytes_1.val[0])));
Expand All @@ -962,10 +964,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
vmull_s8(vget_high_s8(q3bytes.val[2]), vget_high_s8(q8bytes_1.val[2])));
int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[3]), vget_low_s8 (q8bytes_1.val[3])),
vmull_s8(vget_high_s8(q3bytes.val[3]), vget_high_s8(q8bytes_1.val[3])));
isum += vaddvq_s16(p0) * (scale[0] - 32) +
vaddvq_s16(p1) * (scale[1] - 32) +
vaddvq_s16(p2) * (scale[2] - 32) +
vaddvq_s16(p3) * (scale[3] - 32);
isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1] + vaddvq_s16(p2) * scale[2] + vaddvq_s16(p3) * scale[3];
#endif
scale += 4;

Expand All @@ -974,19 +973,16 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
q3h.val[2] = vshrq_n_u8(vbicq_u8(m3, qhbits.val[0]), 1);
q3h.val[3] = vshrq_n_u8(vbicq_u8(m3, qhbits.val[1]), 1);

qhbits.val[0] = vshrq_n_u8(qhbits.val[0], 4);
qhbits.val[1] = vshrq_n_u8(qhbits.val[1], 4);

q3bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 4), m3b)), q3h.val[0]);
q3bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 4), m3b)), q3h.val[1]);
q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 6), m3b)), q3h.val[2]);
q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 6), m3b)), q3h.val[3]);

#if defined(__ARM_FEATURE_DOTPROD)
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes_2.val[0])) * (scale[0] - 32);
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes_2.val[1])) * (scale[1] - 32);
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes_2.val[2])) * (scale[2] - 32);
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes_2.val[3])) * (scale[3] - 32);
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes_2.val[0])) * scale[0];
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes_2.val[1])) * scale[1];
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes_2.val[2])) * scale[2];
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes_2.val[3])) * scale[3];
#else
p0 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[0]), vget_low_s8 (q8bytes_2.val[0])),
vmull_s8(vget_high_s8(q3bytes.val[0]), vget_high_s8(q8bytes_2.val[0])));
Expand All @@ -996,13 +992,15 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
vmull_s8(vget_high_s8(q3bytes.val[2]), vget_high_s8(q8bytes_2.val[2])));
p3 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[3]), vget_low_s8 (q8bytes_2.val[3])),
vmull_s8(vget_high_s8(q3bytes.val[3]), vget_high_s8(q8bytes_2.val[3])));
isum += vaddvq_s16(p0) * (scale[0] - 32) +
vaddvq_s16(p1) * (scale[1] - 32) +
vaddvq_s16(p2) * (scale[2] - 32) +
vaddvq_s16(p3) * (scale[3] - 32);
isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1] + vaddvq_s16(p2) * scale[2] + vaddvq_s16(p3) * scale[3];
#endif
scale += 4;

if (j == 0) {
qhbits.val[0] = vshrq_n_u8(qhbits.val[0], 4);
qhbits.val[1] = vshrq_n_u8(qhbits.val[1], 4);
}

}
sum += d * isum;

Expand Down