Skip to content

Commit

Permalink
Porting q2_k kernel to OpenCL
Browse files Browse the repository at this point in the history
  • Loading branch information
LostRuins authored and 0cc4m committed Jun 13, 2023
1 parent 74a6d92 commit 9b41865
Showing 1 changed file with 188 additions and 0 deletions.
188 changes: 188 additions & 0 deletions ggml-opencl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,29 @@ struct __attribute__ ((packed)) block_q8_0
int8_t qs[QK8_0];
};

struct __attribute__ ((packed)) block_q2_K
{
uchar scales[16];
uchar qs[64];
half d;
half dmin;
};

struct __attribute__ ((packed)) block_q4_K
{
uchar scales[12];
uchar qs[128];
half d;
half dmin;
};

struct __attribute__ ((packed)) block_q3_K
{
uchar hmask[32];
uchar qs[64];
uchar scales[12];
half d;
};

__kernel void convert_fp16_to_fp32(__global half* x, __global float* y) {
const uint i = get_global_id(0);
Expand Down Expand Up @@ -131,8 +154,120 @@ void convert_f16(__global half* x, const int ib, const int iqs, float* v0, float
*v0 = vload_half(0, &x[ib + 0]);
*v1 = vload_half(0, &x[ib + 1]);
}

static inline void get_scale_min_k4(int j, const __global uchar *q, uchar *d, uchar *m) {
if (j < 4) {
*d = q[j] & 63;
*m = q[j + 4] & 63;
} else {
*d = (q[j + 4] & 0xF) | ((q[j - 4] >> 6) << 4);
*m = (q[j + 4] >> 4) | ((q[j - 0] >> 6) << 4);
}
}

__kernel void dequantize_block_q2_K(__global const struct block_q2_K* x, __global float *yy) {
const int i = get_group_id(0);
const int tid = get_local_id(0);
const int n = tid / 32;
const int l = tid - 32 * n;
const int is = 8 * n + l / 16;

const uchar q = x[i].qs[32 * n + l];
__global float *y = yy + i * 256 + 128 * n;

const float dall = vload_half(0, &x[i].d);
const float dmin = vload_half(0, &x[i].dmin);

y[l + 0] = dall * (x[i].scales[is + 0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is + 0] >> 4);
y[l + 32] = dall * (x[i].scales[is + 2] & 0xF) * ((q >> 2) & 3) - dmin * (x[i].scales[is + 2] >> 4);
y[l + 64] = dall * (x[i].scales[is + 4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is + 4] >> 4);
y[l + 96] = dall * (x[i].scales[is + 6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is + 6] >> 4);
}

__kernel void dequantize_block_q3_K(__global const struct block_q3_K* x, __global float *yy) {
int r = get_local_id(0) / 4;
int i = get_group_id(0);
int tid = r / 2;
int is0 = r % 2;
int l0 = 16 * is0 + 4 * (get_local_id(0) % 4);
int n = tid / 4;
int j = tid - 4 * n;

uchar m = 1 << (4 * n + j);
int is = 8 * n + 2 * j + is0;
int shift = 2 * j;

uchar us = is < 4 ? (x[i].scales[is - 0] & 0xF) | (((x[i].scales[is + 8] >> 0) & 3) << 4) :
is < 8 ? (x[i].scales[is - 0] & 0xF) | (((x[i].scales[is + 4] >> 2) & 3) << 4) :
is < 12 ? (x[i].scales[is - 8] >> 4) | (((x[i].scales[is + 0] >> 4) & 3) << 4) :
(x[i].scales[is - 8] >> 4) | (((x[i].scales[is - 4] >> 6) & 3) << 4);
float d_all = vload_half(0, &x[i].d);
float dl = d_all * (us - 32);

__global float *y = yy + i * 256 + 128 * n + 32 * j;
const __global uchar *q = x[i].qs + 32 * n;
const __global uchar *hm = x[i].hmask;

for (int l = l0; l < l0 + 4; ++l)
y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l / 8] & m) ? 0 : 4));
}

