Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Proof of concept: GPU-accelerated token generation #1375

Closed

Conversation

JohannesGaessler
Copy link
Collaborator

@JohannesGaessler JohannesGaessler commented May 9, 2023

TLDR: this PR implements CUDA GPU acceleration for token generation. Even on my relatively slow GTX 1070 this is significantly faster than using the CPU:

gpu_speedup

To use it, provide the main binary with the --gpu_layers argument and specify how many layers should be GPU-accelerated. Only q4_0 is implemented.

Edit: build instructions (Linux):

git clone https://github.com/JohannesGaessler/llama.cpp llama.cpp-johannesgaessler
cd llama.cpp-johannesgaessler                               
git fetch
git switch dequantize-matmul-2
make LLAMA_CUBLAS=1

For building on Windows, read the llama.cpp README. These build instructions are outdated. They will install the development version in this PR that, while still compatible with the old quantization format, is slower than the version on the master branch and only supports q4_0.

Background

The multiplication of two square matrices of size $N$ is a very compute-intensive operation: it requires $O(N^3)$ multiplications on $O(N^2)$ data values. However, if one of the matrices is thin in one dimension then the matrix multiplication becomes much more I/O-bound because it now requires $O(N^2)$ multiplications on $O(N^2)$ data values. When llama.cpp is generating new tokens it spends ~90% of its runtime doing matrix multiplications and those matrix multiplications are exclusively matrix vector multiplications that are maximally I/O bound. As a consequence the speed at which new tokens can be generated is essentially just proportional to memory bandwidth:

memory_scaling_1

Notably the memory bandwidth on consumer GPUs is much higher than the memory bandwidth on consumer CPUs. I therefore looked into running at least part of llama.cpp on the GPU to speed up token generation.

Implementation

The GPU acceleration of token generation has two issues on master:

  1. The dequantization of matrices on the GPU is too slow. I looked into faster dequantization kernels like [DRAFT] Speedup dequantize kernels #1221 but even then the combined dequantization and matrix multiplication on a GTX 1070 was 2x slower than on the CPU. I therefore implemented a new CUDA kernel that does dequantization and matrix vector multiplication simultaneously.
  2. Transferring quantized matrices to the GPU adds significant latency (4x more than the runtime of the optimized dequantization and matrix multiplication). I therefore implemented storing the quantized matrices in VRAM rather than RAM on a tensor-by-tensor basis. I added a property backend to ggml_tensor that specifies where the data is stored. The data then needs to be transferred to the GPU only once. The vectors are still transferred to the GPU every time but on my hardware this adds relatively low overhead (see below). An additional benefit of this approach is that this can potentially be used to reduce the memory footprint on the host because the quantized matrices only need to be stored on the GPU (not implemented). My system with 32GB RAM can start thrashing with 33b q5_1 so even my current 8GB of VRAM would be a huge quality of life improvement.

Only the repeating layers of LLaMa are accelerated. The fixed layers at the beginning and end of the neural networks are still CPU only for token generation.

Results

On my hardware I found:

Model Num layers Baseline speed [t/s] (3200 MHz RAM) Max. accelerated layers (8 GB VRAM) Max. speed [t/s] (GTX 1070) Max. speedup (GTX 1070)
7b q4_0 32 9.15 32 12.50 1.36
13 q4_0 40 4.86 34 6.42 1.32
33b q4_0 60 1.96 19 2.22 1.12

There is a significant speedup for all model sizes though the speedup is considerably larger for the smaller models where a larger percentage fits into VRAM. The plot at the beginning of the PR shows the scaling as a function of the percentage of the model in VRAM. This speedup is essentially the same for all models which suggests that large amounts of VRAM is key. For larger models the maximum potential speedup seems to be higher.

Profiling with nvprof suggests that copying vectors between host and device is not a bottleneck on my hardware:

            Type  Time(%)      Time     Calls       Avg       Min       Max  Name
 GPU activities:   89.16%  24.7485s     33915  729.72us  411.84us  2.2789ms  void dequantize_mul_mat_q4_0<int=32>(void const *, float const *, float*, int)
                    7.37%  2.04551s       133  15.380ms  8.5321ms  24.494ms  dequantize_block_q4_0(void const *, float*)
                    2.27%  629.46ms     34181  18.415us  2.9760us  5.8011ms  [CUDA memcpy HtoD]
                    0.74%  204.55ms       133  1.5380ms  818.69us  2.8980ms  void gemmSN_TN_kernel<float, int=128, int=16, int=2, int=4, int=2, int=2, bool=1, cublasGemvTensorStridedBatched<float const >, cublasGemvTensorStridedBatched<float const >, cublasGemvTensorStridedBatched<float>>(cublasGemmSmallNParams<float const , cublasGemvTensorStridedBatched<float const >, cublasGemvTensorStridedBatched<float const >, float>)
                    0.47%  129.91ms     34048  3.8150us  2.6880us  12.288us  [CUDA memcpy DtoH]
      API calls:   90.14%  27.5341s     34181  805.54us  72.691us  26.766ms  cudaDeviceSynchronize
                    4.49%  1.37286s         5  274.57ms  79.531ms  743.99ms  cudaMallocHost
                    2.27%  694.01ms     68229  10.171us  2.2800us  5.7958ms  cudaMemcpyAsync
                    1.57%  481.06ms         5  96.211ms  391.06us  275.79ms  cudaFreeHost
                    0.76%  232.09ms       953  243.53us  2.3100us  10.969ms  cuLibraryLoadData
                    0.41%  125.95ms     34181  3.6840us  2.8100us  32.480us  cudaLaunchKernel
                    0.11%  34.327ms       142  241.74us  2.6600us  4.2271ms  cudaMalloc
                    0.10%  29.636ms     34048     870ns     620ns  11.970us  cudaEventRecord
                    0.09%  26.234ms     34048     770ns     650ns  206.77us  cudaStreamWaitEvent
                    0.03%  8.4284ms         2  4.2142ms  3.8226ms  4.6058ms  cudaFree
                    0.02%  5.9012ms     34181     172ns     150ns  3.6300us  cudaGetLastError
                    0.00%  336.79us       336  1.0020us     130ns  43.360us  cuDeviceGetAttribute
                    0.00%  148.54us       766     193ns     130ns  1.1500us  cuGetProcAddress
                    0.00%  93.821us        16  5.8630us  1.4200us  57.391us  cudaStreamCreateWithFlags
                    0.00%  42.180us        82     514ns     400ns  1.9300us  cudaEventCreateWithFlags
                    0.00%  25.690us         3  8.5630us  7.3400us  10.850us  cuDeviceGetName
                    0.00%  5.7400us         3  1.9130us     460ns  4.7300us  cudaGetDevice
                    0.00%  5.2300us        14     373ns     270ns  1.2900us  cudaDeviceGetAttribute
                    0.00%  4.9400us         1  4.9400us  4.9400us  4.9400us  cuDeviceGetPCIBusId
                    0.00%  1.9400us         2     970ns     910ns  1.0300us  cuInit
                    0.00%  1.5500us         5     310ns     150ns     680ns  cuDeviceGetCount
                    0.00%     880ns         3     293ns     260ns     340ns  cuDeviceTotalMem
                    0.00%     840ns         4     210ns     130ns     420ns  cuDeviceGet
                    0.00%     590ns         3     196ns     170ns     240ns  cuModuleGetLoadingMode
                    0.00%     550ns         1     550ns     550ns     550ns  cudaGetSymbolAddress
                    0.00%     530ns         3     176ns     170ns     180ns  cuDeviceGetUuid
                    0.00%     310ns         2     155ns     150ns     160ns  cuDriverGetVersion

