Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

k-quants #1684

Merged
merged 32 commits into from
Jun 5, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
8673a41
Starting to add k-quantization to ggml
Kawrakow May 27, 2023
b4f7134
Adding Q3_K and Q8_K (de)-quantization
Kawrakow May 27, 2023
c93cce3
Q3_K now working on CUDA and AVX2/scalar
Kawrakow May 28, 2023
a3c0673
Some improvement for Q3_K on CUDA
Kawrakow May 28, 2023
3d8b1de
Some more CUDA optimizations for Q3_K
Kawrakow May 29, 2023
a0b8e9f
Adding Q4_K - scalar, AVX2, CUDA
Kawrakow May 29, 2023
cf221af
Adding Q6_K - scalar, AVX2, CUDA
Kawrakow May 29, 2023
b835d0f
Adding Q5_K - scalar, AVX2, CUDA
Kawrakow May 29, 2023
5c5191a
Per convention, all QX_K quantizations use Q5_K for output.weight
Kawrakow May 29, 2023
d537b97
Adding quantization mixes
Kawrakow May 29, 2023
54f808d
Quantization mixes: didn't quite get what I wanted in the last commit
Kawrakow May 29, 2023
a2533a7
Q4_K dot product for ARM_NEON
Kawrakow May 30, 2023
5ca15ce
Q6_K dot product for ARM_NEON
Kawrakow May 30, 2023
a197eb5
Q5_K dot product for ARM_NEON
Kawrakow May 30, 2023
13264fa
Adding Q3_K dot for ARM_NEON
Kawrakow May 30, 2023
4faa040
A very slightly faster ARM_NEON Q3_K dot
Kawrakow May 31, 2023
b439efb
Adding Q2_K - just CUDA for now
Kawrakow May 31, 2023
8516fdf
Adding scalar and AVX2 Q2_K dot
Kawrakow May 31, 2023
6ec7057
Adding ARM_NEON Q2_K dot
Kawrakow May 31, 2023
7bcc376
A slightly faster ARM_NEON Q2_K dot
Kawrakow Jun 1, 2023
e51ce72
Fixed bug in Q2_K CUDA dot product kernel
Kawrakow Jun 1, 2023
c5959d5
Don't print zeros/NaNs when no count histogram has been collected
Kawrakow Jun 1, 2023
9a9c5a0
A 10% faster CUDA vector dot kernel for Q3_K
Kawrakow Jun 1, 2023
894210a
A slightly daster Q4_K AVX2 dot product
Kawrakow Jun 2, 2023
abd99a8
A slightly faster ARM_NEON A4_K dot product
Kawrakow Jun 3, 2023
8f5d42d
Minor
Kawrakow Jun 3, 2023
6ef1382
Fix quantization error test
Kawrakow Jun 3, 2023
0a71a4e
Fix docker build
Kawrakow Jun 3, 2023
431693c
Added forgotten ggml.o dependence on k_quants.h to the Makefile
Kawrakow Jun 4, 2023
32a5f3a
Had unintentionally committed the Makefile with -Ofast enabled
Kawrakow Jun 4, 2023
12d4344
ggml : rename k_quants -> ggml-quants-k, use lowercase in code
ggerganov Jun 5, 2023
af275fa
Merge branch 'master' into ik/k_quants
ggerganov Jun 5, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Adding Q2_K - just CUDA for now
Token prediction is pretty good - about 15.5 ms on a RTX 4080.
Perplexity is about the same as Q4_K.
  • Loading branch information
Kawrakow committed Jun 3, 2023
commit b439efb7129c5f2eca243116c158d2a056322273
1 change: 1 addition & 0 deletions examples/quantize/quantize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ static const std::map<std::string, llama_ftype> LLAMA_FTYPE_MAP = {
{"q5_0", LLAMA_FTYPE_MOSTLY_Q5_0},
{"q5_1", LLAMA_FTYPE_MOSTLY_Q5_1},
{"q8_0", LLAMA_FTYPE_MOSTLY_Q8_0},
{"q2_K", LLAMA_FTYPE_MOSTLY_Q2_K},
{"q3_K", LLAMA_FTYPE_MOSTLY_Q3_K_M},
{"q3_K_S", LLAMA_FTYPE_MOSTLY_Q3_K_S},
{"q3_K_M", LLAMA_FTYPE_MOSTLY_Q3_K_M},
Expand Down
76 changes: 76 additions & 0 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,14 @@ static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 blo

#define QK_K 256

typedef struct {
uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
uint8_t qs[QK_K/4]; // quants
half d; // super-block scale for quantized scales
half dmin; // super-block scale for quantized mins
} block_q2_K;
static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_fp16_t) + QK_K/16 + QK_K/4, "wrong q2_K block size/padding");

