-
Notifications
You must be signed in to change notification settings - Fork 10.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Proof of concept: GPU-accelerated token generation #1375
Proof of concept: GPU-accelerated token generation #1375
Conversation
On
How hard would it be to make a PoC that performs integer dot products on the GPU? I.e. quantize the |
I don't think it would be too hard; there's some work that I'll need to do for my master's thesis until Thursday but I'll look into it when I get the time. |
What prompt eval times do you get for longer prompts? If it turns out that my custom kernel is faster than dequantization + cuBLAS matrix multiplication we may be able to speed up prompt processing as well (my code is very much not optimized for matrix matrix multiplication though). |
Very nice! I had experimented with keeping the weights in VRAM before, but I didn't expect that also implementing a quantized mat mul kernel would be fast enough to use it for generation. If we are open to more GPU changes in ggml, we could also implement the remaining operations in CUDA and avoid all the copies in the GPU layers altogether. |
In my 3080 with |
To clarify: The new kernel that I implemented is not used for prompt processing, only for matrix vector multiplications. I was only thinking that if |
Then I am not sure how to test that. Evaluating the prompt is still much faster than generating (13B --gpu_layers 40):
|
Okay, thanks for the performance numbers; I was only intrigued by @ggerganov's numbers where |
I think there is a viable idea for offloading Yes, it is unlikely that for we can outperform cuBLAS for bigger matrices.
|
Worth pointing that I am on PCIe 3.0, so for me the memcpys are probably a significantly larger overhead than for people on PCIe 4.0. @dfyz also noted that NVIDIA has a library (cutlass) for implementing high performance matrix multiplications, though I don't know if it is flexible enough to use it with quantized formats. |
I should point out that a major bottleneck for the implementation was memory coalescing. In my first version I parallelized across multiple rows instead of multiple columns and that kernel was slower than the CPU. So I think that if there is a matrix multiplication library we absolutely need a way to consider the memory layout in ggml. |
Anyways, regarding future GPU acceleration in ggml: I think this PR proves that we need specialized matrix vector multiplication kernels regardless of how they are invoked. |
ggml-cuda.cu
Outdated
// sum up partial sums and write back result | ||
for (int s=block_size/2; s>0; s>>=1) { | ||
if (tid < s) { | ||
tmp[tid] += tmp[tid + s]; | ||
} | ||
__syncthreads(); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You don't need shared memory to sum results over a single warp; take a look at warpReduceSum()
here. If you want to use more warps in the future, you can do intra-warp sums first, then aggregates the partial sums between warps via shared memory (blockReduceSum()
in the same file).
I can't say for sure that this will make a significant difference here, but I saw a lot of memory-bound kernels where getting rid of shared memory (or reducing the number of shared memory transactions) resulted in speedups.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I read an official NVIDIA PDF where they said the same thing regarding intra-warp synchronization but when I removed the __syncthreads()
instructions I got incorrect results even with a block size of 32. In any case, the addition of the partial sums at the end is negligible for overall performance.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[...] when I removed the
__syncthreads()
instructions I got incorrect results even with a block size of 32
Is this code available somewhere? If you say that the addition overhead is negligible, then I guess it doesn't really matter in the big picture. I'm just curious if I can understand what went wrong.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I misremembered. I get incorrect results if I remove __syncthreads()
. I get correct results again if I then also define tmp
as volatile but this gives me worse performance. And when I tried just removing the summation alltogether as a test the performance difference was negligible so I just kept __syncthreads()
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like we're talking about different things. This is what I had in mind: no __syncthreads()
, no __shared__
, no volatile
, no arrays.
On my A100, it does look like getting rid of shared memory doesn't improve anything on its own: the average time of running the kernel with ./main -b 512 -t 12 -n 256 -f prompts/dan.txt -m models/13B/ggml-model-q4_0.bin --no-mmap -s 123456 --gpu_layers 40
is 72 microseconds. However, this might be beneficial if you want to run more warps per block. For example:
- this variant without shared memory runs in 64 microseconds on the same A100
- this variant with shared memory runs in 68 microseconds on the same A100
Bottom line: this might not be very useful on its own, but reducing usage of shared memory is always a good idea in the log run.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the high-effort post. I'm relatively inexperienced when it comes to low-level GPU programming and definitely appreciate it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tested a version without shared memory on my hardware. It was 3.5% times faster. I'm not going to optimize block sizes right now because I would like to do that at the end.
@@ -255,6 +297,23 @@ static void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStre | |||
dequantize_block_q8_0<<<nb, 1, 0, stream>>>(vx, y); | |||
} | |||
|
|||
static void dequantize_mul_mat_q4_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should probably be renamed to dequantize_mul_mat_vec_q4_0_cuda()
or something similar? I didn't read the PR description very carefully at first, and spent some time scratching my head and wondering if there is a missing dimension for y
in the kernel.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that something along the lines of mul_mat_vec
would be a better name; I just forgot to change it.
CUTLASS is more of a collection of building blocks for creating your own custom GEMM than a library, and I believe it should be flexible enough. I think that e.g. this converter is pretty similar in spirit to what we need: it transforms a block of 8 nibbles into a block of 8 half-precision floats. The details might be very different for our use case (e.g., our nibble format is different, and we probably should convert to single-precision floats or even to |
It would be awesome to get rid of cuBLAS/rocBLAS which are super heavyweight for deployment. But creating GEMM kernels from scratch is hard. All of the BLAS versions have spent a long time to tune and optimize the code for some particular hardware. The benefit of custom kernels is that we can combine them with the dequantization. Aside from cutlass, there are CLBlast on CUDA, MIOpenGEMM (OpenCL) and some others I can't remember now. A lot of them are also not 100% fit for llama.cpp also, like they are computing |
#ifdef GGML_USE_CUBLAS | ||
for (int i = 0; i < std::min(gpu_layers, int(hparams.n_layer)); ++i) { | ||
auto & layer = model.layers[i]; | ||
ggml_cuda_transform_tensor(layer.wq); | ||
ggml_cuda_transform_tensor(layer.wk); | ||
ggml_cuda_transform_tensor(layer.wv); | ||
ggml_cuda_transform_tensor(layer.wo); | ||
ggml_cuda_transform_tensor(layer.w1); | ||
ggml_cuda_transform_tensor(layer.w2); | ||
ggml_cuda_transform_tensor(layer.w3); | ||
} | ||
#endif |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be possible to to have this be a generic API from ggml instead of something tied to CUDA?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe? The ggml_cuda_transform_tensor
function transfers the data to VRAM, sets the data
pointer to the data in VRAM, and sets the new backend
field to GGML_BACKEND_CUDA
. Then, when a matrix multiplication is done the program checks for backend
to determine whether it should be GPU accelerated. In the CUDA code the data pointer is directly used instead of copying from RAM. For this approach to work 1:1 with e.g. OpenCL there needs to be a unified address space for both the CPU and the GPU, otherwise ggml_tensor
will need a dedicated GPU memory pointer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see your point, CL would additionally need an offset, maybe.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about this: the data
property of ggml_tensor
is a void pointer. For regular tensors and the transformed tensors in this PR this pointer simply points at the data. You could instead point it at whatever OpenCL has for GPU memory pointers and then dereference that pointer in turn. That would cost you one memory access per weight but I think that overhead is negligible.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've been working on this same problem this week, albeit from a different angle. I've been experimenting with compiling llama.cpp for lower-end devices with integrated GPUs, such as Nvidia Tegra. On those devices, the memory is physically shared between the CPU and the GPU, allowing us to use a unified address space and leveraging zero-copy when doing GPU processing.
I see a lot of similarities between your approach and things that I've been experimenting with. One issue that I'm running into is that the ggml_cuda_h2d_tensor_2d
function does a bit of transformation on the data in addition to copying it over to the GPU, so separating that functionality has made things a bit more difficult for me to accelerate with the zero-copy case.
This is the sort of thing that I'm attempting to do, but it's not fully operational yet:
if (g_cudaZeroCopy) {
c_X = (float *) ((char *) src0->data + i02*nb2 + i03*nb3);
c_Y = (float *) ((char *) src1->data + i02*nb2 + i03*nb3);
c_D = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
} else {
// copy data to device
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_X, src0, i03, i02, cudaStream));
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Y, src1, i03, i02, cudaStream));
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wrote ggml_cuda_h2d_tensor_2d
, it allowed operating on some tensors that were previously skipped for cuBLAS because they were not contiguous and by using the 2D CUDA memcpy there was a small perf increase. Now I'm thinking it would be better to copy the whole tensor or even keep it on device and use something like GemmStridedBatched on it (because this tensor also has the third dimension as well).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If tensors are kept on device, will that significantly hurt performance when the CPU operates on the graph?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the KV cache, it is not read by the CPU, but currently the CPU writes to it, when a batch of embeddings is processed on per layer.
EDIT: they are not read by the CPU assuming that the matrix multiplications that they are used in are never done on the CPU, like in this patch, for example.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// important: storing RoPE-ed version of K in the KV cache!
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
If ggml_cpy()
could see that v
and k
are device pointers, it could copy to the GPU instead. k
and v
are submatrixes of the KV cache, so this is where the property of the device pointer comes useful, even if it cannot be accessed by the CPU, pointer arithmetic still works. When a new tensor view is created, the ->data
points to the first element.
With cl_mem
it is not possible and you need to have an offset saved also, which most API's usually accept.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now I'm thinking it would be better to copy the whole tensor or even keep it on device and use something like GemmStridedBatched on it (because this tensor also has the third dimension as well)
If we did this, then it would certainly make the zero-copy case much easier, because we would no longer be transforming and copying in the same motion.
Wait, why is this not a 100% fit? The former is just the general matrix multiplication API in BLAS, which by design covers all possible cases. You can convert it to the latter by setting |
Well, yes it is general and with the right parameters it works for us. But can we be sure that it's optimized for the I forgot to mention this but none of the BLAS libraries allow for any strides that don't match element count. |
It's not something you have to optimize for, since neither Side note:
I'm not sure what you mean here. The GEMM API in BLAS allows you to use any stride you want by providing |
Speedups are indeed very impressive. @slaren I remember you were manually enabling tensor cores for cuBLAS, which lead to a 5% speedup. I wonder if tensor core integration could further improve generation speed for Johanne's PR. |
Tensor cores allow you to multiply small sub-matrices stored in registers in a single instruction, so they only help if you are bottlenecked by compute (the 5% speedup happened for prompt processing, which is compute-bound). It is unlikely that tensor cores will improve the speed of matrix-vector multiplication during generation, which is memory-bound (see the Background section in the description of this PR). |
@@ -597,7 +656,10 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor | |||
const size_t q_sz = ggml_type_size(type) * x_ne / ggml_blck_size(type); | |||
|
|||
size_t x_size, y_size, d_size, q_size; | |||
float * d_X = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * x_ne, &x_size); | |||
float * d_X; | |||
if (ne11 > 1) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggest abstracting a function to determine if we should use CuBlas to compute. ne11>1 is one of the criterial. we can do more test to see if there is other cases.
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_Q, tensor, 0, 0, cudaStream2)); | ||
CUDA_CHECK(cudaDeviceSynchronize()); | ||
|
||
tensor->data = d_Q; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will we have a memory leak here? Do we need deallocate the original buffer?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is at least not more of a memory leak than master where the buffers are actually used. I plan to make a version that frees the memory in RAM (or rather doesn't allocate as much RAM in the first place) but my priority was getting a working proof of concept out.
@@ -271,6 +271,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { | |||
params.use_color = true; | |||
} else if (arg == "--mlock") { | |||
params.use_mlock = true; | |||
} else if (arg == "--gpu_layers") { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggest gpu-layers to align with other options.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand what you mean.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
He's suggesting you replace the underscore with a hyphen, I think.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Horenbergerb is right. sorry didn't make it explicit. as other parameters are all using hyphen.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They are not all using it consistently, but it is a different issue.
Testing on AMD, I had to patch a little bit because for (int mask=block_size/2; mask > 0; mask >>= 1) {
partial_sum += __shfl_xor(partial_sum, mask, block_size);
} The block size is important, however:
But it is not any faster than the previous version, however. |
Anyone working on this? |
I'm giving it a try myself right now. The results I get are incorrect anyways but a straightforward implementation seems to have worse performance than the current version. I think the problem is that the rationale for better performance on CPUs doesn't work out on GPUs. On CPUs you can do a bunch of cheap integer multiplications and then only do very few expensive floating point calculations per block. But on a GPU you run a bunch of threads in parallel anyways; there is no performance difference between multiplying once and multiplying 32 times within the same warp. So I think the CPU integer tricks just add more overhead. Of course I'll gladly let myself be proven wrong if someone manages to somehow write a fast integer multiplication kernel. |
@ggerganov I managed to produce a version that produces correct results but for some reason the performance is absolutely terrible, literally 16 s per token with 7b. |
But they do use i8 for ML in other implementations, for example bitsandbytes. For one thing, it could reduce the GPU memory use. |
I think I figured out the problem with my version: I'm not copying the quantized data to VRAM. |
I fixed the performance for the integer version. Runtime is 20% higher but I didn't try to do any optimizations yet. I'll go to sleep now, feel free to play around with it in the meantime (same branch |
The smallest matrix vector multiplication in 7b is between a 4096x4096 matrix and a vector of size 4096. The quantization version requested by ggerganov only affects the vector and will make virtually no difference for memory usage. For larger matrices the benefit will be even smaller. |
I have tried (and pushed) a version where each GPU thread sums up 4 integer products and then does a single float multiplication. The performance is worse than a version where each thread sums up only 2 integer products and then does a singe float multiplication. I think the fundamental problem with the q8_0 approach on GPUs is that it hurts memory coalescing. On GPUs you want the threads in a warp to operate on data that is as close together as possible because then the warp can fetch the same data for all threads. But if you have each thread work on more data values because you want more integer multiplications per float multiplication then each warp has to access more memory and performance goes down. Again, if someone comes up with a way to do this efficiently I will happily admit I'm wrong but I think it just doesn't work out for GPUs. |
Here is another attempt at making integer-based On my
So a speed-up of about Here a couple of logs with the new and old implementations: branch: origin:dequantize-matmul-3-gg
branch: JohannesGaessler:dequantize-matmul-3
I'm interested if someone has a video card with more than 16GB, to measure the speed for 30B and 65B using this approach. |
On a GTX 1070 the q4_0_q8_0 performance is very bad:
I would like to get one working version of GPU acceleration merged onto master that has good software design to enable easy modifications; variants to optimize performance for specific GPUs can then be developed afterwards. I fear that if we try to optimize performance too much from the get-go this will be stuck in development hell because performance has poor portability across GPUs. |
Did you mean 4080? |
@@ -255,6 +294,23 @@ static void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStre | |||
dequantize_block_q8_0<<<nb, 1, 0, stream>>>(vx, y); | |||
} | |||
|
|||
static void dequantize_mul_mat_q4_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { | |||
// static int block_size = -1; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would suggest to remove commented code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Already done for the version that I intend to get merged: #1412
Closing this since #1412 got merged. |
TLDR: this PR implements CUDA GPU acceleration for token generation. Even on my relatively slow GTX 1070 this is significantly faster than using the CPU:
To use it, provide the
main
binary with the--gpu_layers
argument and specify how many layers should be GPU-accelerated. Only q4_0 is implemented.Edit: build instructions (Linux):
For building on Windows, read the llama.cpp README. These build instructions are outdated. They will install the development version in this PR that, while still compatible with the old quantization format, is slower than the version on the master branch and only supports q4_0.
Background
The multiplication of two square matrices of size$N$ is a very compute-intensive operation: it requires $O(N^3)$ multiplications on $O(N^2)$ data values. However, if one of the matrices is thin in one dimension then the matrix multiplication becomes much more I/O-bound because it now requires $O(N^2)$ multiplications on $O(N^2)$ data values. When llama.cpp is generating new tokens it spends ~90% of its runtime doing matrix multiplications and those matrix multiplications are exclusively matrix vector multiplications that are maximally I/O bound. As a consequence the speed at which new tokens can be generated is essentially just proportional to memory bandwidth:
Notably the memory bandwidth on consumer GPUs is much higher than the memory bandwidth on consumer CPUs. I therefore looked into running at least part of llama.cpp on the GPU to speed up token generation.
Implementation
The GPU acceleration of token generation has two issues on master:
backend
toggml_tensor
that specifies where the data is stored. The data then needs to be transferred to the GPU only once. The vectors are still transferred to the GPU every time but on my hardware this adds relatively low overhead (see below). An additional benefit of this approach is that this can potentially be used to reduce the memory footprint on the host because the quantized matrices only need to be stored on the GPU (not implemented). My system with 32GB RAM can start thrashing with 33b q5_1 so even my current 8GB of VRAM would be a huge quality of life improvement.Only the repeating layers of LLaMa are accelerated. The fixed layers at the beginning and end of the neural networks are still CPU only for token generation.
Results
On my hardware I found:
There is a significant speedup for all model sizes though the speedup is considerably larger for the smaller models where a larger percentage fits into VRAM. The plot at the beginning of the PR shows the scaling as a function of the percentage of the model in VRAM. This speedup is essentially the same for all models which suggests that large amounts of VRAM is key. For larger models the maximum potential speedup seems to be higher.
Profiling with
nvprof
suggests that copying vectors between host and device is not a bottleneck on my hardware:Of course, I would be very interested to see the results that people with faster GPUs get with this PR; I personally plan to buy an RTX 3090 now that I've asserted that I'll be able to make use of it.