Skip to content

Commit

Permalink
ggml : make i-quants work with super-blocks of 64 (CPU,Metal) (ggerga…
Browse files Browse the repository at this point in the history
…nov#5760)

* WIP: make i-quants work for QK_K = 64

* iq2_xs: attempt to fix AVX dot product for QK_K = 64

Tests pass, but I get gibberish.

* QK_K = 64 tests pass on ARM_NEON and Metal

Sadly, that does not mean it actually works.

* Make CUDA compile with QK_K = 64

Tests don't pass, plus we get misaligned access

* Q2_K: fixed bug in imatrix quantization for QK_K = 64

* iq1_s: turn off SIMD implementation for QK_K = 64 (it does not work)

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
  • Loading branch information
ikawrakow and Kawrakow authored Feb 28, 2024
1 parent cb49e0f commit 7c4263d
Show file tree
Hide file tree
Showing 5 changed files with 194 additions and 59 deletions.
27 changes: 20 additions & 7 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -544,14 +544,19 @@ static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_fp16_t) + 3*(QK_K/8), "wrong

#define QR3_XS 8
#define QI3_XS (QK_K / (4*QR3_XS))
#if QK_K == 64
#define IQ3S_N_SCALE 2
#else
#define IQ3S_N_SCALE QK_K/64
#endif
typedef struct {
half d;
uint8_t qs[QK_K/4];
uint8_t qh[QK_K/32];
uint8_t signs[QK_K/8];
uint8_t scales[QK_K/64];
uint8_t scales[IQ3S_N_SCALE];
} block_iq3_s;
static_assert(sizeof(block_iq3_s) == sizeof(ggml_fp16_t) + 27*(QK_K/64), "wrong iq3_s block size/padding");
static_assert(sizeof(block_iq3_s) == sizeof(ggml_fp16_t) + 13*(QK_K/32) + IQ3S_N_SCALE, "wrong iq3_s block size/padding");

#define QR1_S 8
#define QI1_S (QK_K / (4*QR1_S))
Expand All @@ -571,6 +576,11 @@ typedef struct {
} block_iq4_nl;
static_assert(sizeof(block_iq4_nl) == sizeof(ggml_fp16_t) + QK4_NL/2, "wrong iq4_nl block size/padding");

#if QK_K == 64
#define block_iq4_xs block_iq4_nl
#define QR4_XS QR4_NL
#define QI4_XS QI4_NL
#else
// QR4_XS = 8 is very slightly faster than QR4_XS = 4
#define QR4_XS 8
#define QI4_XS (QK_K / (4*QR4_XS))
Expand All @@ -581,7 +591,7 @@ typedef struct {
uint8_t qs[QK_K/2];
} block_iq4_xs;
static_assert(sizeof(block_iq4_xs) == sizeof(ggml_fp16_t) + sizeof(uint16_t) + QK_K/64 + QK_K/2, "wrong iq4_xs block size/padding");

#endif

#define WARP_SIZE 32
#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
Expand Down Expand Up @@ -2439,9 +2449,9 @@ static __global__ void dequantize_block_iq4_nl(const void * __restrict__ vx, dst

}

#if QK_K != 64
template<typename dst_t>
static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) {

const int i = blockIdx.x;
const block_iq4_xs * x = (const block_iq4_xs *)vx;

Expand All @@ -2455,8 +2465,8 @@ static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst
y[j+ 0] = d * kvalues_iq4nl[q4[j] & 0xf];
y[j+16] = d * kvalues_iq4nl[q4[j] >> 4];
}

}
#endif

static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {

Expand Down Expand Up @@ -5382,8 +5392,7 @@ static __device__ __forceinline__ float vec_dot_iq4_xs_q8_1(
return 0.f;
#endif
#else
assert(false);
return 0.f;
return vec_dot_iq4_xs_q8_1(vbq, bq8_1, iqs);
#endif
}

Expand Down Expand Up @@ -7444,7 +7453,11 @@ static void dequantize_row_iq4_nl_cuda(const void * vx, dst_t * y, const int k,
template<typename dst_t>
static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
const int nb = (k + QK_K - 1) / QK_K;
#if QK_K == 64
dequantize_block_iq4_nl<<<nb, 32, 0, stream>>>(vx, y);
#else
dequantize_block_iq4_xs<<<nb, 32, 0, stream>>>(vx, y);
#endif
}

template <typename src_t, typename dst_t>
Expand Down
58 changes: 30 additions & 28 deletions ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -2560,12 +2560,16 @@ typedef struct {
uint8_t qs[QK4_NL/2];
} block_iq4_nl;

#if QK_K == 64
#define block_iq4_xs block_iq4_nl
#else
typedef struct {
half d;
uint16_t scales_h;
uint8_t scales_l[QK_K/64];
uint8_t qs[QK_K/2];
} block_iq4_xs;
#endif

//====================================== dot products =========================

Expand Down Expand Up @@ -4346,7 +4350,6 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
threadgroup_barrier(mem_flags::mem_threadgroup);
}

#if QK_K == 256
const int ix = tiisg;

device const float * y4 = y + 32 * ix;
Expand Down Expand Up @@ -4387,12 +4390,6 @@ void kernel_mul_mv_iq2_xxs_f32_impl(

y4 += 32 * 32;
}
#else
(void) x;
(void) y;
(void) yl;
(void) nb32;
#endif