typedef struct {
uint8_t hmask[QK_K/8];
uint8_t qs[QK_K/4]; // nibbles / quants
Expand Down Expand Up @@ -225,6 +233,59 @@ static __device__ void dequantize_q8_0(const void * vx, const int ib, const int

//================================== k-quants

static __global__ void dequantize_block_q2_K(const void * vx, float * yy) {

const int i = blockIdx.x;
const int tid = threadIdx.x;
const int n = tid/32;
const int l = tid - 32*n;
const int is = 8*n + l/16;

const block_q2_K * x = (const block_q2_K *) vx;

const uint8_t q = x[i].qs[32*n + l];
float * y = yy + i*QK_K + 128*n;

float dall = x[i].d;
float dmin = x[i].dmin;
y[l+ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4);
y[l+32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 2) & 3) - dmin * (x[i].scales[is+2] >> 4);
y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4);
y[l+96] = dall * (x[i].scales[is+6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is+6] >> 4);

}

static __device__ void vec_dot_q2_K(const void * vx, const int ib, const int iqs, const float * yy, float & result) {

const block_q2_K * x = (const block_q2_K *) vx;

// if n is 0, we want to do the lower 128, else the upper 128,
// covering y[l+0], y[l+32], y[l+64], y[l+96] and
// y[l+16], y[l+48], y[l+80], y[l+112]
int n = iqs/128; // 0 or 1
int r = iqs - 128*n; // 0...120 in steps of 8
int l = r/8; // 0...15 in steps of 1

const float * y = yy + 128*n + l;
const uint8_t * q = x[ib].qs + 32*n + l;
const uint8_t * s = x[ib].scales + 8*n;

const float dall = x[ib].d;
const float dmin = x[ib].dmin;

float sum = y[ 0] * (dall * ((s[0] & 0xF) * ((q[ 0] >> 0) & 3)) - dmin * (s[0] >> 4))
+ y[ 32] * (dall * ((s[2] & 0xF) * ((q[ 0] >> 2) & 3)) - dmin * (s[2] >> 4))
+ y[ 64] * (dall * ((s[4] & 0xF) * ((q[ 0] >> 4) & 3)) - dmin * (s[4] >> 4))
+ y[ 96] * (dall * ((s[6] & 0xF) * ((q[ 0] >> 6) & 3)) - dmin * (s[6] >> 4))
+ y[ 16] * (dall * ((s[1] & 0xF) * ((q[16] >> 0) & 3)) - dmin * (s[1] >> 4))
+ y[ 48] * (dall * ((s[3] & 0xF) * ((q[16] >> 2) & 3)) - dmin * (s[1] >> 4))
+ y[ 80] * (dall * ((s[5] & 0xF) * ((q[16] >> 4) & 3)) - dmin * (s[5] >> 4))
+ y[112] * (dall * ((s[7] & 0xF) * ((q[16] >> 6) & 3)) - dmin * (s[7] >> 4));

result = sum;

}

static __global__ void dequantize_block_q3_K(const void * vx, float * yy) {

int r = threadIdx.x/4;
Expand Down Expand Up @@ -625,6 +686,11 @@ static void dequantize_row_q8_0_cuda(const void * vx, float * y, const int k, cu
dequantize_block<QK8_0, QR8_0, dequantize_q8_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
}

static void dequantize_row_q2_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
const int nb = k / QK_K;
dequantize_block_q2_K<<<nb, 64, 0, stream>>>(vx, y);
}

static void dequantize_row_q3_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
const int nb = k / QK_K;
dequantize_block_q3_K<<<nb, 64, 0, stream>>>(vx, y);
Expand Down Expand Up @@ -685,6 +751,12 @@ static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const float * y, f
<<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0, stream>>>(vx, y, dst, ncols);
}

static void dequantize_mul_mat_vec_q2_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % QK_K == 0);
const dim3 block_dims(32, 2, 1);
dequantize_mul_mat_vec_k<32, vec_dot_q2_K><<<nrows/2, block_dims, 0, stream>>>(vx, y, dst, ncols);
}