__kernel void dequantize_block_q4_K(__global const struct block_q4_K* x, __global float *yy) {

const int i = get_group_id(0);
const int tid = get_local_id(0);
const int il = tid / 8;
const int ir = tid % 8;
const int is = 2 * il;
const int n = 4;

__global float *y = yy + i * 256 + 64 * il + n * ir;

const float dall = vload_half(0, &x[i].d);
const float dmin = vload_half(0, &x[i].dmin);

__global const uchar *q = x[i].qs + 32 * il + n * ir;

uchar sc, m;
get_scale_min_k4(is + 0, x[i].scales, &sc, &m);
float d1 = dall * sc;
float m1 = dmin * m;
get_scale_min_k4(is + 1, x[i].scales, &sc, &m);
float d2 = dall * sc;
float m2 = dmin * m;
for (int l = 0; l < n; ++l) {
y[l + 0] = d1 * (q[l] & 0xF) - m1;
y[l + 32] = d2 * (q[l] >> 4) - m2;
}
}

);

// __kernel void vec_dot_q2_K(__global const struct block_q2_K* x, const int ib, const int iqs, const __global float *yy, __global float *result) {

// int n = iqs / 128; // 0 or 1
// int r = iqs - 128 * n; // 0...120 in steps of 8
// int l = r / 8; // 0...15 in steps of 1

// __global const float *y = yy + 128 * n + l;
// __global const uchar *q = x[ib].qs + 32 * n + l;
// __global const uchar *s = x[ib].scales + 8 * n;

// const float dall = vload_half(0, &x[ib].d);
// const float dmin = vload_half(0, &x[ib].dmin);

// float sum = y[0] * (dall * ((s[0] & 0xF) * ((q[0] >> 0) & 3)) - dmin * (s[0] >> 4))
// + y[32] * (dall * ((s[2] & 0xF) * ((q[0] >> 2) & 3)) - dmin * (s[2] >> 4))
// + y[64] * (dall * ((s[4] & 0xF) * ((q[0] >> 4) & 3)) - dmin * (s[4] >> 4))
// + y[96] * (dall * ((s[6] & 0xF) * ((q[0] >> 6) & 3)) - dmin * (s[6] >> 4))
// + y[16] * (dall * ((s[1] & 0xF) * ((q[16] >> 0) & 3)) - dmin * (s[1] >> 4))
// + y[48] * (dall * ((s[3] & 0xF) * ((q[16] >> 2) & 3)) - dmin * (s[3] >> 4))
// + y[80] * (dall * ((s[5] & 0xF) * ((q[16] >> 4) & 3)) - dmin * (s[5] >> 4))
// + y[112] * (dall * ((s[7] & 0xF) * ((q[16] >> 6) & 3)) - dmin * (s[7] >> 4));

// *result = sum;
// }