Of course, I would be very interested to see the results that people with faster GPUs get with this PR; I personally plan to buy an RTX 3090 now that I've asserted that I'll be able to make use of it.

@ggerganov
Copy link
Owner

On GeForce RTX 4080 I hit about 37 tokens / sec with 13B and all layers on the GPU:

$  make -j && ./bin/main -m ../models/13B/ggml-model-q4_0.bin -p "I believe the meaning of life is" -c 2048 -n 512 --ignore-eos -s 4 -n 64 -t 8 --gpu_layers 40
[  3%] Built target BUILD_INFO
[  9%] Built target ggml
[ 15%] Built target llama
[ 21%] Built target test-sampling
[ 28%] Built target test-quantize-fns
[ 34%] Built target test-tokenizer-0
[ 40%] Built target quantize
[ 43%] Built target common
[ 50%] Built target test-quantize-perf
[ 56%] Built target quantize-stats
[ 62%] Built target main
[ 68%] Built target q8dot
[ 78%] Built target perplexity
[ 81%] Built target embedding
[ 90%] Built target vdot
[ 93%] Built target benchmark
[100%] Built target save-load-state
WARNING: when using cuBLAS generation results are NOT guaranteed to be reproducible.
main: build = 527 (3ed4588)
main: seed  = 4
llama.cpp: loading model from ../models/13B/ggml-model-q4_0.bin
llama_model_load_internal: format     = ggjt v1 (latest)
llama_model_load_internal: n_vocab    = 32000
llama_model_load_internal: n_ctx      = 2048
llama_model_load_internal: n_embd     = 5120
llama_model_load_internal: n_mult     = 256
llama_model_load_internal: n_head     = 40
llama_model_load_internal: n_layer    = 40
llama_model_load_internal: n_rot      = 128
llama_model_load_internal: ftype      = 2 (mostly Q4_0)
llama_model_load_internal: n_ff       = 13824
llama_model_load_internal: n_parts    = 1
llama_model_load_internal: model size = 13B
llama_model_load_internal: ggml ctx size =  90,75 KB
llama_model_load_internal: mem required  = 9807,48 MB (+ 1608,00 MB per state)
llama_init_from_file: kv self size  = 1600,00 MB

system_info: n_threads = 8 / 64 | AVX = 1 | AVX2 = 1 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 1 | NEON = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 1 | VSX = 0 | 
sampling: repeat_last_n = 64, repeat_penalty = 1,100000, presence_penalty = 0,000000, frequency_penalty = 0,000000, top_k = 40, tfs_z = 1,000000, top_p = 0,950000, typical_p = 1,000000, temp = 0,800000, mirostat = 0, mirostat_lr = 0,100000, mirostat_ent = 5,000000
generate: n_ctx = 2048, n_batch = 512, n_predict = 64, n_keep = 0


 I believe the meaning of life is simply to exist, but I also believe there is a power greater than myself that makes it worth existing.
I feel that this power is everywhere in every breath, every moment and every step we take.
When I was young, I didn’t understand why people fought so hard for their beliefs if they had
llama_print_timings:        load time =  2629,80 ms
llama_print_timings:      sample time =    26,83 ms /    64 runs   (    0,42 ms per run)
llama_print_timings: prompt eval time =   352,29 ms /     8 tokens (   44,04 ms per token)
llama_print_timings:        eval time =  1748,56 ms /    63 runs   (   27,75 ms per run)
llama_print_timings:       total time =  4415,35 ms

How hard would it be to make a PoC that performs integer dot products on the GPU?

I.e. quantize the src1 data, upload to the GPU, and perform integer multiplications of the quants as we do on the CPU.
Would be super useful to have a data point for this type of kernels.

@JohannesGaessler
Copy link
Collaborator Author

How hard would it be to make a PoC that performs integer dot products on the GPU?

I don't think it would be too hard; there's some work that I'll need to do for my master's thesis until Thursday but I'll look into it when I get the time.

@JohannesGaessler
Copy link
Collaborator Author

llama_print_timings: prompt eval time =   352,29 ms /     8 tokens (   44,04 ms per token)
llama_print_timings:        eval time =  1748,56 ms /    63 runs   (   27,75 ms per run)

What prompt eval times do you get for longer prompts? If it turns out that my custom kernel is faster than dequantization + cuBLAS matrix multiplication we may be able to speed up prompt processing as well (my code is very much not optimized for matrix matrix multiplication though).

@slaren
Copy link
Collaborator

slaren commented May 9, 2023

Very nice! I had experimented with keeping the weights in VRAM before, but I didn't expect that also implementing a quantized mat mul kernel would be fast enough to use it for generation. If we are open to more GPU changes in ggml, we could also implement the remaining operations in CUDA and avoid all the copies in the GPU layers altogether.