static void dequantize_mul_mat_vec_q3_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % QK_K == 0);
const dim3 block_dims(32, 2, 1);
Expand Down Expand Up @@ -734,6 +806,8 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
return dequantize_row_q5_1_cuda;
case GGML_TYPE_Q8_0:
return dequantize_row_q8_0_cuda;
case GGML_TYPE_Q2_K:
return dequantize_row_q2_K_cuda;
case GGML_TYPE_Q3_K:
return dequantize_row_q3_K_cuda;
case GGML_TYPE_Q4_K:
Expand Down Expand Up @@ -761,6 +835,8 @@ static dequantize_mul_mat_vec_cuda_t ggml_get_dequantize_mul_mat_vec_cuda(ggml_t
return dequantize_mul_mat_vec_q5_1_cuda;
case GGML_TYPE_Q8_0:
return dequantize_mul_mat_vec_q8_0_cuda;
case GGML_TYPE_Q2_K:
return dequantize_mul_mat_vec_q2_K_cuda;
case GGML_TYPE_Q3_K:
return dequantize_mul_mat_vec_q3_K_cuda;
case GGML_TYPE_Q4_K:
Expand Down
35 changes: 31 additions & 4 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -1566,6 +1566,14 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
.vec_dot_q = NULL, // TODO
.vec_dot_type = GGML_TYPE_Q8_1,
},
[GGML_TYPE_Q2_K] = {
.dequantize_row_q = (dequantize_row_q_t) dequantize_row_q2_K,
.quantize_row_q = quantize_row_q2_K,
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q2_K_reference,
.quantize_row_q_dot = quantize_row_q8_K,
.vec_dot_q = NULL, //ggml_vec_dot_q2_K_q8_K,
.vec_dot_type = GGML_TYPE_Q8_K,
},
[GGML_TYPE_Q3_K] = {
.dequantize_row_q = (dequantize_row_q_t) dequantize_row_q3_K,
.quantize_row_q = quantize_row_q3_K,
Expand Down Expand Up @@ -3477,6 +3485,7 @@ static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = {
[GGML_TYPE_Q5_1] = QK5_1,
[GGML_TYPE_Q8_0] = QK8_0,
[GGML_TYPE_Q8_1] = QK8_1,
[GGML_TYPE_Q2_K] = QK_K,
[GGML_TYPE_Q3_K] = QK_K,
[GGML_TYPE_Q4_K] = QK_K,
[GGML_TYPE_Q5_K] = QK_K,
Expand All @@ -3486,7 +3495,7 @@ static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = {
[GGML_TYPE_I16] = 1,
[GGML_TYPE_I32] = 1,
};
static_assert(GGML_TYPE_COUNT == 18, "GGML_BLCK_SIZE is outdated");
static_assert(GGML_TYPE_COUNT == 19, "GGML_BLCK_SIZE is outdated");

static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
[GGML_TYPE_F32] = sizeof(float),
Expand All @@ -3497,6 +3506,7 @@ static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
[GGML_TYPE_Q5_1] = sizeof(block_q5_1),
[GGML_TYPE_Q8_0] = sizeof(block_q8_0),
[GGML_TYPE_Q8_1] = sizeof(block_q8_1),
[GGML_TYPE_Q2_K] = sizeof(block_q2_K),
[GGML_TYPE_Q3_K] = sizeof(block_q3_K),
[GGML_TYPE_Q4_K] = sizeof(block_q4_K),
[GGML_TYPE_Q5_K] = sizeof(block_q5_K),
Expand All @@ -3506,7 +3516,7 @@ static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
[GGML_TYPE_I16] = sizeof(int16_t),
[GGML_TYPE_I32] = sizeof(int32_t),
};
static_assert(GGML_TYPE_COUNT == 18, "GGML_TYPE_SIZE is outdated");
static_assert(GGML_TYPE_COUNT == 19, "GGML_TYPE_SIZE is outdated");


static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
Expand All @@ -3518,6 +3528,7 @@ static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
[GGML_TYPE_Q5_1] = "q5_1",
[GGML_TYPE_Q8_0] = "q8_0",
[GGML_TYPE_Q8_1] = "q8_1",
[GGML_TYPE_Q2_K] = "q2_K",
[GGML_TYPE_Q3_K] = "q3_K",
[GGML_TYPE_Q4_K] = "q4_K",
[GGML_TYPE_Q5_K] = "q5_K",
Expand All @@ -3527,7 +3538,7 @@ static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
[GGML_TYPE_I16] = "i16",
[GGML_TYPE_I32] = "i32",
};
static_assert(GGML_TYPE_COUNT == 18, "GGML_TYPE_NAME is outdated");
static_assert(GGML_TYPE_COUNT == 19, "GGML_TYPE_NAME is outdated");

