Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Proof of concept: GPU-accelerated token generation #1375

Closed
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Faster than CPU without 80% runtime memcpy
  • Loading branch information
JohannesGaessler committed May 9, 2023
commit d052a0ed4ce5d75aa0c0db243ebc99569a88d8e0
63 changes: 42 additions & 21 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -225,30 +225,45 @@ static __global__ void dequantize_block_q8_0(const void * vx, float * y) {
}
}

static __global__ void dequantize_mul_mat_q4_0(const void * vx, const float * y, float * dst, int ncols, int nrows) {
template <int block_size> static __global__ void dequantize_mul_mat_q4_0(const void * vx, const float * y, float * dst, const int ncols) {
const block_q4_0 * x = (const block_q4_0 *) vx;

const int row = blockIdx.x*blockDim.x + threadIdx.x;
const int row = blockIdx.x;
const int tid = threadIdx.x;

if (row >= nrows) {
return;
}
dst[row] = 0;
for (int i = 0; i < ncols; i += 2) {
const float d = x[(row*ncols + i)/QK4_0].d;
__shared__ float tmp[block_size]; // separate sum for each thread
tmp[tid] = 0;

for (int i = 0; i < ncols/block_size; i += 2) {
const int col = i*block_size + 2*tid;

const uint8_t * pp = x[(row*ncols + i)/QK4_0].qs;
// dequantize
const float d = x[(row*ncols + col)/QK4_0].d;

const uint8_t vui = pp[((row*ncols + i)%QK4_0)/2];
const uint8_t * pp = x[(row*ncols + col)/QK4_0].qs;

const uint8_t vui = pp[((row*ncols + col)%QK4_0)/2];

const int8_t vi0 = vui & 0xF;
const int8_t vi1 = vui >> 4;

const float v0 = (vi0 - 8)*d;
const float v1 = (vi1 - 8)*d;

dst[row] += v0 * y[i + 0];
dst[row] += v1 * y[i + 1];
// matrix multiplication
tmp[tid] += v0 * y[col + 0];
tmp[tid] += v1 * y[col + 1];
}

// sum up partial sums and write back result
for (int s=block_size/2; s>0; s>>=1) {
if (tid < s) {
tmp[tid] += tmp[tid + s];
}
__syncthreads();
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't need shared memory to sum results over a single warp; take a look at warpReduceSum() here. If you want to use more warps in the future, you can do intra-warp sums first, then aggregates the partial sums between warps via shared memory (blockReduceSum() in the same file).

I can't say for sure that this will make a significant difference here, but I saw a lot of memory-bound kernels where getting rid of shared memory (or reducing the number of shared memory transactions) resulted in speedups.

Copy link
Collaborator Author

@JohannesGaessler JohannesGaessler May 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I read an official NVIDIA PDF where they said the same thing regarding intra-warp synchronization but when I removed the __syncthreads() instructions I got incorrect results even with a block size of 32. In any case, the addition of the partial sums at the end is negligible for overall performance.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[...] when I removed the __syncthreads() instructions I got incorrect results even with a block size of 32

Is this code available somewhere? If you say that the addition overhead is negligible, then I guess it doesn't really matter in the big picture. I'm just curious if I can understand what went wrong.

Copy link
Collaborator Author

@JohannesGaessler JohannesGaessler May 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I misremembered. I get incorrect results if I remove __syncthreads(). I get correct results again if I then also define tmp as volatile but this gives me worse performance. And when I tried just removing the summation alltogether as a test the performance difference was negligible so I just kept __syncthreads().

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like we're talking about different things. This is what I had in mind: no __syncthreads(), no __shared__, no volatile, no arrays.

On my A100, it does look like getting rid of shared memory doesn't improve anything on its own: the average time of running the kernel with ./main -b 512 -t 12 -n 256 -f prompts/dan.txt -m models/13B/ggml-model-q4_0.bin --no-mmap -s 123456 --gpu_layers 40 is 72 microseconds. However, this might be beneficial if you want to run more warps per block. For example:

  • this variant without shared memory runs in 64 microseconds on the same A100
  • this variant with shared memory runs in 68 microseconds on the same A100

Bottom line: this might not be very useful on its own, but reducing usage of shared memory is always a good idea in the log run.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the high-effort post. I'm relatively inexperienced when it comes to low-level GPU programming and definitely appreciate it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested a version without shared memory on my hardware. It was 3.5% times faster. I'm not going to optimize block sizes right now because I would like to do that at the end.

if (tid == 0) {
dst[row] = tmp[0];
}
}

Expand Down Expand Up @@ -282,15 +297,21 @@ static void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStre
dequantize_block_q8_0<<<nb, 1, 0, stream>>>(vx, y);
}

static void dequantize_mul_mat_q4_0_cuda(const void * vx, const float * y, float * dst, int ncols, int nrows, cudaStream_t stream) {
static int block_size = -1;
if (block_size == -1) {
int min_grid_size;
CUDA_CHECK(cudaOccupancyMaxPotentialBlockSize(&min_grid_size, &block_size, dequantize_mul_mat_q4_0, 0, 0));
block_size = min(block_size, GGML_CUDA_MAX_BLOCK_SIZE);
}
const int grid_size = (nrows + block_size - 1) / block_size; // Round up.
dequantize_mul_mat_q4_0<<<grid_size, block_size, 0, stream>>>(vx, y, dst, ncols, nrows);
static void dequantize_mul_mat_q4_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should probably be renamed to dequantize_mul_mat_vec_q4_0_cuda() or something similar? I didn't read the PR description very carefully at first, and spent some time scratching my head and wondering if there is a missing dimension for y in the kernel.

Copy link
Collaborator Author

@JohannesGaessler JohannesGaessler May 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that something along the lines of mul_mat_vec would be a better name; I just forgot to change it.

// static int block_size = -1;
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest to remove commented code.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Already done for the version that I intend to get merged: #1412

// if (block_size == -1) {
// int min_grid_size, max_block_size = 1;
// CUDA_CHECK(cudaOccupancyMaxPotentialBlockSize(&min_grid_size, &max_block_size, dequantize_mul_mat_q4_0<256>, 0, 0));
// max_block_size = min(max_block_size, GGML_CUDA_MAX_BLOCK_SIZE);
// block_size = 1;
// while (block_size*2 <= max_block_size && block_size*2 % ncols == 0) {
// block_size *= 2;
// }
// }
// dequantize_mul_mat_q4_0<<<nrows, block_size, 0, stream>>>(vx, y, dst, ncols);
const int block_size = 32;
GGML_ASSERT(ncols % block_size == 0);
dequantize_mul_mat_q4_0<block_size><<<nrows, block_size, 0, stream>>>(vx, y, dst, ncols);
}

// TODO: optimize
Expand Down