@slaren
Copy link
Collaborator

slaren commented May 9, 2023

What prompt eval times do you get for longer prompts? If it turns out that my custom kernel is faster than dequantization + cuBLAS matrix multiplication we may be able to speed up prompt processing as well (my code is very much not optimized for matrix matrix multiplication though).

In my 3080 with --gpu_layers 0 it is roughtly the same speed as before, so there is probably no reason to use cuBLAS anymore.

@JohannesGaessler
Copy link
Collaborator Author

In my 3080 with --gpu_layers 0 it is roughtly the same speed as before

To clarify: The new kernel that I implemented is not used for prompt processing, only for matrix vector multiplications. I was only thinking that if eval time is lower than prompt eval time then this would suggest that my kernel is faster than the current master dequantization + cuBLAS even though it would be very unoptimized for large matrices.

@slaren
Copy link
Collaborator

slaren commented May 9, 2023

Then I am not sure how to test that. Evaluating the prompt is still much faster than generating (13B --gpu_layers 40):

llama_print_timings: prompt eval time =  5331.65 ms /   631 tokens (    8.45 ms per token)
llama_print_timings:        eval time = 37672.51 ms /   255 runs   (  147.74 ms per run)

@JohannesGaessler
Copy link
Collaborator Author

Okay, thanks for the performance numbers; I was only intrigued by @ggerganov's numbers where eval time was lower than prompt eval time. Efficiently parallelizing the multiplication of large matrices is very hard so I think we should stick with cuBLAS.

@ggerganov
Copy link
Owner

@slaren

If we are open to more GPU changes in ggml, we could also implement the remaining operations in CUDA and avoid all the copies in the GPU layers altogether.

I think there is a viable idea for offloading ggml compute graphs to the GPU via second-pass processing of the graphs after they are generated and exported in certain format. However, it will take quite some time to get to a working prototype and see if that idea even makes sense. Probably we can do some CUDA integration in the meantime. As long as it is easy to detach it later, I think we can reconsider the "no-GPU in ggml" limitation.
Especially if we can demonstrate efficient 4-bit mat mul kernels.

@JohannesGaessler

Yes, it is unlikely that for we can outperform cuBLAS for bigger matrices.
Here is a run with a 228 token prompt:

 $  make -j && ./bin/main -m ../models/13B/ggml-model-q4_0.bin -c 2048 -b 2048 --no-mmap --ignore-eos -s 4 -n 64 -t 8 -f ../prompts/reason-act.txt --gpu_layers 40
[  6%] Built target BUILD_INFO
[  9%] Built target ggml
[ 15%] Built target llama
[ 25%] Built target test-quantize-fns
[ 28%] Built target test-sampling
[ 34%] Built target quantize
[ 40%] Built target quantize-stats
[ 46%] Built target test-quantize-perf
[ 56%] Built target test-tokenizer-0
[ 56%] Built target common
[ 62%] Built target main
[ 75%] Built target embedding
[ 75%] Built target benchmark
[ 81%] Built target perplexity
[ 87%] Built target vdot
[ 93%] Built target save-load-state
[100%] Built target q8dot
WARNING: when using cuBLAS generation results are NOT guaranteed to be reproducible.
main: build = 527 (3ed4588)
main: seed  = 4
llama.cpp: loading model from ../models/13B/ggml-model-q4_0.bin
llama_model_load_internal: format     = ggjt v1 (latest)
llama_model_load_internal: n_vocab    = 32000
llama_model_load_internal: n_ctx      = 2048
llama_model_load_internal: n_embd     = 5120
llama_model_load_internal: n_mult     = 256
llama_model_load_internal: n_head     = 40
llama_model_load_internal: n_layer    = 40
llama_model_load_internal: n_rot      = 128
llama_model_load_internal: ftype      = 2 (mostly Q4_0)
llama_model_load_internal: n_ff       = 13824
llama_model_load_internal: n_parts    = 1
llama_model_load_internal: model size = 13B
llama_model_load_internal: ggml ctx size = 7945710,75 KB
llama_model_load_internal: mem required  = 9807,48 MB (+ 1608,00 MB per state)
....................................................................................................
llama_init_from_file: kv self size  = 1600,00 MB

system_info: n_threads = 8 / 64 | AVX = 1 | AVX2 = 1 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 1 | NEON = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 1 | VSX = 0 | 
sampling: repeat_last_n = 64, repeat_penalty = 1,100000, presence_penalty = 0,000000, frequency_penalty = 0,000000, top_k = 40, tfs_z = 1,000000, top_p = 0,950000, typical_p = 1,000000, temp = 0,800000, mirostat = 0, mirostat_lr = 0,100000, mirostat_ent = 5,000000
generate: n_ctx = 2048, n_batch = 512, n_predict = 64, n_keep = 0


 You run in a loop of Thought, Action, Observation.
At the end of the loop either Answer or restate your Thought and Action.
Use Thought to describe your thoughts about the question you have been asked.
Use Action to run one of these actions available to you:
- calculate[python math expression]
Observation will be the result of running those actions


Question: What is 4 * 7 / 3?
Thought: Do I need to use an action? Yes, I use calculate to do math
Action: calculate[4 * 7 / 3]
Observation: 9.3333333333
Thought: Do I need to use an action? No, have the result
Answer: The calculate tool says it is 9.3333333333
Question: What is capital of france?
Thought: Do I need to use an action? No, I know the answer
Answer: Paris is the capital of France
Question: Why do we need math?
Thought: We do not need math. It is important but not a must have
Answer: We can live without it if all things fall apart
Question: What is 2 + 3 * 5?
Thought: Do I need to use an action? Yes
llama_print_timings:        load time =  5128,21 ms
llama_print_timings:      sample time =    27,44 ms /    64 runs   (    0,43 ms per run)
llama_print_timings: prompt eval time =   875,02 ms /   228 tokens (    3,84 ms per token)
llama_print_timings:        eval time =  2276,40 ms /    63 runs   (   36,13 ms per run)
llama_print_timings:       total time =  7442,29 ms

@slaren
Copy link
Collaborator

slaren commented May 9, 2023

