Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama : compute BERT graph with F16 K, V #5891

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open

Conversation

ggerganov
Copy link
Owner

Before #5796, we were using F16 KV cache. Cast K and V to F16 to match the performance

@ggerganov ggerganov mentioned this pull request Mar 5, 2024
1 task
@slaren
Copy link
Collaborator

slaren commented Mar 5, 2024

I noticed that the first evaluation with BERT is a lot faster than the rest (about 2x), hence the huge variance. Is this expected? Maybe there is some cleanup necessary between runs that llama-bench is not doing?

$ ./llama-bench -m models/nomic-embed/ggml-model-f16.gguf -p  8192 -n 0 -o json -r 20
...
    "samples_ns": [ 55305409, 106761860, 103784178, 103219933, 102769295, 106150938, 103377148, 103227991, 103466474, 103412610, 103204537, 103091448, 102517562, 102943769, 102898136, 102444703, 103308155, 103430029, 103645031, 103997954 ],
    "samples_ts": [ 148123, 76731.5, 78933, 79364.5, 79712.5, 77173.1, 79243.8, 79358.3, 79175.4, 79216.6, 79376.4, 79463.4, 79908.3, 79577.4, 79612.7, 79965.1, 79296.7, 79203.3, 79039, 78770.8 ]
  }

@ggerganov
Copy link
Owner Author

Similar observation on my GPU:

  Device 0: NVIDIA GeForce RTX 2060 SUPER, compute capability 7.5, VMM: yes
[
  {
    "build_commit": "40ca2e03",
    "build_number": 2351,
    "model_type": "bert 33M F16",
    ...
    "samples_ns": [ 132941728, 147430247, 147204893, 147173373, 147210412, 147196176, 147225462, 147165247, 147198780, 147198841 ],
    "samples_ts": [ 61621, 55565.3, 55650.3, 55662.2, 55648.2, 55653.6, 55642.5, 55665.3, 55652.6, 55652.6 ]
  }
]

Not expected. Maybe some issue with the input preparation, though I'm not sure where it could be

@slaren
Copy link
Collaborator

slaren commented Mar 5, 2024

The problem is that llama-bench does not set cparams.embeddings, and the model does not produce logits either. The result is that there isn't any data copied back to the CPU from the GPU, and there aren't any calls made to ggml_backend_synchronize. So all that it is doing is queueing evaluations on the GPU and never waiting for them to end, since ggml_backend_graph_compute in the CUDA backend is asynchronous. Setting cparams.embeddings to true in llama-bench results in much more repeatable (and reasonable) results:

"samples_ts": [ 15884.9, 17033.2, 17236.4, 17281.7, 17280.4, 17326.7, 17305.6, 17287.7, 17252.3, 17251.9, 17300.1, 17247.9, 17218.3, 17208.6, 17266.2, 17175.7, 17189.4, 17236.7, 17301.3, 17290.8 ]

With that fix applied, and the fix to compare-llama-bench.py, I get these results vs master:

scripts/compare-commits.sh master gg/bert-f16  -m models/nomic-embed/ggml-model-f16.gguf -p 512,1024,2048,4096,8192 -b 8192 -n 0 -r 10           
GPU Model Test t/s master t/s gg/bert-f16 Speedup
RTX 3090 Ti nomic-bert 137M F16 pp512 55166.66 72500.62 1.31
RTX 3090 Ti nomic-bert 137M F16 pp1024 52499.25 61251.00 1.17
RTX 3090 Ti nomic-bert 137M F16 pp2048 48672.47 48633.91 1.00
RTX 3090 Ti nomic-bert 137M F16 pp4096 33049.97 30725.15 0.93
RTX 3090 Ti nomic-bert 137M F16 pp8192 20921.86 17152.72 0.82

@iamlemec
Copy link
Collaborator

iamlemec commented Mar 6, 2024

Great! Running fast here. I'm now seeing about a 5% speed boost relative to pre-#5796 for GPU. About 20% for CPU.

@ggerganov
Copy link
Owner Author

It's surprising that for large batch sizes the F32 version is faster. On RTX 2060 with 33M model, F16 is faster only for BS = 512:

  • master
model size params backend ngl n_batch test t/s
bert 33M F16 63.46 MiB 33.21 M CUDA 99 8192 pp 512 45000.11 ± 3151.59
bert 33M F16 63.46 MiB 33.21 M CUDA 99 8192 pp 1024 35720.98 ± 128.35
bert 33M F16 63.46 MiB 33.21 M CUDA 99 8192 pp 2048 28716.76 ± 121.26
bert 33M F16 63.46 MiB 33.21 M CUDA 99 8192 pp 4096 19417.52 ± 49.16
bert 33M F16 63.46 MiB 33.21 M CUDA 99 8192 pp 8192 9123.49 ± 6.09

build: bd83694 (2350)

  • PR