static bool GGML_IS_QUANTIZED[GGML_TYPE_COUNT] = {
[GGML_TYPE_F32] = false,
Expand All @@ -3538,6 +3549,7 @@ static bool GGML_IS_QUANTIZED[GGML_TYPE_COUNT] = {
[GGML_TYPE_Q5_1] = true,
[GGML_TYPE_Q8_0] = true,
[GGML_TYPE_Q8_1] = true,
[GGML_TYPE_Q2_K] = true,
[GGML_TYPE_Q3_K] = true,
[GGML_TYPE_Q4_K] = true,
[GGML_TYPE_Q5_K] = true,
Expand All @@ -3547,7 +3559,7 @@ static bool GGML_IS_QUANTIZED[GGML_TYPE_COUNT] = {
[GGML_TYPE_I16] = false,
[GGML_TYPE_I32] = false,
};
static_assert(GGML_TYPE_COUNT == 18, "GGML_IS_QUANTIZED is outdated");
static_assert(GGML_TYPE_COUNT == 19, "GGML_IS_QUANTIZED is outdated");

static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"NONE",
Expand Down Expand Up @@ -3854,6 +3866,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
case GGML_FTYPE_MOSTLY_Q5_0: wtype = GGML_TYPE_Q5_0; break;
case GGML_FTYPE_MOSTLY_Q5_1: wtype = GGML_TYPE_Q5_1; break;
case GGML_FTYPE_MOSTLY_Q8_0: wtype = GGML_TYPE_Q8_0; break;
case GGML_FTYPE_MOSTLY_Q2_K: wtype = GGML_TYPE_Q2_K; break;
case GGML_FTYPE_MOSTLY_Q3_K: wtype = GGML_TYPE_Q3_K; break;
case GGML_FTYPE_MOSTLY_Q4_K: wtype = GGML_TYPE_Q4_K; break;
case GGML_FTYPE_MOSTLY_Q5_K: wtype = GGML_TYPE_Q5_K; break;
Expand Down Expand Up @@ -7641,6 +7654,7 @@ static void ggml_compute_forward_add(
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
Expand Down Expand Up @@ -7948,6 +7962,7 @@ static void ggml_compute_forward_add1(
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q8_1:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
Expand Down Expand Up @@ -8074,6 +8089,7 @@ static void ggml_compute_forward_acc(
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q8_1:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
Expand Down Expand Up @@ -10171,6 +10187,7 @@ static void ggml_compute_forward_mul_mat(
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q8_1:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
Expand Down Expand Up @@ -10358,6 +10375,7 @@ static void ggml_compute_forward_set(
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q8_1:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
Expand Down Expand Up @@ -10527,6 +10545,7 @@ static void ggml_compute_forward_get_rows(
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q8_1:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
Expand Down Expand Up @@ -11077,6 +11096,7 @@ static void ggml_compute_forward_alibi(
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q8_1:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
Expand Down Expand Up @@ -11153,6 +11173,7 @@ static void ggml_compute_forward_clamp(
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q8_1:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
Expand Down Expand Up @@ -16161,6 +16182,12 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i
block_q8_0 * block = (block_q8_0*)dst + start / QK8_0;
result = ggml_quantize_q8_0(src + start, block, n, n, hist);
} break;
case GGML_TYPE_Q2_K:
{
GGML_ASSERT(start % QK_K == 0);
block_q2_K * block = (block_q2_K*)dst + start / QK_K;
result = ggml_quantize_q2_K(src + start, block, n, n, hist);
} break;
case GGML_TYPE_Q3_K:
{
GGML_ASSERT(start % QK_K == 0);
Expand Down
20 changes: 11 additions & 9 deletions ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -242,11 +242,12 @@ extern "C" {
GGML_TYPE_Q8_0 = 8,
GGML_TYPE_Q8_1 = 9,
// k-quantizations
GGML_TYPE_Q3_K = 10,
GGML_TYPE_Q4_K = 11,
GGML_TYPE_Q5_K = 12,
GGML_TYPE_Q6_K = 13,
GGML_TYPE_Q8_K = 14,
GGML_TYPE_Q2_K = 10,
GGML_TYPE_Q3_K = 11,
GGML_TYPE_Q4_K = 12,
GGML_TYPE_Q5_K = 13,
GGML_TYPE_Q6_K = 14,
GGML_TYPE_Q8_K = 15,
GGML_TYPE_I8,
GGML_TYPE_I16,
GGML_TYPE_I32,
Expand All @@ -270,10 +271,11 @@ extern "C" {
GGML_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors
GGML_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors
GGML_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors
GGML_FTYPE_MOSTLY_Q3_K = 10, // except 1d tensors
GGML_FTYPE_MOSTLY_Q4_K = 11, // except 1d tensors
GGML_FTYPE_MOSTLY_Q5_K = 12, // except 1d tensors
GGML_FTYPE_MOSTLY_Q6_K = 13, // except 1d tensors
GGML_FTYPE_MOSTLY_Q2_K = 10, // except 1d tensors
GGML_FTYPE_MOSTLY_Q3_K = 11, // except 1d tensors
GGML_FTYPE_MOSTLY_Q4_K = 12, // except 1d tensors
GGML_FTYPE_MOSTLY_Q5_K = 13, // except 1d tensors
GGML_FTYPE_MOSTLY_Q6_K = 14, // except 1d tensors
};

// available tensor operations:
Expand Down
Loading