Worth pointing that I am on PCIe 3.0, so for me the memcpys are probably a significantly larger overhead than for people on PCIe 4.0. @dfyz also noted that NVIDIA has a library (cutlass) for implementing high performance matrix multiplications, though I don't know if it is flexible enough to use it with quantized formats.

@JohannesGaessler
Copy link
Collaborator Author

I should point out that a major bottleneck for the implementation was memory coalescing. In my first version I parallelized across multiple rows instead of multiple columns and that kernel was slower than the CPU. So I think that if there is a matrix multiplication library we absolutely need a way to consider the memory layout in ggml.

@JohannesGaessler
Copy link
Collaborator Author

Anyways, regarding future GPU acceleration in ggml: I think this PR proves that we need specialized matrix vector multiplication kernels regardless of how they are invoked.

ggml-cuda.cu Outdated
Comment on lines 258 to 264
// sum up partial sums and write back result
for (int s=block_size/2; s>0; s>>=1) {
if (tid < s) {
tmp[tid] += tmp[tid + s];
}
__syncthreads();
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't need shared memory to sum results over a single warp; take a look at warpReduceSum() here. If you want to use more warps in the future, you can do intra-warp sums first, then aggregates the partial sums between warps via shared memory (blockReduceSum() in the same file).

I can't say for sure that this will make a significant difference here, but I saw a lot of memory-bound kernels where getting rid of shared memory (or reducing the number of shared memory transactions) resulted in speedups.

Copy link
Collaborator Author

@JohannesGaessler JohannesGaessler May 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I read an official NVIDIA PDF where they said the same thing regarding intra-warp synchronization but when I removed the __syncthreads() instructions I got incorrect results even with a block size of 32. In any case, the addition of the partial sums at the end is negligible for overall performance.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[...] when I removed the __syncthreads() instructions I got incorrect results even with a block size of 32

Is this code available somewhere? If you say that the addition overhead is negligible, then I guess it doesn't really matter in the big picture. I'm just curious if I can understand what went wrong.

Copy link
Collaborator Author

@JohannesGaessler JohannesGaessler May 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I misremembered. I get incorrect results if I remove __syncthreads(). I get correct results again if I then also define tmp as volatile but this gives me worse performance. And when I tried just removing the summation alltogether as a test the performance difference was negligible so I just kept __syncthreads().

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like we're talking about different things. This is what I had in mind: no __syncthreads(), no __shared__, no volatile, no arrays.

On my A100, it does look like getting rid of shared memory doesn't improve anything on its own: the average time of running the kernel with ./main -b 512 -t 12 -n 256 -f prompts/dan.txt -m models/13B/ggml-model-q4_0.bin --no-mmap -s 123456 --gpu_layers 40 is 72 microseconds. However, this might be beneficial if you want to run more warps per block. For example:

  • this variant without shared memory runs in 64 microseconds on the same A100
  • this variant with shared memory runs in 68 microseconds on the same A100

Bottom line: this might not be very useful on its own, but reducing usage of shared memory is always a good idea in the log run.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the high-effort post. I'm relatively inexperienced when it comes to low-level GPU programming and definitely appreciate it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested a version without shared memory on my hardware. It was 3.5% times faster. I'm not going to optimize block sizes right now because I would like to do that at the end.

@@ -255,6 +297,23 @@ static void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStre
dequantize_block_q8_0<<<nb, 1, 0, stream>>>(vx, y);
}

static void dequantize_mul_mat_q4_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should probably be renamed to dequantize_mul_mat_vec_q4_0_cuda() or something similar? I didn't read the PR description very carefully at first, and spent some time scratching my head and wondering if there is a missing dimension for y in the kernel.

Copy link
Collaborator Author

@JohannesGaessler JohannesGaessler May 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that something along the lines of mul_mat_vec would be a better name; I just forgot to change it.

@dfyz
Copy link
Collaborator

dfyz commented May 9, 2023

@dfyz also noted that NVIDIA has a library (cutlass) for implementing high performance matrix multiplications, though I don't know if it is flexible enough to use it with quantized formats.

CUTLASS is more of a collection of building blocks for creating your own custom GEMM than a library, and I believe it should be flexible enough. I think that e.g. this converter is pretty similar in spirit to what we need: it transforms a block of 8 nibbles into a block of 8 half-precision floats. The details might be very different for our use case (e.g., our nibble format is different, and we probably should convert to single-precision floats or even to uint4b_t suitable for tensor core multiplication), but it illustrates that CUTLASS is extremely customizable when it comes to data formats.

@SlyEcho
Copy link
Collaborator

SlyEcho commented May 9, 2023

It would be awesome to get rid of cuBLAS/rocBLAS which are super heavyweight for deployment. But creating GEMM kernels from scratch is hard. All of the BLAS versions have spent a long time to tune and optimize the code for some particular hardware. The benefit of custom kernels is that we can combine them with the dequantization.

Aside from cutlass, there are CLBlast on CUDA, MIOpenGEMM (OpenCL) and some others I can't remember now.

A lot of them are also not 100% fit for llama.cpp also, like they are computing $\mathbf{C} \gets \alpha\mathbf{A}^{(\top)}\mathbf{B}^{(\top)}+\beta\mathbf{C}$ when we need $\mathbf{C} \gets \mathbf{A}\mathbf{B}^{\top}$. f16 support is not guaranteed, especially on the CPU side not to mention mixed precision or quantized.

Comment on lines +1019 to +1030
#ifdef GGML_USE_CUBLAS
for (int i = 0; i < std::min(gpu_layers, int(hparams.n_layer)); ++i) {
auto & layer = model.layers[i];
ggml_cuda_transform_tensor(layer.wq);
ggml_cuda_transform_tensor(layer.wk);
ggml_cuda_transform_tensor(layer.wv);
ggml_cuda_transform_tensor(layer.wo);
ggml_cuda_transform_tensor(layer.w1);
ggml_cuda_transform_tensor(layer.w2);
ggml_cuda_transform_tensor(layer.w3);
}
#endif
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be possible to to have this be a generic API from ggml instead of something tied to CUDA?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe? The ggml_cuda_transform_tensor function transfers the data to VRAM, sets the data pointer to the data in VRAM, and sets the new backend field to GGML_BACKEND_CUDA. Then, when a matrix multiplication is done the program checks for backend to determine whether it should be GPU accelerated. In the CUDA code the data pointer is directly used instead of copying from RAM. For this approach to work 1:1 with e.g. OpenCL there needs to be a unified address space for both the CPU and the GPU, otherwise ggml_tensor will need a dedicated GPU memory pointer.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see your point, CL would additionally need an offset, maybe.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about this: the data property of ggml_tensor is a void pointer. For regular tensors and the transformed tensors in this PR this pointer simply points at the data. You could instead point it at whatever OpenCL has for GPU memory pointers and then dereference that pointer in turn. That would cost you one memory access per weight but I think that overhead is negligible.

