Skip to content

Commit

Permalink
ported q2k and q5k speedups
Browse files Browse the repository at this point in the history
  • Loading branch information
LostRuins committed Jun 20, 2023
1 parent 029bed6 commit 93247a1
Showing 1 changed file with 159 additions and 3 deletions.
162 changes: 159 additions & 3 deletions ggml-opencl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
#define CL_DMMV_BLOCK_SIZE 32

#ifndef K_QUANTS_PER_ITERATION
#define K_QUANTS_PER_ITERATION 2
#define K_QUANTS_PER_ITERATION 1
#else
static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2");
#endif
Expand Down Expand Up @@ -357,6 +357,80 @@ void vec_dot_q2_K(__global const struct block_q2_K* x, const int ib, const int i
*result = sum;
}


__kernel void dequantize_mul_mat_vec_q2_K_fast(__global struct block_q2_K * xx, __local float* tmp, __global float* yy, __global float* dst, const int ncols) {

const int row = get_group_id(0);

const int num_blocks_per_row = ncols / QK_K;
const int ib0 = row*num_blocks_per_row;

const struct block_q2_K * x = xx + ib0;

const int tid = get_local_id(0)/K_QUANTS_PER_ITERATION; // 0...31 or 0...15
const int ix = get_local_id(0)%K_QUANTS_PER_ITERATION; // 0 or 0,1

const int step = 16/K_QUANTS_PER_ITERATION;

const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
const int in = tid - step*im; // 0...15 or 0...7

const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15 or 0...14 in steps of 2
const int q_offset = 32*im + l0;
const int s_offset = 8*im;
const int y_offset = 128*im + l0;

tmp[16 * ix + tid] = 0;

uint32_t aux[4];
const uint8_t * d = (const uint8_t *)aux;
const uint8_t * m = (const uint8_t *)(aux + 2);

for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {

const float * y = yy + i * QK_K + y_offset;
const uint8_t * q = x[i].qs + q_offset;

const float dall = vload_half(0, &x[i].d);
const float dmin = vload_half(0, &x[i].dmin);

const uint32_t * a = (const uint32_t *)(x[i].scales + s_offset);
aux[0] = a[0] & 0x0f0f0f0f;
aux[1] = a[1] & 0x0f0f0f0f;
aux[2] = (a[0] >> 4) & 0x0f0f0f0f;
aux[3] = (a[1] >> 4) & 0x0f0f0f0f;

float sum1 = 0, sum2 = 0;
for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
sum1 += y[l+ 0] * d[0] * ((q[l+ 0] >> 0) & 3)
+ y[l+32] * d[2] * ((q[l+ 0] >> 2) & 3)
+ y[l+64] * d[4] * ((q[l+ 0] >> 4) & 3)
+ y[l+96] * d[6] * ((q[l+ 0] >> 6) & 3)
+ y[l+16] * d[1] * ((q[l+16] >> 0) & 3)
+ y[l+48] * d[3] * ((q[l+16] >> 2) & 3)
+ y[l+80] * d[5] * ((q[l+16] >> 4) & 3)
+y[l+112] * d[7] * ((q[l+16] >> 6) & 3);
sum2 += y[l+ 0] * m[0] + y[l+32] * m[2] + y[l+64] * m[4] + y[ l+96] * m[6]
+ y[l+16] * m[1] + y[l+48] * m[3] + y[l+80] * m[5] + y[l+112] * m[7];

}
tmp[16 * ix + tid] += dall * sum1 - dmin * sum2;

}

// sum up partial sums and write back result
barrier(CLK_LOCAL_MEM_FENCE);
for (int s=16; s>0; s>>=1) {
if (tid < s) {
tmp[tid] += tmp[tid + s];
}
barrier(CLK_LOCAL_MEM_FENCE);
}
if (tid == 0) {
dst[row] = tmp[0];
}
}

