Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

K-Quant-Support for OpenCL #1836

Merged
merged 11 commits into from
Jun 16, 2023
Prev Previous commit
Next Next commit
Added OpenCL DMMV kernels
  • Loading branch information
LostRuins authored and 0cc4m committed Jun 13, 2023
commit 6e20827f933657fedeb63ad99ff069ce4ee814b8
165 changes: 129 additions & 36 deletions ggml-opencl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,35 +25,35 @@ typedef uchar uint8_t;
typedef int int32_t;
typedef uint uint32_t;

struct __attribute__((packed)) block_q4_0
struct __attribute__ ((packed)) block_q4_0
{
half d;
uint8_t qs[QK4_0 / 2];
};

struct __attribute__((packed)) block_q4_1
struct __attribute__ ((packed)) block_q4_1
{
half d;
half m;
uint8_t qs[QK4_1 / 2];
};

struct __attribute__((packed)) block_q5_0
struct __attribute__ ((packed)) block_q5_0
{
half d;
uint32_t qh;
uint8_t qs[QK5_0 / 2];
};

struct __attribute__((packed)) block_q5_1
struct __attribute__ ((packed)) block_q5_1
{
half d;
half m;
uint32_t qh;
uint8_t qs[QK5_1 / 2];
};

struct __attribute__((packed)) block_q8_0
struct __attribute__ ((packed)) block_q8_0
{
half d;
int8_t qs[QK8_0];
Expand Down Expand Up @@ -100,26 +100,24 @@ struct __attribute__((packed)) block_q6_K
half d;
};