model size params backend ngl n_batch test t/s
bert 33M F16 63.46 MiB 33.21 M CUDA 99 8192 pp 512 50991.57 ± 365.12
bert 33M F16 63.46 MiB 33.21 M CUDA 99 8192 pp 1024 35353.14 ± 40.56
bert 33M F16 63.46 MiB 33.21 M CUDA 99 8192 pp 2048 26109.02 ± 33.45
bert 33M F16 63.46 MiB 33.21 M CUDA 99 8192 pp 4096 16230.26 ± 6.94
bert 33M F16 63.46 MiB 33.21 M CUDA 99 8192 pp 8192 8297.65 ± 2.88

build: 40ca2e0 (2351)

model size params backend ngl n_batch test t/s
bert 33M F16 63.46 MiB 33.21 M CUDA 99 8192 pp 512 50444.46 ± 387.10
bert 33M F16 63.46 MiB 33.21 M CUDA 99 8192 pp 1024 34794.17 ± 27.87
bert 33M F16 63.46 MiB 33.21 M CUDA 99 8192 pp 2048 25528.46 ± 13.41
bert 33M F16 63.46 MiB 33.21 M CUDA 99 8192 pp 4096 15831.46 ± 7.02
bert 33M F16 63.46 MiB 33.21 M CUDA 99 8192 pp 8192 8030.40 ± 6.07

build: e0843af (2340)

Should we keep the master version?
I imaging that embeddings are usually computed on large amounts of tokens and/or many small texts batched together. So large batch sizes seem more important than small ones

@iamlemec
Copy link
Collaborator

iamlemec commented Mar 6, 2024

Edit: Updated numbers with fix from @slaren. Now seeing similar results on an RTX A6000.

In terms of batch size focus, one use case for small batches is when you're trying to reduce response time for a single query to a vector db. Those will typically be small and single sequence. For batched small-ish chunks, I think you may as well keep batch sizes small to avoid quadratic attention costs. The newer long context stuff would be the concern.

  • master
model size params backend ngl n_batch test t/s
bert 33M F16 63.46 MiB 33.21 M CUDA 99 8192 pp 512 84162.93 ± 28049.11
bert 33M F16 63.46 MiB 33.21 M CUDA 99 8192 pp 1024 84590.53 ± 627.15
bert 33M F16 63.46 MiB 33.21 M CUDA 99 8192 pp 2048 65214.02 ± 917.15
bert 33M F16 63.46 MiB 33.21 M CUDA 99 8192 pp 4096 39788.44 ± 154.90
bert 33M F16 63.46 MiB 33.21 M CUDA 99 8192 pp 8192 21576.41 ± 18.43

build: e25fb4b (2354)

  • PR:
model size params backend ngl n_batch test t/s
bert 33M F16 63.46 MiB 33.21 M CUDA 99 8192 pp 512 116013.71 ± 9760.89
bert 33M F16 63.46 MiB 33.21 M CUDA 99 8192 pp 1024 83931.48 ± 965.92
bert 33M F16 63.46 MiB 33.21 M CUDA 99 8192 pp 2048 57481.57 ± 131.52
bert 33M F16 63.46 MiB 33.21 M CUDA 99 8192 pp 4096 31928.17 ± 137.37
bert 33M F16 63.46 MiB 33.21 M CUDA 99 8192 pp 8192 16893.21 ± 30.95

build: 40ca2e0 (2351)

model size params backend ngl n_batch test t/s
bert 33M F16 63.46 MiB 33.21 M CUDA 99 8192 pp 512 106790.32 ± 8440.97
bert 33M F16 63.46 MiB 33.21 M CUDA 99 8192 pp 1024 79385.58 ± 347.36
bert 33M F16 63.46 MiB 33.21 M CUDA 99 8192 pp 2048 53797.97 ± 853.64
bert 33M F16 63.46 MiB 33.21 M CUDA 99 8192 pp 4096 29866.99 ± 79.23
bert 33M F16 63.46 MiB 33.21 M CUDA 99 8192 pp 8192 15713.72 ± 35.11

build: e0843af (2340)

@slaren
Copy link
Collaborator

slaren commented Mar 6, 2024

@iamlemec You need to apply this patch to be able to use llama-bench with embedding models and CUDA, otherwise the results are garbage for the reason I explained before.

diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp
index aa79d002..67a549d6 100644
--- a/examples/llama-bench/llama-bench.cpp
+++ b/examples/llama-bench/llama-bench.cpp
@@ -502,6 +502,7 @@ struct cmd_params_instance {
         cparams.type_k = type_k;
         cparams.type_v = type_v;
         cparams.offload_kqv = !no_kv_offload;
+        cparams.embeddings = true;

         return cparams;
     }

@ggerganov ggerganov added the demo Demonstrate some concept or idea, not intended to be merged label Apr 29, 2024
@mofosyne mofosyne added the Review Complexity : High Generally require indepth knowledge of LLMs or GPUs label May 10, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
demo Demonstrate some concept or idea, not intended to be merged Review Complexity : High Generally require indepth knowledge of LLMs or GPUs
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants