Skip to content

Commit

Permalink
Use preprocessor for QK_K
Browse files Browse the repository at this point in the history
  • Loading branch information
0cc4m committed Jun 20, 2023
1 parent 069cbe5 commit 34a4917
Showing 1 changed file with 13 additions and 13 deletions.
26 changes: 13 additions & 13 deletions ggml-opencl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ __kernel void dequantize_block_q2_K(__global const struct block_q2_K *x, __globa
const int is = 8 * n + l / 16;

const uint8_t q = x[i].qs[32 * n + l];
__global float *y = yy + i * 256 + 128 * n;
__global float *y = yy + i * QK_K + 128 * n;

const float dall = vload_half(0, &x[i].d);
const float dmin = vload_half(0, &x[i].dmin);
Expand Down Expand Up @@ -239,7 +239,7 @@ __kernel void dequantize_block_q3_K(__global const struct block_q3_K *x, __globa
float d_all = vload_half(0, &x[i].d);
float dl = d_all * (us - 32);

__global float *y = yy + i * 256 + 128 * n + 32 * j;
__global float *y = yy + i * QK_K + 128 * n + 32 * j;
const __global uint8_t *q = x[i].qs + 32 * n;
const __global uint8_t *hm = x[i].hmask;

Expand All @@ -256,7 +256,7 @@ __kernel void dequantize_block_q4_K(__global const struct block_q4_K *x, __globa
const int is = 2 * il;
const int n = 4;

__global float *y = yy + i * 256 + 64 * il + n * ir;
__global float *y = yy + i * QK_K + 64 * il + n * ir;

const float dall = vload_half(0, &x[i].d);
const float dmin = vload_half(0, &x[i].dmin);
Expand Down Expand Up @@ -285,7 +285,7 @@ __kernel void dequantize_block_q5_K(__global const struct block_q5_K *x, __globa
const int ir = tid % 16;
const int is = 2 * il;

__global float *y = yy + i * 256 + 64 * il + 2 * ir;
__global float *y = yy + i * QK_K + 64 * il + 2 * ir;

const float dall = vload_half(0, &x[i].d);
const float dmin = vload_half(0, &x[i].dmin);
Expand Down Expand Up @@ -317,7 +317,7 @@ __kernel void dequantize_block_q6_K(__global const struct block_q6_K *x, __globa
const int il = tid - 32 * ip;
const int is = 8 * ip + il / 16;

__global float *y = yy + i * 256 + 128 * ip + il;
__global float *y = yy + i * QK_K + 128 * ip + il;

const float d = vload_half(0, &x[i].d);

Expand Down Expand Up @@ -436,7 +436,7 @@ __kernel void dequantize_mul_mat_vec_q4_K_fast(__global struct block_q4_K * xx,
const uint16_t kmask3 = 0xc0c0;

const int row = get_group_id(0);
const int num_blocks_per_row = ncols / 256;
const int num_blocks_per_row = ncols / QK_K;
const int ib0 = row*num_blocks_per_row;

const int tid = get_local_id(0)/K_QUANTS_PER_ITERATION; // 0...15
Expand Down Expand Up @@ -466,7 +466,7 @@ __kernel void dequantize_mul_mat_vec_q4_K_fast(__global struct block_q4_K * xx,

const uint8_t * q1 = x[i].qs + q_offset;
const uint8_t * q2 = q1 + 64;
const float * y1 = y + i*256 + y_offset;
const float * y1 = y + i*QK_K + y_offset;
const float * y2 = y1 + 128;

const float dall = vload_half(0, &x[i].d);
Expand Down Expand Up @@ -637,18 +637,18 @@ __kernel void KERNEL_NAME(__global X_TYPE* x, __local float* tmp, __global float
const int row = get_group_id(0);
const int tid = get_local_id(0);

const int iter_stride = 256;
const int iter_stride = QK_K;
const int vals_per_iter = iter_stride / block_size;
const int num_blocks_per_row = ncols / 256;
const int num_blocks_per_row = ncols / QK_K;
const int ib0 = row*num_blocks_per_row;

tmp[tid] = 0;

for (int i = 0; i < ncols; i += iter_stride) {
const int col = i + vals_per_iter*tid;
const int ib = ib0 + col/256; // x block index
const int iqs = col%256; // x quant index
const int iybs = col - col%256; // y block start index
const int ib = ib0 + col/QK_K; // x block index
const int iqs = col%QK_K; // x quant index
const int iybs = col - col%QK_K; // y block start index

// dequantize
float v;
Expand Down Expand Up @@ -813,7 +813,7 @@ static cl_program build_program_from_source(cl_context ctx, cl_device_id dev, co

std::string compile_opts = "-cl-mad-enable -cl-unsafe-math-optimizations -cl-finite-math-only -cl-fast-relaxed-math "
"-DQK4_0=32 -DQR4_0=2 -DQK4_1=32 -DQR4_1=2 -DQK5_0=32 -DQR5_0=2 -DQK5_1=32 -DQR5_1=2 -DQK8_0=32 -DQR8_0=1 "
"-DK_QUANTS_PER_ITERATION=" + std::to_string(K_QUANTS_PER_ITERATION);
"-DQK_K=256 -DK_QUANTS_PER_ITERATION=" + std::to_string(K_QUANTS_PER_ITERATION);

err = clBuildProgram(p, 0, NULL, compile_opts.c_str(), NULL, NULL);
if(err < 0) {
Expand Down

0 comments on commit 34a4917

Please sign in to comment.