Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Proof of concept: GPU-accelerated token generation #1375

Closed
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Works but slower than CPU
  • Loading branch information
JohannesGaessler committed May 9, 2023
commit 229aa1f504369686a3eafdfb2830116e95dc107f
104 changes: 82 additions & 22 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,33 @@ static __global__ void dequantize_block_q8_0(const void * vx, float * y) {
}
}

static __global__ void dequantize_mul_mat_q4_0(const void * vx, const float * y, float * dst, int ncols, int nrows) {
const block_q4_0 * x = (const block_q4_0 *) vx;

const int row = blockIdx.x*blockDim.x + threadIdx.x;

if (row >= nrows) {
return;
}
dst[row] = 0;
for (int i = 0; i < ncols; i += 2) {
const float d = x[(row*ncols + i)/QK4_0].d;

const uint8_t * pp = x[(row*ncols + i)/QK4_0].qs;

const uint8_t vui = pp[((row*ncols + i)%QK4_0)/2];

const int8_t vi0 = vui & 0xF;
const int8_t vi1 = vui >> 4;

const float v0 = (vi0 - 8)*d;
const float v1 = (vi1 - 8)*d;

dst[row] += v0 * y[i + 0];
dst[row] += v1 * y[i + 1];
}
}

static void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
const int nb = k / QK4_0;
dequantize_block_q4_0<<<nb, 1, 0, stream>>>(vx, y);
Expand Down Expand Up @@ -255,6 +282,17 @@ static void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStre
dequantize_block_q8_0<<<nb, 1, 0, stream>>>(vx, y);
}

static void dequantize_mul_mat_q4_0_cuda(const void * vx, const float * y, float * dst, int ncols, int nrows, cudaStream_t stream) {
static int block_size = -1;
if (block_size == -1) {
int min_grid_size;
CUDA_CHECK(cudaOccupancyMaxPotentialBlockSize(&min_grid_size, &block_size, dequantize_mul_mat_q4_0, 0, 0));
block_size = min(block_size, GGML_CUDA_MAX_BLOCK_SIZE);
}
const int grid_size = (nrows + block_size - 1) / block_size; // Round up.
dequantize_mul_mat_q4_0<<<grid_size, block_size, 0, stream>>>(vx, y, dst, ncols, nrows);
}

// TODO: optimize
static __global__ void convert_fp16_to_fp32(const void * vx, float * y) {
const half * x = (const half *) vx;
Expand Down Expand Up @@ -597,7 +635,10 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor
const size_t q_sz = ggml_type_size(type) * x_ne / ggml_blck_size(type);

size_t x_size, y_size, d_size, q_size;
float * d_X = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * x_ne, &x_size);
float * d_X;
if (ne11 > 1) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggest abstracting a function to determine if we should use CuBlas to compute. ne11>1 is one of the criterial. we can do more test to see if there is other cases.

d_X = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * x_ne, &x_size);
}
float * d_Y = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * y_ne, &y_size);
float * d_D = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * d_ne, &d_size);
char * d_Q = (char *) ggml_cuda_pool_malloc(n_mm * q_sz, &q_size);
Expand All @@ -612,31 +653,49 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor
cudaStream_t cudaStream2 = g_cudaStreams2[i % GGML_CUDA_MAX_STREAMS];
cudaEvent_t cudaEvent = g_cudaEvents[i % GGML_CUDA_MAX_EVENTS];

float * c_X = d_X + i * x_ne;
float * c_Y = d_Y + i * y_ne;
float * c_D = d_D + i * d_ne;
char * c_Q = d_Q + i * q_sz;

// copy src0 and convert to fp32 on device
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Q, src0, i03, i02, cudaStream2));
to_fp32_cuda(c_Q, c_X, x_ne, cudaStream2);
CUDA_CHECK(cudaGetLastError());
CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2));
if (ne11 == 1) {
// copy src0 to device
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Q, src0, i03, i02, cudaStream2));
CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2));

// copy src1 to device
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Y, src1, i03, i02, cudaStream));
// copy src1 to device
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Y, src1, i03, i02, cudaStream));

// wait for conversion
CUDA_CHECK(cudaStreamWaitEvent(cudaStream, cudaEvent, 0));
// wait for data
CUDA_CHECK(cudaStreamWaitEvent(cudaStream, cudaEvent, 0));

// compute
CUBLAS_CHECK(cublasSetStream(g_cublasH, cudaStream));
CUBLAS_CHECK(
cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
ne01, ne11, ne10,
&alpha, c_X, ne00,
c_Y, ne10,
&beta, c_D, ne01));
// compute
dequantize_mul_mat_q4_0_cuda(c_Q, c_Y, c_D, ne00, ne01, cudaStream);
CUDA_CHECK(cudaGetLastError());

} else {
float * c_X = d_X + i * x_ne;

// copy src0 and convert to fp32 on device
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Q, src0, i03, i02, cudaStream2));
to_fp32_cuda(c_Q, c_X, x_ne, cudaStream2);
CUDA_CHECK(cudaGetLastError());
CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2));

// copy src1 to device
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Y, src1, i03, i02, cudaStream));

// wait for conversion
CUDA_CHECK(cudaStreamWaitEvent(cudaStream, cudaEvent, 0));

// compute
CUBLAS_CHECK(cublasSetStream(g_cublasH, cudaStream));
CUBLAS_CHECK(
cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
ne01, ne11, ne10,
&alpha, c_X, ne00,
c_Y, ne10,
&beta, c_D, ne01));
}

// copy dst to host
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
Expand All @@ -645,7 +704,9 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor
}

CUDA_CHECK(cudaDeviceSynchronize());
ggml_cuda_pool_free(d_X, x_size);
if (ne11 > 1) {
ggml_cuda_pool_free(d_X, x_size);
}
ggml_cuda_pool_free(d_Y, y_size);
ggml_cuda_pool_free(d_D, d_size);
ggml_cuda_pool_free(d_Q, q_size);
Expand All @@ -660,8 +721,7 @@ bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_te
// TODO: find the optimal values for these
if ((src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) &&
src1->type == GGML_TYPE_F32 &&
dst->type == GGML_TYPE_F32 &&
(ne0 >= 32 && ne1 >= 32 && ne10 >= 32)) {
dst->type == GGML_TYPE_F32) {

return true;
}
Expand Down