-
Notifications
You must be signed in to change notification settings - Fork 9.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
llama : compute BERT graph with F16 K, V #5891
base: master
Are you sure you want to change the base?
Conversation
I noticed that the first evaluation with BERT is a lot faster than the rest (about 2x), hence the huge variance. Is this expected? Maybe there is some cleanup necessary between runs that
|
Similar observation on my GPU:
Not expected. Maybe some issue with the input preparation, though I'm not sure where it could be |
The problem is that
With that fix applied, and the fix to
|
Great! Running fast here. I'm now seeing about a 5% speed boost relative to pre-#5796 for GPU. About 20% for CPU. |
It's surprising that for large batch sizes the F32 version is faster. On RTX 2060 with 33M model, F16 is faster only for BS = 512:
build: bd83694 (2350)
build: 40ca2e0 (2351)
build: e0843af (2340) Should we keep the |
Edit: Updated numbers with fix from @slaren. Now seeing similar results on an RTX A6000. In terms of batch size focus, one use case for small batches is when you're trying to reduce response time for a single query to a vector db. Those will typically be small and single sequence. For batched small-ish chunks, I think you may as well keep batch sizes small to avoid quadratic attention costs. The newer long context stuff would be the concern.
build: e25fb4b (2354)
build: 40ca2e0 (2351)
build: e0843af (2340) |
@iamlemec You need to apply this patch to be able to use diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp
index aa79d002..67a549d6 100644
--- a/examples/llama-bench/llama-bench.cpp
+++ b/examples/llama-bench/llama-bench.cpp
@@ -502,6 +502,7 @@ struct cmd_params_instance {
cparams.type_k = type_k;
cparams.type_v = type_v;
cparams.offload_kqv = !no_kv_offload;
+ cparams.embeddings = true;
return cparams;
} |
Before #5796, we were using F16 KV cache. Cast K and V to F16 to match the performance