__kernel void convert_fp16_to_fp32(__global half *x, __global float *y)
{
__kernel void convert_fp16_to_fp32(__global half* x, __global float* y) {
const uint i = get_global_id(0);

y[i] = vload_half(0, &x[i]);
}

void dequantize_q4_0(__global const struct block_q4_0 *x, const int ib, const int iqs, float *v0, float *v1)
{
void dequantize_q4_0(__global const struct block_q4_0* x, const int ib, const int iqs, float* v0, float* v1) {
const float d = vload_half(0, &x[ib].d);

const uint8_t vui = x[ib].qs[iqs];

const int8_t vi0 = vui & 0xF;
const int8_t vi1 = vui >> 4;

*v0 = (vi0 - 8) * d;
*v1 = (vi1 - 8) * d;
} void dequantize_q4_1(__global const struct block_q4_1 *x, const int ib, const int iqs, float *v0, float *v1)
{
*v0 = (vi0 - 8)*d;
*v1 = (vi1 - 8)*d;
}
void dequantize_q4_1(__global const struct block_q4_1* x, const int ib, const int iqs, float* v0, float* v1) {
const float d = vload_half(0, &x[ib].d);
const float m = vload_half(0, &x[ib].m);

Expand All @@ -128,48 +126,48 @@ void dequantize_q4_0(__global const struct block_q4_0 *x, const int ib, const in
const int8_t vi0 = vui & 0xF;
const int8_t vi1 = vui >> 4;

*v0 = vi0 * d + m;
*v1 = vi1 * d + m;
} void dequantize_q5_0(__global const struct block_q5_0 *x, const int ib, const int iqs, float *v0, float *v1)
{
*v0 = vi0*d + m;
*v1 = vi1*d + m;
}
void dequantize_q5_0(__global const struct block_q5_0* x, const int ib, const int iqs, float* v0, float* v1) {
const float d = vload_half(0, &x[ib].d);

uint32_t qh = x[ib].qh;

const uint8_t xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10;
const uint8_t xh_1 = ((qh >> (iqs + 12))) & 0x10;
const uint8_t xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10;
const uint8_t xh_1 = ((qh >> (iqs + 12)) ) & 0x10;

const int32_t x0 = ((x[ib].qs[iqs] & 0xf) | xh_0) - 16;
const int32_t x1 = ((x[ib].qs[iqs] >> 4) | xh_1) - 16;
const int32_t x1 = ((x[ib].qs[iqs] >> 4) | xh_1) - 16;

*v0 = x0 * d;
*v1 = x1 * d;
} void dequantize_q5_1(__global const struct block_q5_1 *x, const int ib, const int iqs, float *v0, float *v1)
{
*v0 = x0*d;
*v1 = x1*d;
}
void dequantize_q5_1(__global const struct block_q5_1* x, const int ib, const int iqs, float* v0, float* v1) {
const float d = vload_half(0, &x[ib].d);
const float m = vload_half(0, &x[ib].m);

uint32_t qh = x[ib].qh;

const uint8_t xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10;
const uint8_t xh_1 = ((qh >> (iqs + 12))) & 0x10;
const uint8_t xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10;
const uint8_t xh_1 = ((qh >> (iqs + 12)) ) & 0x10;

const int32_t x0 = ((x[ib].qs[iqs] & 0xf) | xh_0);
const int32_t x1 = ((x[ib].qs[iqs] >> 4) | xh_1);
const int32_t x1 = ((x[ib].qs[iqs] >> 4) | xh_1);

*v0 = x0 * d + m;
*v1 = x1 * d + m;
} void dequantize_q8_0(__global const struct block_q8_0 *x, const int ib, const int iqs, float *v0, float *v1)
{
*v0 = x0*d + m;
*v1 = x1*d + m;
}
void dequantize_q8_0(__global const struct block_q8_0* x, const int ib, const int iqs, float* v0, float* v1) {
const float d = vload_half(0, &x[ib].d);

const int8_t vi0 = x[ib].qs[iqs + 0];
const int8_t vi1 = x[ib].qs[iqs + 1];

*v0 = vi0 * d;
*v1 = vi1 * d;
} void convert_f16(__global half *x, const int ib, const int iqs, float *v0, float *v1)
{
*v0 = vi0*d;
*v1 = vi1*d;
}
void convert_f16(__global half* x, const int ib, const int iqs, float* v0, float* v1){
*v0 = vload_half(0, &x[ib + 0]);
*v1 = vload_half(0, &x[ib + 1]);
}
Expand Down Expand Up @@ -397,6 +395,95 @@ void vec_dot_q3_K(__global const struct block_q3_K* x, const int ib, const int i

}

void vec_dot_q4_K(__global const struct block_q4_K* x, const int ib, const int iqs, const __global float *yy, float *result) {

const int j = iqs / 64; // j is in 0...3
const int ir = (iqs - 64*j)/2; // ir is in 0...28 in steps of 4
const int is = 2*j; // is is in 0...6 in steps of 2

__global const float * y = yy + 64*j + ir;
__global const uint8_t * q = x[ib].qs + 32*j + ir;

const float dall = vload_half(0, &x[ib].d);
const float dmin = vload_half(0, &x[ib].dmin);

uint8_t sc, m;
get_scale_min_k4(is + 0, x[ib].scales, &sc, &m);
const float d1 = dall * sc;
const float m1 = dmin * m;
get_scale_min_k4(is + 1, x[ib].scales, &sc, &m);
const float d2 = dall * sc;
const float m2 = dmin * m;

float sum = 0;
for (int k = 0; k < 4; ++k) {
sum += y[k + 0] * (d1 * (q[k] & 0xF) - m1);
sum += y[k + 32] * (d2 * (q[k] >> 4) - m2);
}

*result = sum;
}

void vec_dot_q5_K(__global const struct block_q5_K* x, const int ib, const int iqs, const __global float *yy, float *result) {

const int j = iqs / 64;
const int ir = (iqs - 64*j)/2;
const int is = 2*j;

__global const float * y = yy + 64*j + ir;
__global const uint8_t * ql = x[ib].qs + 32*j + ir;
__global const uint8_t * qh = x[ib].qh + ir;

const float dall = vload_half(0, &x[ib].d);
const float dmin = vload_half(0, &x[ib].dmin);

uint8_t sc, m;
get_scale_min_k4(is + 0, x[ib].scales, &sc, &m);
const float d1 = dall * sc;
const float m1 = dmin * m;
get_scale_min_k4(is + 1, x[ib].scales, &sc, &m);
const float d2 = dall * sc;
const float m2 = dmin * m;

uint8_t hm = 1 << is;
float sum = 0;
for (int k = 0; k < 4; ++k) {
sum += y[k + 0] * (d1 * ((ql[k] & 0xF) + (qh[k] & hm ? 16 : 0)) - m1);
}
hm <<= 1;
for (int k = 0; k < 4; ++k) {
sum += y[k + 32] * (d2 * ((ql[k] >> 4) + (qh[k] & hm ? 16 : 0)) - m2);
}
*result = sum;

}

void vec_dot_q6_K(__global const struct block_q6_K* x, const int ib, const int iqs, const __global float *yy, float *result) {


const int ip = iqs / 128; // 0 or 1
const int il = (iqs - 128*ip)/8; // 0...15
const int is = 8*ip;

__global const float * y = yy + 128*ip + il;

const float d = vload_half(0, &x[ib].d);

__global const uint8_t * ql = x[ib].ql + 64*ip + il;
const uint8_t * qh = x[ib].qh + 32*ip + il;
__global const int8_t * sc = x[ib].scales + is;

*result = y[ 0] * d * sc[0] * ((int8_t)((ql[ 0] & 0xF) | (((qh[ 0] >> 0) & 3) << 4)) - 32)
+ y[ 32] * d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh[ 0] >> 2) & 3) << 4)) - 32)
+ y[ 64] * d * sc[4] * ((int8_t)((ql[ 0] >> 4) | (((qh[ 0] >> 4) & 3) << 4)) - 32)
+ y[ 96] * d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh[ 0] >> 6) & 3) << 4)) - 32)
+ y[ 16] * d * sc[1] * ((int8_t)((ql[16] & 0xF) | (((qh[16] >> 0) & 3) << 4)) - 32)
+ y[ 48] * d * sc[3] * ((int8_t)((ql[48] & 0xF) | (((qh[16] >> 2) & 3) << 4)) - 32)
+ y[ 80] * d * sc[5] * ((int8_t)((ql[16] >> 4) | (((qh[16] >> 4) & 3) << 4)) - 32)
+ y[112] * d * sc[7] * ((int8_t)((ql[48] >> 4) | (((qh[16] >> 6) & 3) << 4)) - 32);

}

);