void vec_dot_q3_K(__global const struct block_q3_K* x, const int ib, const int iqs, const __global float *yy, float *result) {

const uint32_t kmask1 = 0x03030303;
Expand Down Expand Up @@ -610,6 +684,88 @@ void vec_dot_q5_K(__global const struct block_q5_K* x, const int ib, const int i

}

__kernel void dequantize_mul_mat_vec_q5_K_fast(__global struct block_q5_K * xx, __local float* tmp, __global float* yy, __global float* dst, const int ncols) {

const uint16_t kmask1 = 0x3f3f;
const uint16_t kmask2 = 0x0f0f;
const uint16_t kmask3 = 0xc0c0;

const int row = get_group_id(0);
const int num_blocks_per_row = ncols / QK_K;
const int ib0 = row*num_blocks_per_row;

const int tid = get_local_id(0)/2; // 0...15
const int ix = get_local_id(0)%2;

const int il = tid/4; // 0...3
const int ir = tid - 4*il;// 0...3
const int n = 2;

const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
const int in = il%2;

const int l0 = n*(2*ir + in);
const int q_offset = 32*im + l0;
const int y_offset = 64*im + l0;

const uint8_t hm1 = 1 << (2*im);
const uint8_t hm2 = hm1 << 4;

uint16_t aux[4];
const uint8_t * sc = (const uint8_t *)aux;

const struct block_q5_K * x = xx + ib0;

tmp[16 * ix + tid] = 0;

for (int i = ix; i < num_blocks_per_row; i += 2) {

const uint8_t * ql1 = x[i].qs + q_offset;
const uint8_t * ql2 = ql1 + 64;
const uint8_t * qh = x[i].qh + l0;
const float * y1 = yy + i*QK_K + y_offset;
const float * y2 = y1 + 128;

const float dall = vload_half(0, &x[i].d);
const float dmin = vload_half(0, &x[i].dmin);

const uint16_t * a = (const uint16_t *)x[i].scales;
aux[0] = a[im+0] & kmask1;
aux[1] = a[im+2] & kmask1;
aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2);
aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2);

float4 sum = (float4)(0.f);
float smin = 0;
for (int l = 0; l < n; ++l) {
sum.x += y1[l+ 0] * ((ql1[l+ 0] & 0xF) + (qh[l+ 0] & (hm1 << 0) ? 16 : 0))
+ y1[l+16] * ((ql1[l+16] & 0xF) + (qh[l+16] & (hm1 << 0) ? 16 : 0));
sum.y += y1[l+32] * ((ql1[l+ 0] >> 4) + (qh[l+ 0] & (hm1 << 1) ? 16 : 0))
+ y1[l+48] * ((ql1[l+16] >> 4) + (qh[l+16] & (hm1 << 1) ? 16 : 0));
sum.z += y2[l+ 0] * ((ql2[l+ 0] & 0xF) + (qh[l+ 0] & (hm2 << 0) ? 16 : 0))
+ y2[l+16] * ((ql2[l+16] & 0xF) + (qh[l+16] & (hm2 << 0) ? 16 : 0));
sum.w += y2[l+32] * ((ql2[l+ 0] >> 4) + (qh[l+ 0] & (hm2 << 1) ? 16 : 0))
+ y2[l+48] * ((ql2[l+16] >> 4) + (qh[l+16] & (hm2 << 1) ? 16 : 0));
smin += (y1[l] + y1[l+16]) * sc[2] + (y1[l+32] + y1[l+48]) * sc[3]
+ (y2[l] + y2[l+16]) * sc[6] + (y2[l+32] + y2[l+48]) * sc[7];
}
tmp[16 * ix + tid] += dall * (sum.x * sc[0] + sum.y * sc[1] + sum.z * sc[4] + sum.w * sc[5]) - dmin * smin;

}

// sum up partial sums and write back result
barrier(CLK_LOCAL_MEM_FENCE);
for (int s=16; s>0; s>>=1) {
if (tid < s) {
tmp[tid] += tmp[tid + s];
}
barrier(CLK_LOCAL_MEM_FENCE);
}
if (tid == 0) {
dst[row] = tmp[0];
}
}

void vec_dot_q6_K(__global const struct block_q6_K* x, const int ib, const int iqs, const __global float *yy, float *result) {


Expand Down Expand Up @@ -1187,10 +1343,10 @@ void ggml_cl_init(void) {
CL_CHECK((dequantize_mul_mat_vec_q5_1_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q5_1", &err), err));
CL_CHECK((dequantize_mul_mat_vec_q8_0_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q8_0", &err), err));
CL_CHECK((convert_mul_mat_vec_f16_cl = clCreateKernel(program, "convert_mul_mat_vec_f16", &err), err));
CL_CHECK((dequantize_mul_mat_vec_q2_K_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q2_K", &err), err));
CL_CHECK((dequantize_mul_mat_vec_q2_K_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q2_K_fast", &err), err));
CL_CHECK((dequantize_mul_mat_vec_q3_K_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q3_K_fast", &err), err));
CL_CHECK((dequantize_mul_mat_vec_q4_K_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q4_K_fast", &err), err));
CL_CHECK((dequantize_mul_mat_vec_q5_K_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q5_K", &err), err));
CL_CHECK((dequantize_mul_mat_vec_q5_K_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q5_K_fast", &err), err));
CL_CHECK((dequantize_mul_mat_vec_q6_K_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q6_K_fast", &err), err));

// mul kernel
Expand Down

0 comments on commit 93247a1

Please sign in to comment.