Skip to content

Commit

Permalink
ported q3k speedup successfully
Browse files Browse the repository at this point in the history
  • Loading branch information
LostRuins committed Jun 20, 2023
1 parent 8d816d1 commit 029bed6
Showing 1 changed file with 75 additions and 1 deletion.
76 changes: 75 additions & 1 deletion ggml-opencl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,80 @@ void vec_dot_q3_K(__global const struct block_q3_K* x, const int ib, const int i

}

__kernel void dequantize_mul_mat_vec_q3_K_fast(__global struct block_q3_K * xx, __local float* tmp, __global float* yy, __global float* dst, const int ncols) {
const uint16_t kmask1 = 0x0303;
const uint16_t kmask2 = 0x0f0f;

const int row = get_group_id(0);

const int num_blocks_per_row = ncols / QK_K;
const int ib0 = row*num_blocks_per_row;

const struct block_q3_K * x = xx + ib0;

const int tid = get_local_id(0)/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
const int ix = get_local_id(0)%K_QUANTS_PER_ITERATION; // 0 or 0,1

const int n = K_QUANTS_PER_ITERATION; // iterations in the inner loop
const int step = 16/K_QUANTS_PER_ITERATION;
const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
const int in = tid - step*im; // 0....15 or 0...7

const uint8_t m = 1 << (4*im);

const int l0 = n*in; // 0...15 or 0...14 in steps of 2
const int q_offset = 32*im + l0;
const int y_offset = 128*im + l0;

uint16_t utmp[4];
const int8_t * s = (const int8_t *)utmp;

const uint16_t s_shift = 4*im;

tmp[16 * ix + tid] = 0;

for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {

const float * y = yy + i * QK_K + y_offset;
const uint8_t * q = x[i].qs + q_offset;
const uint8_t * h = x[i].hmask + l0;

const uint16_t * a = (const uint16_t *)x[i].scales;
utmp[0] = ((a[0] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 0)) & kmask1) << 4);
utmp[1] = ((a[1] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 0)) & kmask1) << 4);
utmp[2] = ((a[2] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 2)) & kmask1) << 4);
utmp[3] = ((a[3] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 2)) & kmask1) << 4);

const float d = vload_half(0, &x[i].d);

float sum = 0;
for (int l = 0; l < n; ++l) {
sum += y[l+ 0] * (s[0] - 32) * (((q[l] >> 0) & 3) - (h[l] & (m << 0) ? 0 : 4))
+ y[l+32] * (s[2] - 32) * (((q[l] >> 2) & 3) - (h[l] & (m << 1) ? 0 : 4))
+ y[l+64] * (s[4] - 32) * (((q[l] >> 4) & 3) - (h[l] & (m << 2) ? 0 : 4))
+ y[l+96] * (s[6] - 32) * (((q[l] >> 6) & 3) - (h[l] & (m << 3) ? 0 : 4));
sum += y[l+16] * (s[1] - 32) * (((q[l+16] >> 0) & 3) - (h[l+16] & (m << 0) ? 0 : 4))
+ y[l+48] * (s[3] - 32) * (((q[l+16] >> 2) & 3) - (h[l+16] & (m << 1) ? 0 : 4))
+ y[l+80] * (s[5] - 32) * (((q[l+16] >> 4) & 3) - (h[l+16] & (m << 2) ? 0 : 4))
+ y[l+112] * (s[7] - 32) * (((q[l+16] >> 6) & 3) - (h[l+16] & (m << 3) ? 0 : 4));
}
tmp[16 * ix + tid] += d * sum;

}

// sum up partial sums and write back result
barrier(CLK_LOCAL_MEM_FENCE);
for (int s=16; s>0; s>>=1) {
if (tid < s) {
tmp[tid] += tmp[tid + s];
}
barrier(CLK_LOCAL_MEM_FENCE);
}
if (tid == 0) {
dst[row] = tmp[0];
}
}

void vec_dot_q4_K(__global const struct block_q4_K* x, const int ib, const int iqs, const __global float *yy, float *result) {

const int j = iqs / 64; // j is in 0...3
Expand Down Expand Up @@ -1114,7 +1188,7 @@ void ggml_cl_init(void) {
CL_CHECK((dequantize_mul_mat_vec_q8_0_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q8_0", &err), err));
CL_CHECK((convert_mul_mat_vec_f16_cl = clCreateKernel(program, "convert_mul_mat_vec_f16", &err), err));
CL_CHECK((dequantize_mul_mat_vec_q2_K_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q2_K", &err), err));
CL_CHECK((dequantize_mul_mat_vec_q3_K_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q3_K", &err), err));
CL_CHECK((dequantize_mul_mat_vec_q3_K_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q3_K_fast", &err), err));
CL_CHECK((dequantize_mul_mat_vec_q4_K_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q4_K_fast", &err), err));
CL_CHECK((dequantize_mul_mat_vec_q5_K_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q5_K", &err), err));
CL_CHECK((dequantize_mul_mat_vec_q6_K_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q6_K_fast", &err), err));
Expand Down

0 comments on commit 029bed6

Please sign in to comment.