std::string dequant_template = MULTILINE_QUOTE(
__kernel void KERNEL_NAME(__global X_TYPE* x, __global float* y) {
const int i = get_group_id(0)*get_local_size(0) + get_local_id(0)*2;
Expand Down Expand Up @@ -199,6 +334,45 @@ __kernel void KERNEL_NAME(__global X_TYPE* x, __local float* tmp, __global float
}
);

// std::string dequant_mul_mat_vec_k_template = MULTILINE_QUOTE(
// __kernel void KERNEL_NAME(__global X_TYPE* x, __local float* tmp, __global float* y, __global float* dst, const int ncols) {
// const int block_size = get_local_size(0);
// const int row = get_global_id(0) / block_size;
// const int tid = get_local_id(0);

// const int iter_stride = 256;
// const int vals_per_iter = iter_stride;
// const int num_blocks_per_row = ncols / 256;
// const int ib0 = row*num_blocks_per_row;

// tmp[tid] = 0;

// for (int i = 0; i < ncols; i += iter_stride) {
// const int col = i + vals_per_iter*tid;
// const int ib = ib0 + col/QK_K; // x block index
// const int iqs = col%QK_K; // x quant index
// const int iybs = col - col%QK_K; // y block start index

// // dequantize
// float v;
// dot_kernel(vx, ib, iqs, y + iybs, v);
// tmp += v;
// }

// // sum up partial sums and write back result
// barrier(CLK_LOCAL_MEM_FENCE);
// for (int s=block_size/2; s>0; s>>=1) {
// if (tid < s) {
// tmp[tid] += tmp[tid + s];
// }
// barrier(CLK_LOCAL_MEM_FENCE);
// }
// if (tid == 0) {
// dst[row] = tmp[0];
// }
// }
// );

std::string mul_template = MULTILINE_QUOTE(
__kernel void KERNEL_NAME(__global TYPE* x, const int x_offset, __global TYPE* y, const int y_offset, __global TYPE* dst, const int dst_offset, const int ky) {
const int i = get_group_id(0)*get_local_size(0) + get_local_id(0);
Expand Down Expand Up @@ -300,6 +474,7 @@ static cl_program program;
static cl_kernel convert_row_f16_cl;
static cl_kernel dequantize_row_q4_0_cl, dequantize_row_q4_1_cl, dequantize_row_q5_0_cl, dequantize_row_q5_1_cl, dequantize_row_q8_0_cl;
static cl_kernel dequantize_mul_mat_vec_q4_0_cl, dequantize_mul_mat_vec_q4_1_cl, dequantize_mul_mat_vec_q5_0_cl, dequantize_mul_mat_vec_q5_1_cl, dequantize_mul_mat_vec_q8_0_cl, convert_mul_mat_vec_f16_cl;
static cl_kernel dequantize_block_q2_k_cl, dequantize_block_q3_k_cl, dequantize_block_q4_k_cl;
static cl_kernel mul_f32_cl;
static bool fp16_support;

Expand Down Expand Up @@ -529,6 +704,10 @@ void ggml_cl_init(void) {
CL_CHECK((dequantize_row_q5_0_cl = clCreateKernel(program, "dequantize_row_q5_0", &err), err));
CL_CHECK((dequantize_row_q5_1_cl = clCreateKernel(program, "dequantize_row_q5_1", &err), err));
CL_CHECK((dequantize_row_q8_0_cl = clCreateKernel(program, "dequantize_row_q8_0", &err), err));
CL_CHECK((dequantize_row_q8_0_cl = clCreateKernel(program, "dequantize_row_q8_0", &err), err));
CL_CHECK((dequantize_block_q2_k_cl = clCreateKernel(program, "dequantize_block_q2_K", &err), err));
CL_CHECK((dequantize_block_q3_k_cl = clCreateKernel(program, "dequantize_block_q3_K", &err), err));
CL_CHECK((dequantize_block_q4_k_cl = clCreateKernel(program, "dequantize_block_q4_K", &err), err));

// dequant mul mat kernel
CL_CHECK((dequantize_mul_mat_vec_q4_0_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q4_0", &err), err));
Expand All @@ -538,6 +717,7 @@ void ggml_cl_init(void) {
CL_CHECK((dequantize_mul_mat_vec_q8_0_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q8_0", &err), err));
CL_CHECK((convert_mul_mat_vec_f16_cl = clCreateKernel(program, "convert_mul_mat_vec_f16", &err), err));


// mul kernel
CL_CHECK((mul_f32_cl = clCreateKernel(program, "mul_f32", &err), err));
}
Expand All @@ -554,6 +734,12 @@ static cl_kernel* ggml_get_to_fp32_cl(ggml_type type) {
return &dequantize_row_q5_1_cl;
case GGML_TYPE_Q8_0:
return &dequantize_row_q8_0_cl;
case GGML_TYPE_Q2_K:
return &dequantize_block_q2_k_cl;
case GGML_TYPE_Q3_K:
return &dequantize_block_q3_k_cl;
case GGML_TYPE_Q4_K:
return &dequantize_block_q4_k_cl;
case GGML_TYPE_F16:
return &convert_row_f16_cl;
default:
Expand Down Expand Up @@ -1017,6 +1203,8 @@ static void ggml_cl_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor *
cl_kernel* dmmv = ggml_get_dequantize_mul_mat_vec_cl(type);
GGML_ASSERT(to_fp32_cl != nullptr);

printf("\ntype:%d q_sz:%d y_sz:%d ne00:%d ne01:%d ne10:%d ne11:%d nb2:%d nb3:%d",type,q_size,y_size,ne00,ne01,ne10,ne11);

size_t ev_idx = 0;
std::vector<cl_event> events;

Expand Down

0 comments on commit 9b41865

Please sign in to comment.