Skip to content

Commit

Permalink
Add q6_k fast matmul kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
0cc4m committed Jun 20, 2023
1 parent 34a4917 commit 8d816d1
Showing 1 changed file with 79 additions and 3 deletions.
82 changes: 79 additions & 3 deletions ggml-opencl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,7 @@ void vec_dot_q4_K(__global const struct block_q4_K* x, const int ib, const int i
*result = sum;
}

__kernel void dequantize_mul_mat_vec_q4_K_fast(__global struct block_q4_K * xx, __local float* tmp, __global float* y, __global float* dst, const int ncols) {
__kernel void dequantize_mul_mat_vec_q4_K_fast(__global struct block_q4_K * xx, __local float* tmp, __global float* yy, __global float* dst, const int ncols) {

//to rename it later, just to test now
const uint16_t kmask1 = 0x3f3f;
Expand Down Expand Up @@ -466,7 +466,7 @@ __kernel void dequantize_mul_mat_vec_q4_K_fast(__global struct block_q4_K * xx,

const uint8_t * q1 = x[i].qs + q_offset;
const uint8_t * q2 = q1 + 64;
const float * y1 = y + i*QK_K + y_offset;
const float * y1 = yy + i*QK_K + y_offset;
const float * y2 = y1 + 128;

const float dall = vload_half(0, &x[i].d);
Expand Down Expand Up @@ -562,6 +562,82 @@ void vec_dot_q6_K(__global const struct block_q6_K* x, const int ib, const int i

}

__kernel void dequantize_mul_mat_vec_q6_K_fast(__global struct block_q6_K * xx, __local float* tmp, __global const float * yy, __global float * dst, const int ncols) {

const int row = get_group_id(0);

const int num_blocks_per_row = ncols / QK_K;
const int ib0 = row*num_blocks_per_row;

const struct block_q6_K * x = xx + ib0;

const int tid = get_local_id(0)/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
const int ix = get_local_id(0)%K_QUANTS_PER_ITERATION; // 0 or 0, 1

const int step = 16/K_QUANTS_PER_ITERATION; // 16 or 8

const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
const int in = tid - step*im; // 0...15 or 0...7

#if K_QUANTS_PER_ITERATION == 1
const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15
const int is = 0;
#else
const int l0 = 4 * in; // 0, 4, 8, ..., 28
const int is = in / 4;
#endif
const int ql_offset = 64*im + l0;
const int qh_offset = 32*im + l0;
const int s_offset = 8*im + is;
const int y_offset = 128*im + l0;

tmp[16 * ix + tid] = 0; // partial sum for thread in warp

for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {

const float * y = yy + i * QK_K + y_offset;
const uint8_t * ql = x[i].ql + ql_offset;
const uint8_t * qh = x[i].qh + qh_offset;
const int8_t * s = x[i].scales + s_offset;

const float d = vload_half(0, &x[i].d);

#if K_QUANTS_PER_ITERATION == 1
float sum = y[ 0] * s[0] * d * ((int8_t)((ql[ 0] & 0xF) | ((qh[ 0] & 0x03) << 4)) - 32)
+ y[16] * s[1] * d * ((int8_t)((ql[16] & 0xF) | ((qh[16] & 0x03) << 4)) - 32)
+ y[32] * s[2] * d * ((int8_t)((ql[32] & 0xF) | ((qh[ 0] & 0x0c) << 2)) - 32)
+ y[48] * s[3] * d * ((int8_t)((ql[48] & 0xF) | ((qh[16] & 0x0c) << 2)) - 32)
+ y[64] * s[4] * d * ((int8_t)((ql[ 0] >> 4) | ((qh[ 0] & 0x30) >> 0)) - 32)
+ y[80] * s[5] * d * ((int8_t)((ql[16] >> 4) | ((qh[16] & 0x30) >> 0)) - 32)
+ y[96] * s[6] * d * ((int8_t)((ql[32] >> 4) | ((qh[ 0] & 0xc0) >> 2)) - 32)
+y[112] * s[7] * d * ((int8_t)((ql[48] >> 4) | ((qh[16] & 0xc0) >> 2)) - 32);
tmp[16 * ix + tid] += sum;
#else
float sum = 0;
for (int l = 0; l < 4; ++l) {
sum += y[l+ 0] * s[0] * d * ((int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32)
+ y[l+32] * s[2] * d * ((int8_t)((ql[l+32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32)
+ y[l+64] * s[4] * d * ((int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32)
+ y[l+96] * s[6] * d * ((int8_t)((ql[l+32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32);
}
tmp[16 * ix + tid] += sum;
#endif

}

// sum up partial sums and write back result
barrier(CLK_LOCAL_MEM_FENCE);
for (int s=16; s>0; s>>=1) {
if (tid < s) {
tmp[tid] += tmp[tid + s];
}
barrier(CLK_LOCAL_MEM_FENCE);
}
if (tid == 0) {
dst[row] = tmp[0];
}
}

);


Expand Down Expand Up @@ -1041,7 +1117,7 @@ void ggml_cl_init(void) {
CL_CHECK((dequantize_mul_mat_vec_q3_K_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q3_K", &err), err));
CL_CHECK((dequantize_mul_mat_vec_q4_K_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q4_K_fast", &err), err));
CL_CHECK((dequantize_mul_mat_vec_q5_K_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q5_K", &err), err));
CL_CHECK((dequantize_mul_mat_vec_q6_K_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q6_K", &err), err));
CL_CHECK((dequantize_mul_mat_vec_q6_K_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q6_K_fast", &err), err));

// mul kernel
CL_CHECK((mul_f32_cl = clCreateKernel(program, "mul_f32", &err), err));
Expand Down

0 comments on commit 8d816d1

Please sign in to comment.