Copy link
Collaborator

@HanClinto HanClinto May 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've been working on this same problem this week, albeit from a different angle. I've been experimenting with compiling llama.cpp for lower-end devices with integrated GPUs, such as Nvidia Tegra. On those devices, the memory is physically shared between the CPU and the GPU, allowing us to use a unified address space and leveraging zero-copy when doing GPU processing.

I see a lot of similarities between your approach and things that I've been experimenting with. One issue that I'm running into is that the ggml_cuda_h2d_tensor_2d function does a bit of transformation on the data in addition to copying it over to the GPU, so separating that functionality has made things a bit more difficult for me to accelerate with the zero-copy case.

This is the sort of thing that I'm attempting to do, but it's not fully operational yet:

            if (g_cudaZeroCopy) {
                c_X = (float *) ((char *) src0->data + i02*nb2 + i03*nb3);
                c_Y = (float *) ((char *) src1->data + i02*nb2 + i03*nb3);
                c_D = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
            } else {
                // copy data to device
                CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_X, src0, i03, i02, cudaStream));
                CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Y, src1, i03, i02, cudaStream));
            }

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wrote ggml_cuda_h2d_tensor_2d, it allowed operating on some tensors that were previously skipped for cuBLAS because they were not contiguous and by using the 2D CUDA memcpy there was a small perf increase. Now I'm thinking it would be better to copy the whole tensor or even keep it on device and use something like GemmStridedBatched on it (because this tensor also has the third dimension as well).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If tensors are kept on device, will that significantly hurt performance when the CPU operates on the graph?

Copy link
Collaborator

@SlyEcho SlyEcho May 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the KV cache, it is not read by the CPU, but currently the CPU writes to it, when a batch of embeddings is processed on per layer.

EDIT: they are not read by the CPU assuming that the matrix multiplications that they are used in are never done on the CPU, like in this patch, for example.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// important: storing RoPE-ed version of K in the KV cache!
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));

If ggml_cpy() could see that v and k are device pointers, it could copy to the GPU instead. k and v are submatrixes of the KV cache, so this is where the property of the device pointer comes useful, even if it cannot be accessed by the CPU, pointer arithmetic still works. When a new tensor view is created, the ->data points to the first element.

With cl_mem it is not possible and you need to have an offset saved also, which most API's usually accept.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now I'm thinking it would be better to copy the whole tensor or even keep it on device and use something like GemmStridedBatched on it (because this tensor also has the third dimension as well)

If we did this, then it would certainly make the zero-copy case much easier, because we would no longer be transforming and copying in the same motion.

@JohannesGaessler JohannesGaessler changed the title Proof of oncept: GPU-accelerated token generation Proof of concept: GPU-accelerated token generation May 9, 2023
@dfyz
Copy link
Collaborator

dfyz commented May 9, 2023

A lot of them are also not 100% fit for llama.cpp also, like they are computing $\mathbf{C} \gets \alpha\mathbf{A}^{(\top)}\mathbf{B}^{(\top)}+\beta\mathbf{C}$ when we need $\mathbf{C} \gets \mathbf{A}\mathbf{B}^{\top}$

Wait, why is this not a 100% fit? The former is just the general matrix multiplication API in BLAS, which by design covers all possible cases. You can convert it to the latter by setting $\alpha=1, \beta=0$, and specifying that $\mathbf{B}$ is transposed.

@SlyEcho
Copy link
Collaborator

SlyEcho commented May 9, 2023

Wait, why is this not a 100% fit? The former is just the general matrix multiplication API in BLAS

Well, yes it is general and with the right parameters it works for us. But can we be sure that it's optimized for the $\alpha=1\quad\beta=0$ case?

I forgot to mention this but none of the BLAS libraries allow for any strides that don't match element count.

@dfyz
Copy link
Collaborator

dfyz commented May 9, 2023

But can we be sure that it's optimized for the $\alpha=1\quad\beta=0$ case?

It's not something you have to optimize for, since neither $\alpha$ nor $\beta$ are on the critical path of GEMM. See for example, the main loop of the "indirect GEMM kernel" in CLBLast. Both scalars are only used when storing the computed results of the thread block. This is a tiny fraction of the overall GEMM time.

Side note: $\beta=0$ is typically special-cased, but not for performance reasons. When $\beta=0$, you need special care to handle NaN output elements – see, for example, this CLBlast commit.

I forgot to mention this but none of the BLAS libraries allow for any strides that don't match element count.

I'm not sure what you mean here. The GEMM API in BLAS allows you to use any stride you want by providing ld{a,b,c} parameters. For example, in CLBlast they are called a_ld, b_ld, and c_ld.

@Dampfinchen
Copy link

Speedups are indeed very impressive.

@slaren I remember you were manually enabling tensor cores for cuBLAS, which lead to a 5% speedup. I wonder if tensor core integration could further improve generation speed for Johanne's PR.

@dfyz
Copy link
Collaborator

dfyz commented May 10, 2023

I remember you were manually enabling tensor cores for cuBLAS, which lead to a 5% speedup. I wonder if tensor core integration could further improve generation speed for Johanne's PR.

Tensor cores allow you to multiply small sub-matrices stored in registers in a single instruction, so they only help if you are bottlenecked by compute (the 5% speedup happened for prompt processing, which is compute-bound). It is unlikely that tensor cores will improve the speed of matrix-vector multiplication during generation, which is memory-bound (see the Background section in the description of this PR).

examples/common.cpp Outdated Show resolved Hide resolved
@@ -597,7 +656,10 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor
const size_t q_sz = ggml_type_size(type) * x_ne / ggml_blck_size(type);