for (int row = 0; row < N_DST; ++row) {
all_sum = simd_sum(sumf[row]);
Expand Down Expand Up @@ -4482,7 +4479,6 @@ void kernel_mul_mv_iq2_xs_f32_impl(
threadgroup_barrier(mem_flags::mem_threadgroup);
}

#if QK_K == 256
const int ix = tiisg;

device const float * y4 = y + 32 * ix;
Expand Down Expand Up @@ -4533,12 +4529,6 @@ void kernel_mul_mv_iq2_xs_f32_impl(

y4 += 32 * 32;
}
#else
(void) x;
(void) y;
(void) yl;
(void) nb32;
#endif

for (int row = 0; row < N_DST; ++row) {
all_sum = simd_sum(sumf[row]);
Expand Down Expand Up @@ -4628,7 +4618,6 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
threadgroup_barrier(mem_flags::mem_threadgroup);
}

#if QK_K == 256
const int ix = tiisg;

device const float * y4 = y + 32 * ix;
Expand Down Expand Up @@ -4672,12 +4661,6 @@ void kernel_mul_mv_iq3_xxs_f32_impl(

y4 += 32 * 32;
}
#else
(void) x;
(void) y;
(void) yl;
(void) nb32;
#endif

for (int row = 0; row < N_DST; ++row) {
all_sum = simd_sum(sumf[row]);
Expand Down Expand Up @@ -5016,7 +4999,6 @@ void kernel_mul_mv_iq1_s_f32_impl(

const int nb32 = nb * (QK_K / 32);

#if QK_K == 256
const int ix = tiisg/2;
const int il = tiisg%2;

Expand Down Expand Up @@ -5055,12 +5037,6 @@ void kernel_mul_mv_iq1_s_f32_impl(

y4 += 16 * 32;
}
#else
(void) x;
(void) y;
(void) yl;
(void) nb32;
#endif

for (int row = 0; row < N_DST; ++row) {
all_sum = simd_sum(sumf[row]);
Expand Down Expand Up @@ -5167,6 +5143,7 @@ void kernel_mul_mv_iq4_nl_f32_impl(
}
}

#if QK_K != 64
void kernel_mul_mv_iq4_xs_f32_impl(
device const void * src0,
device const float * src1,
Expand Down Expand Up @@ -5260,6 +5237,7 @@ void kernel_mul_mv_iq4_xs_f32_impl(
}
}
}
#endif

[[host_name("kernel_mul_mv_iq1_s_f32")]]
kernel void kernel_mul_mv_iq1_s_f32(
Expand Down Expand Up @@ -5344,7 +5322,11 @@ kernel void kernel_mul_mv_iq4_xs_f32(
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {

#if QK_K == 64
kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
#else
kernel_mul_mv_iq4_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
#endif
}

//============================= templates and their specializations =============================
Expand Down Expand Up @@ -5770,6 +5752,9 @@ void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4

template <typename type4x4>
void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 & reg) {
#if QK_K == 64
dequantize_iq4_nl(xb, il, reg);
#else
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
const int ib32 = il/2;
il = il%2;
Expand All @@ -5786,6 +5771,7 @@ void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4
reg[i][2] = d * kvalues_iq4nl_f[q8[2]];
reg[i][3] = d * kvalues_iq4nl_f[q8[3]];
}
#endif
}

template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
Expand Down Expand Up @@ -6334,7 +6320,11 @@ template [[host_name("kernel_get_rows_iq3_s")]] kernel get_rows_t kernel_get_r
template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_t kernel_get_rows<block_iq2_s, QK_NL, dequantize_iq2_s>;
template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_t kernel_get_rows<block_iq1_s, QK_NL, dequantize_iq1_s>;
template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_t kernel_get_rows<block_iq4_nl, 2, dequantize_iq4_nl>;
#if QK_K == 64
template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows<block_iq4_xs, 2, dequantize_iq4_xs>;
#else
template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
#endif

//
// matrix-matrix multiplication
Expand Down Expand Up @@ -6378,7 +6368,11 @@ template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_m
template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_s, QK_NL, dequantize_iq2_s>;
template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_s, QK_NL, dequantize_iq1_s>;
template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_nl>;
#if QK_K == 64
template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_xs>;
#else
template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
#endif

//
// indirect matrix-matrix multiplication
Expand Down Expand Up @@ -6434,7 +6428,11 @@ template [[host_name("kernel_mul_mm_id_iq3_s_f32")]] kernel mat_mm_id_t kernel
template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_s, QK_NL, dequantize_iq2_s>;
template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_s, QK_NL, dequantize_iq1_s>;
template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_nl, 2, dequantize_iq4_nl>;
#if QK_K == 64
template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs, 2, dequantize_iq4_xs>;
#else
template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
#endif

//
// matrix-vector multiplication
Expand Down Expand Up @@ -7707,7 +7705,11 @@ kernel void kernel_mul_mv_id_iq4_xs_f32(

const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];

#if QK_K == 64
kernel_mul_mv_iq4_nl_f32_impl(
#else
kernel_mul_mv_iq4_xs_f32_impl(
#endif
src0[id],
(device const float *) (src1 + bid*nb11),
dst + bid*ne0,
Expand Down
Loading

0 comments on commit 7c4263d

Please sign in to comment.