Expand Down Expand Up @@ -566,9 +653,12 @@ std::array<std::string, 2> mul_str_values = {
"mul_f32", "float"
};

std::array<std::string, 6> dmmv_k_str_values = {
std::array<std::string, 15> dmmv_k_str_values = {
"dequantize_mul_mat_vec_q2_K", "struct block_q2_K", "vec_dot_q2_K",
"dequantize_mul_mat_vec_q3_K", "struct block_q3_K", "vec_dot_q3_K",
"dequantize_mul_mat_vec_q4_K", "struct block_q4_K", "vec_dot_q4_K",
"dequantize_mul_mat_vec_q5_K", "struct block_q5_K", "vec_dot_q5_K",
"dequantize_mul_mat_vec_q6_K", "struct block_q6_K", "vec_dot_q6_K",
};

std::string& replace(std::string& s, const std::string& from, const std::string& to) {
Expand Down Expand Up @@ -867,6 +957,9 @@ void ggml_cl_init(void) {
CL_CHECK((convert_mul_mat_vec_f16_cl = clCreateKernel(program, "convert_mul_mat_vec_f16", &err), err));
CL_CHECK((dequantize_mul_mat_vec_q2_K_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q2_K", &err), err));
CL_CHECK((dequantize_mul_mat_vec_q3_K_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q3_K", &err), err));
CL_CHECK((dequantize_mul_mat_vec_q4_K_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q4_K", &err), err));
CL_CHECK((dequantize_mul_mat_vec_q5_K_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q5_K", &err), err));
CL_CHECK((dequantize_mul_mat_vec_q6_K_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q6_K", &err), err));

// mul kernel
CL_CHECK((mul_f32_cl = clCreateKernel(program, "mul_f32", &err), err));
Expand Down