size_t x_size, y_size, d_size, q_size;
float * d_X = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * x_ne, &x_size);
float * d_X;
if (ne11 > 1) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggest abstracting a function to determine if we should use CuBlas to compute. ne11>1 is one of the criterial. we can do more test to see if there is other cases.

CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_Q, tensor, 0, 0, cudaStream2));
CUDA_CHECK(cudaDeviceSynchronize());

tensor->data = d_Q;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will we have a memory leak here? Do we need deallocate the original buffer?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is at least not more of a memory leak than master where the buffers are actually used. I plan to make a version that frees the memory in RAM (or rather doesn't allocate as much RAM in the first place) but my priority was getting a working proof of concept out.

@@ -271,6 +271,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
params.use_color = true;
} else if (arg == "--mlock") {
params.use_mlock = true;
} else if (arg == "--gpu_layers") {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggest gpu-layers to align with other options.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand what you mean.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

He's suggesting you replace the underscore with a hyphen, I think.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Horenbergerb is right. sorry didn't make it explicit. as other parameters are all using hyphen.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are not all using it consistently, but it is a different issue.

@SlyEcho
Copy link
Collaborator

SlyEcho commented May 11, 2023

Testing on AMD, I had to patch a little bit because __shfl_xor_sync() is not available 😞

    for (int mask=block_size/2; mask > 0; mask >>= 1) {
        partial_sum += __shfl_xor(partial_sum, mask, block_size);
    }

The block size is important, however:

./bin/main -m ../models/llama-7b-q4_0.bin -p "I believe the meaning of life is" -c 2048 -n 512 --ignore-eos -s 4 -n 64 -t 4 --gpu_layers 32

# block size 32:
llama_print_timings:        eval time =  3743.04 ms /    63 runs   (   59.41 ms per run)

# block size 64:
llama_print_timings:        eval time =  2867.05 ms /    63 runs   (   45.51 ms per run)

But it is not any faster than the previous version, however.

@ggerganov
Copy link
Owner

How hard would it be to make a PoC that performs integer dot products on the GPU?

I.e. quantize the src1 data, upload to the GPU, and perform integer multiplications of the quants as we do on the CPU. Would be super useful to have a data point for this type of kernels.

Anyone working on this?
I'm thinking about giving it a try - really curious if this will bring extra speed.

@JohannesGaessler
Copy link
Collaborator Author

JohannesGaessler commented May 11, 2023

Anyone working on this?
I'm thinking about giving it a try - really curious if this will bring extra speed.

I'm giving it a try myself right now. The results I get are incorrect anyways but a straightforward implementation seems to have worse performance than the current version. I think the problem is that the rationale for better performance on CPUs doesn't work out on GPUs. On CPUs you can do a bunch of cheap integer multiplications and then only do very few expensive floating point calculations per block. But on a GPU you run a bunch of threads in parallel anyways; there is no performance difference between multiplying once and multiplying 32 times within the same warp. So I think the CPU integer tricks just add more overhead. Of course I'll gladly let myself be proven wrong if someone manages to somehow write a fast integer multiplication kernel.

@JohannesGaessler
Copy link
Collaborator Author

@ggerganov I managed to produce a version that produces correct results but for some reason the performance is absolutely terrible, literally 16 s per token with 7b.

@SlyEcho
Copy link
Collaborator

SlyEcho commented May 11, 2023

But they do use i8 for ML in other implementations, for example bitsandbytes.

For one thing, it could reduce the GPU memory use.

@JohannesGaessler
Copy link
Collaborator Author

I think I figured out the problem with my version: I'm not copying the quantized data to VRAM.

@ggerganov ggerganov added the performance Speed related topics label May 11, 2023
@JohannesGaessler
Copy link
Collaborator Author

I fixed the performance for the integer version. Runtime is 20% higher but I didn't try to do any optimizations yet. I'll go to sleep now, feel free to play around with it in the meantime (same branch dequantize-matmul-3 as before).

@JohannesGaessler
Copy link
Collaborator Author

JohannesGaessler commented May 11, 2023

For one thing, it could reduce the GPU memory use.

The smallest matrix vector multiplication in 7b is between a 4096x4096 matrix and a vector of size 4096. The quantization version requested by ggerganov only affects the vector and will make virtually no difference for memory usage. For larger matrices the benefit will be even smaller.

@JohannesGaessler
Copy link
Collaborator Author

I have tried (and pushed) a version where each GPU thread sums up 4 integer products and then does a single float multiplication. The performance is worse than a version where each thread sums up only 2 integer products and then does a singe float multiplication. I think the fundamental problem with the q8_0 approach on GPUs is that it hurts memory coalescing. On GPUs you want the threads in a warp to operate on data that is as close together as possible because then the warp can fetch the same data for all threads. But if you have each thread work on more data values because you want more integer multiplications per float multiplication then each warp has to access more memory and performance goes down.

Again, if someone comes up with a way to do this efficiently I will happily admit I'm wrong but I think it just doesn't work out for GPUs.

@ggerganov
Copy link
Owner

ggerganov commented May 12, 2023

@JohannesGaessler

Here is another attempt at making integer-based Q4 - Q8 dot product: a3e6d62

On my 4090 4080, I get the following timings:

branch 7B (ms / token) 7B (token / sec) 13B (ms / token) 13B (token / sec)
new 16.0 63 25.0 40
old 22.0 45 36.0 28

So a speed-up of about x1.4 compared to the dequantize-matmul-3 branch.

Here a couple of logs with the new and old implementations:

branch: origin:dequantize-matmul-3-gg
$  ./bin/main -m ../models/7B/ggml-model-q4_0.bin -p "I believe the meaning of life is" -c 2048 -n 512 --ignore-eos -s 6 -n 64 -t 8 --gpu_layers 40
WARNING: when using cuBLAS generation results are NOT guaranteed to be reproducible.
main: build = 532 (8809d45)
main: seed  = 6
llama.cpp: loading model from ../models/7B/ggml-model-q4_0.bin
llama_model_load_internal: format     = ggjt v1 (latest)
llama_model_load_internal: n_vocab    = 32000
llama_model_load_internal: n_ctx      = 2048
llama_model_load_internal: n_embd     = 4096
llama_model_load_internal: n_mult     = 256
llama_model_load_internal: n_head     = 32
llama_model_load_internal: n_layer    = 32
llama_model_load_internal: n_rot      = 128
llama_model_load_internal: ftype      = 2 (mostly Q4_0)
llama_model_load_internal: n_ff       = 11008
llama_model_load_internal: n_parts    = 1
llama_model_load_internal: model size = 7B
llama_model_load_internal: ggml ctx size =  72,75 KB
llama_model_load_internal: mem required  = 5809,34 MB (+ 1026,00 MB per state)
llama_init_from_file: kv self size  = 1024,00 MB

system_info: n_threads = 8 / 64 | AVX = 1 | AVX2 = 1 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 1 | NEON = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 1 | VSX = 0 | 
sampling: repeat_last_n = 64, repeat_penalty = 1,100000, presence_penalty = 0,000000, frequency_penalty = 0,000000, top_k = 40, tfs_z = 1,000000, top_p = 0,950000, typical_p = 1,000000, temp = 0,800000, mirostat = 0, mirostat_lr = 0,100000, mirostat_ent = 5,000000
generate: n_ctx = 2048, n_batch = 512, n_predict = 64, n_keep = 0


 I believe the meaning of life is to find your purpose.
The reason you're here on this earth, whether it be as a doctor or an artist or whatever, is to fulfill your potential. If you're just hanging around in a boring job because you need the money and you want to retire at 65 with
llama_print_timings:        load time =  1886,78 ms
llama_print_timings:      sample time =    27,10 ms /    64 runs   (    0,42 ms per run)
llama_print_timings: prompt eval time =   178,20 ms /     8 tokens (   22,28 ms per token)
llama_print_timings:        eval time =  1001,04 ms /    63 runs   (   15,89 ms per run)
llama_print_timings:       total time =  2925,16 ms
 ggerganov  tdcu-5975  SSH  ~/development/github/llama.cpp/build-cublas 
 17:01:51  dequantize-matmul-3-gg  ✎  $  ./bin/main -m ../models/13B/ggml-model-q4_0.bin -p "I believe the meaning of life is" -c 2048 -n 512 --ignore-eos -s 6 -n 64 -t 8 --gpu_layers 40
WARNING: when using cuBLAS generation results are NOT guaranteed to be reproducible.
main: build = 532 (8809d45)
main: seed  = 6
llama.cpp: loading model from ../models/13B/ggml-model-q4_0.bin
llama_model_load_internal: format     = ggjt v1 (latest)
llama_model_load_internal: n_vocab    = 32000
llama_model_load_internal: n_ctx      = 2048
llama_model_load_internal: n_embd     = 5120
llama_model_load_internal: n_mult     = 256
llama_model_load_internal: n_head     = 40
llama_model_load_internal: n_layer    = 40
llama_model_load_internal: n_rot      = 128
llama_model_load_internal: ftype      = 2 (mostly Q4_0)
llama_model_load_internal: n_ff       = 13824
llama_model_load_internal: n_parts    = 1
llama_model_load_internal: model size = 13B
llama_model_load_internal: ggml ctx size =  90,75 KB
llama_model_load_internal: mem required  = 9807,48 MB (+ 1608,00 MB per state)
llama_init_from_file: kv self size  = 1600,00 MB

system_info: n_threads = 8 / 64 | AVX = 1 | AVX2 = 1 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 1 | NEON = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 1 | VSX = 0 | 
sampling: repeat_last_n = 64, repeat_penalty = 1,100000, presence_penalty = 0,000000, frequency_penalty = 0,000000, top_k = 40, tfs_z = 1,000000, top_p = 0,950000, typical_p = 1,000000, temp = 0,800000, mirostat = 0, mirostat_lr = 0,100000, mirostat_ent = 5,000000
generate: n_ctx = 2048, n_batch = 512, n_predict = 64, n_keep = 0


 I believe the meaning of life is happiness.
Happiness at work, in love, in your family and for everyone around you. The world will be a better place if we all are happy and do not judge people because of their different opinions.
I also believe in hard work.
This blog is about how I am trying to find happiness
llama_print_timings:        load time =  3086,30 ms
llama_print_timings:      sample time =    27,13 ms /    64 runs   (    0,42 ms per run)
llama_print_timings: prompt eval time =   352,41 ms /     8 tokens (   44,05 ms per token)
llama_print_timings:        eval time =  1573,32 ms /    63 runs   (   24,97 ms per run)
llama_print_timings:       total time =  4697,03 ms
branch: JohannesGaessler:dequantize-matmul-3
$  make -j && ./bin/main -m ../models/7B/ggml-model-q4_0.bin -p "I believe the meaning of life is" -c 2048 -n 512 --ignore-eos -s 6 -n 64 -t 8 --gpu_layers 40
[  3%] Built target BUILD_INFO
[  9%] Built target ggml
[ 15%] Built target llama
[ 28%] Built target test-tokenizer-0
[ 28%] Built target test-quantize-fns
[ 34%] Built target common
[ 37%] Built target test-sampling
[ 43%] Built target test-quantize-perf
[ 53%] Built target quantize-stats
[ 56%] Built target quantize
[ 65%] Built target main
[ 68%] Built target perplexity
[ 75%] Built target embedding
[ 96%] Built target q8dot
[ 96%] Built target vdot
[100%] Built target save-load-state
[100%] Built target benchmark
WARNING: when using cuBLAS generation results are NOT guaranteed to be reproducible.
main: build = 531 (e7b9d97)
main: seed  = 6
llama.cpp: loading model from ../models/7B/ggml-model-q4_0.bin
llama_model_load_internal: format     = ggjt v1 (latest)
llama_model_load_internal: n_vocab    = 32000
llama_model_load_internal: n_ctx      = 2048
llama_model_load_internal: n_embd     = 4096
llama_model_load_internal: n_mult     = 256
llama_model_load_internal: n_head     = 32
llama_model_load_internal: n_layer    = 32
llama_model_load_internal: n_rot      = 128
llama_model_load_internal: ftype      = 2 (mostly Q4_0)
llama_model_load_internal: n_ff       = 11008
llama_model_load_internal: n_parts    = 1
llama_model_load_internal: model size = 7B
llama_model_load_internal: ggml ctx size =  72,75 KB
llama_model_load_internal: mem required  = 5809,34 MB (+ 1026,00 MB per state)
llama_init_from_file: kv self size  = 1024,00 MB

system_info: n_threads = 8 / 64 | AVX = 1 | AVX2 = 1 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 1 | NEON = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 1 | VSX = 0 | 
sampling: repeat_last_n = 64, repeat_penalty = 1,100000, presence_penalty = 0,000000, frequency_penalty = 0,000000, top_k = 40, tfs_z = 1,000000, top_p = 0,950000, typical_p = 1,000000, temp = 0,800000, mirostat = 0, mirostat_lr = 0,100000, mirostat_ent = 5,000000
generate: n_ctx = 2048, n_batch = 512, n_predict = 64, n_keep = 0


 I believe the meaning of life is to find your purpose.
The world needs you and your unique talents. But, in order to find it, you need to understand yourself first. It’s a lifelong process but, in order to start living on purpose, you will need to get clear on who you are and what makes you tick
llama_print_timings:        load time =  1863,65 ms
llama_print_timings:      sample time =    26,88 ms /    64 runs   (    0,42 ms per run)
llama_print_timings: prompt eval time =   178,71 ms /     8 tokens (   22,34 ms per token)
llama_print_timings:        eval time =  1395,76 ms /    63 runs   (   22,15 ms per run)
llama_print_timings:       total time =  3296,49 ms
 ggerganov  tdcu-5975  SSH  ~/development/github/llama.cpp/build-cublas 
 17:03:34  ⚓ master-41654ef-7-ge7b9d97  $  make -j && ./bin/main -m ../models/13B/ggml-model-q4_0.bin -p "I believe the meaning of life is" -c 2048 -n 512 --ignore-eos -s 6 -n 64 -t 8 --gpu_layers 40
[  3%] Built target BUILD_INFO
[  9%] Built target ggml
[ 15%] Built target llama
[ 21%] Built target test-quantize-fns
[ 28%] Built target test-sampling
[ 37%] Built target quantize-stats
[ 40%] Built target test-tokenizer-0
[ 43%] Built target common
[ 50%] Built target quantize
[ 56%] Built target test-quantize-perf
[ 62%] Built target perplexity
[ 78%] Built target save-load-state
[ 78%] Built target vdot
[ 81%] Built target main
[ 87%] Built target benchmark
[ 93%] Built target q8dot
[100%] Built target embedding
WARNING: when using cuBLAS generation results are NOT guaranteed to be reproducible.
main: build = 531 (e7b9d97)
main: seed  = 6
llama.cpp: loading model from ../models/13B/ggml-model-q4_0.bin
llama_model_load_internal: format     = ggjt v1 (latest)
llama_model_load_internal: n_vocab    = 32000
llama_model_load_internal: n_ctx      = 2048
llama_model_load_internal: n_embd     = 5120
llama_model_load_internal: n_mult     = 256
llama_model_load_internal: n_head     = 40
llama_model_load_internal: n_layer    = 40
llama_model_load_internal: n_rot      = 128
llama_model_load_internal: ftype      = 2 (mostly Q4_0)
llama_model_load_internal: n_ff       = 13824
llama_model_load_internal: n_parts    = 1
llama_model_load_internal: model size = 13B
llama_model_load_internal: ggml ctx size =  90,75 KB
llama_model_load_internal: mem required  = 9807,48 MB (+ 1608,00 MB per state)
llama_init_from_file: kv self size  = 1600,00 MB

system_info: n_threads = 8 / 64 | AVX = 1 | AVX2 = 1 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 1 | NEON = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 1 | VSX = 0 | 
sampling: repeat_last_n = 64, repeat_penalty = 1,100000, presence_penalty = 0,000000, frequency_penalty = 0,000000, top_k = 40, tfs_z = 1,000000, top_p = 0,950000, typical_p = 1,000000, temp = 0,800000, mirostat = 0, mirostat_lr = 0,100000, mirostat_ent = 5,000000
generate: n_ctx = 2048, n_batch = 512, n_predict = 64, n_keep = 0


 I believe the meaning of life is happiness.
Happiness. It's what we all strive for, isn't it? We think, if only this happened, or that changed, then I would be happy. The problem with thinking in such a manner is that happiness is an emotion, which is fleeting and temporary at best
llama_print_timings:        load time =  2467,16 ms
llama_print_timings:      sample time =    26,68 ms /    64 runs   (    0,42 ms per run)
llama_print_timings: prompt eval time =   353,54 ms /     8 tokens (   44,19 ms per token)
llama_print_timings:        eval time =  2304,62 ms /    63 runs   (   36,58 ms per run)
llama_print_timings:       total time =  4808,65 ms

I'm interested if someone has a video card with more than 16GB, to measure the speed for 30B and 65B using this approach.

@JohannesGaessler
Copy link
Collaborator Author

JohannesGaessler commented May 12, 2023

On a GTX 1070 the q4_0_q8_0 performance is very bad:

branch 7b t/s
master (CPU) 9.03
dequantize-matmul-2 12.91
dequantize-matmul-3 10.29
dequantize-matmul-3-gg 4.38
dequantize-matmul-4 14.57

I would like to get one working version of GPU acceleration merged onto master that has good software design to enable easy modifications; variants to optimize performance for specific GPUs can then be developed afterwards. I fear that if we try to optimize performance too much from the get-go this will be stuck in development hell because performance has poor portability across GPUs.

@JohannesGaessler
Copy link
Collaborator Author

On my 4090, I get the following timings:

I'm interested if someone has a video card with more than 16GB, to measure the speed for 30B and 65B using this approach.

Did you mean 4080?

@@ -255,6 +294,23 @@ static void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStre
dequantize_block_q8_0<<<nb, 1, 0, stream>>>(vx, y);
}

static void dequantize_mul_mat_q4_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
// static int block_size = -1;
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest to remove commented code.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Already done for the version that I intend to get merged: #1412

@JohannesGaessler
Copy link
Collaborator Author

JohannesGaessler commented May 13, 2023

Closing this since #1412 got merged.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
performance Speed related topics
Projects
None yet
Development

Successfully merging this pull request may close these issues.

10 participants