Description
I was interested what impact the BF16 format #6412 had on my cpu (AMD Ryzen Embedded V3000 V3C48, which uses Zen 3 cores, 4800MT/s ECC RAM).
Surprisingly the prompt processing was just half of the performance compared to F16 and F32 formats. Token generation is slightly faster.
Here is llama-bench
with sgemm
:
model | size | params | backend | threads | test | t/s |
---|---|---|---|---|---|---|
mistral 7B BF16 | 13.49 GiB | 7.24 B | CPU | 6 | pp 512 | 11.18 ± 0.05 |
mistral 7B BF16 | 13.49 GiB | 7.24 B | CPU | 6 | tg 128 | 3.26 ± 0.03 |
mistral 7B F16 | 13.49 GiB | 7.24 B | CPU | 6 | pp 512 | 23.29 ± 0.07 |
mistral 7B F16 | 13.49 GiB | 7.24 B | CPU | 6 | tg 128 | 3.02 ± 0.03 |
mistral 7B all F32 | 26.98 GiB | 7.24 B | CPU | 6 | pp 512 | 19.94 ± 0.04 |
mistral 7B all F32 | 26.98 GiB | 7.24 B | CPU | 6 | tg 128 | 1.52 ± 0.01 |
And here without sgemm
(setting LLAMA_NO_LLAMAFILE=1 make ...
):
model | size | params | backend | threads | test | t/s |
---|---|---|---|---|---|---|
mistral 7B BF16 | 13.49 GiB | 7.24 B | CPU | 6 | pp 512 | 10.80 ± 0.04 |
mistral 7B BF16 | 13.49 GiB | 7.24 B | CPU | 6 | tg 128 | 3.24 ± 0.02 |
mistral 7B F16 | 13.49 GiB | 7.24 B | CPU | 6 | pp 512 | 16.32 ± 0.04 |
mistral 7B F16 | 13.49 GiB | 7.24 B | CPU | 6 | tg 128 | 3.26 ± 0.03 |
mistral 7B all F32 | 26.98 GiB | 7.24 B | CPU | 6 | pp 512 | 10.60 ± 0.05 |
mistral 7B all F32 | 26.98 GiB | 7.24 B | CPU | 6 | tg 128 | 1.65 ± 0.01 |
So it's not just gains from sgemm
.
Running perf
on llama-bench
on a Mistral-7B-Instruct-v0.2
model showed that more than 95% of the time is spent in ggml_vec_dot_bf16
. Here is the annotated disassembly:
Percent│
│
│3 Disassembly of section .text:
│
│5 000000000002df60 <ggml_vec_dot_bf16>:
│6 ggml_vec_dot_bf16():
0.05 │ cmp $0x1f,%edi
0.01 │ ↓ jle 410
0.00 │ lea -0x20(%rdi),%r8d
│ vxorps %xmm3,%xmm3,%xmm3
0.00 │ mov %r9,%rax
0.00 │ mov %rcx,%rdx
0.00 │ shr $0x5,%r8d
0.00 │ vmovaps %ymm3,%ymm4
0.05 │ vmovaps %ymm3,%ymm5
0.00 │ vmovaps %ymm3,%ymm2
0.00 │ mov %r8d,%r10d
│ shl $0x6,%r10
0.00 │ lea 0x40(%r9,%r10,1),%r10
0.00 │ data16 cs nopw 0x0(%rax,%rax,1)
0.00 │ xchg %ax,%ax
0.05 │ 40: vpmovzxwd (%rax),%ymm0
0.53 │ vpmovzxwd (%rdx),%ymm1
0.40 │ add $0x40,%rax
0.05 │ add $0x40,%rdx
0.04 │ vpslld $0x10,%ymm0,%ymm0
0.25 │ vpslld $0x10,%ymm1,%ymm1
0.04 │ vfmadd231ps %ymm0,%ymm1,%ymm2
14.70 │ vpmovzxwd -0x30(%rax),%ymm1
0.08 │ vpmovzxwd -0x30(%rdx),%ymm0
0.07 │ vpslld $0x10,%ymm1,%ymm1
0.20 │ vpslld $0x10,%ymm0,%ymm0
0.06 │ vfmadd231ps %ymm0,%ymm1,%ymm5
14.98 │ vpmovzxwd -0x20(%rax),%ymm1
0.07 │ vpmovzxwd -0x20(%rdx),%ymm0
34.90 │ vpslld $0x10,%ymm1,%ymm1
0.50 │ vpslld $0x10,%ymm0,%ymm0
0.36 │ vfmadd231ps %ymm0,%ymm1,%ymm4
15.21 │ vpmovzxwd -0x10(%rax),%ymm1
0.14 │ vpmovzxwd -0x10(%rdx),%ymm0
0.13 │ vpslld $0x10,%ymm1,%ymm1
0.44 │ vpslld $0x10,%ymm0,%ymm0
0.05 │ vfmadd231ps %ymm0,%ymm1,%ymm3
15.01 │ cmp %rax,%r10
0.04 │ ↑ jne 40
│ vaddps %ymm5,%ymm2,%ymm0
0.06 │ vaddps %ymm4,%ymm3,%ymm3
0.11 │ lea 0x1(%r8),%edx
0.00 │ shl $0x5,%edx
0.00 │ mov %edx,%r8d
│ vaddps %ymm3,%ymm0,%ymm0
0.14 │ cd: vmovaps %xmm0,%xmm1
0.00 │ vextractf128 $0x1,%ymm0,%xmm0
0.18 │ vaddps %xmm1,%xmm0,%xmm0
0.16 │ vmovhlps %xmm0,%xmm0,%xmm1
0.18 │ vaddps %xmm0,%xmm1,%xmm0
0.19 │ vmovshdup %xmm0,%xmm1
0.18 │ vaddss %xmm1,%xmm0,%xmm0
0.17 │ vcvtss2sd %xmm0,%xmm0,%xmm4
0.17 │ cmp %r8d,%edi
0.00 │ ↓ jle 400
The majority of time is spent in vpmovzxwd
and vpslld
instructions. My guess this has more to do with waiting for memory than anything else, since the same instructions seem to run quite fast in other locations.
I toyed around with different amounts of unrolling and reordering of the instructions, but that did not really yield any improvements.
Just for fun, here is the result of leaving vectorization of ggml_vec_dot_bf16
to the compiler:
model | size | params | backend | threads | test | t/s |
---|---|---|---|---|---|---|
mistral 7B BF16 | 13.49 GiB | 7.24 B | CPU | 6 | pp 512 | 4.24 ± 0.02 |
mistral 7B BF16 | 13.49 GiB | 7.24 B | CPU | 6 | tg 128 | 3.13 ± 0.02 |
While the intrinsic version does seem to help in prompt processing, the impact for token generation is rather small for some reason.
Here is the profile of the loop:
0.00 │ 30: vpmovzxwd (%rcx,%rax,1),%ymm1
19.45 │ vpmovzxwd (%r9,%rax,1),%ymm0
0.04 │ vmovdqu (%rcx,%rax,1),%ymm5
0.00 │ vmovdqu (%r9,%rax,1),%ymm6
0.33 │ add $0x20,%rax
0.00 │ vpslld $0x10,%ymm0,%ymm0
0.01 │ vpslld $0x10,%ymm1,%ymm1
6.67 │ vmulps %ymm0,%ymm1,%ymm1
0.54 │ vextracti128 $0x1,%ymm6,%xmm2
0.06 │ vextracti128 $0x1,%ymm5,%xmm0
0.01 │ vpmovzxwd %xmm0,%ymm0
0.09 │ vpmovzxwd %xmm2,%ymm2
0.09 │ vpslld $0x10,%ymm2,%ymm2
6.59 │ vpslld $0x10,%ymm0,%ymm0
0.11 │ vmulps %ymm2,%ymm0,%ymm0
1.17 │ vcvtps2pd %xmm1,%ymm2
0.01 │ vextractf128 $0x1,%ymm1,%xmm1
0.00 │ vcvtps2pd %xmm1,%ymm1
0.18 │ vaddpd %ymm1,%ymm2,%ymm1
6.81 │ vcvtps2pd %xmm0,%ymm2
1.26 │ vextractf128 $0x1,%ymm0,%xmm0
0.47 │ vcvtps2pd %xmm0,%ymm0
8.89 │ vaddpd %ymm0,%ymm2,%ymm0
11.95 │ vaddpd %ymm0,%ymm1,%ymm0
15.98 │ vaddpd %ymm0,%ymm3,%ymm3
18.85 │ cmp %rax,%rdx
0.01 │ ↑ jne 30