Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Arm AArch64: optimized GEMV and GEMM kernels for q4_0_q8_0, and q8_0_q8_0 quantization #5780

Merged
merged 28 commits into from
Jul 10, 2024

Conversation

Dibakar
Copy link
Contributor

@Dibakar Dibakar commented Feb 28, 2024

For the Arm AArch64 architecture, this PR adds support for optimized GEMV (using dot instructions) and GEMM (using mmla instructions) kernels for the q4_0_q8_0 and q4_0_q8_0 quantization methods.

The feature is enabled if the platform supports __ARM_FEATURE_MATMUL_INT8 (for GEMM) and __ARM_NEON or __ARM_FEATURE_SVE (for GEMV).

On AWS Graviton3 processors, these kernels resulted in a 2.5x improvement in prompt evaluation over the existing GEMM mmla kernels, as well as a 2x improvement in text generation over the default vec_dot kernel (Feb 21 commit 973053d). Please see the table below.

Authors: David Mansell (david.mansell@arm.com) and Dibakar Gope (dibakar.gope@arm.com)

image

@Jeximo
Copy link
Contributor

Jeximo commented Feb 29, 2024

Hi, thanks for sharing,

I was excited about this, but it doesn't seem to do what I expected for my device. Perhaps I'm not building correctly? Here's my command: cmake -B build -DCMAKE_C_FLAGS=-march=armv8.4a -D

uname -a
Linux localhost 4.14.190-23725627-abG975WVLS8IWD1 #2 SMP PREEMPT Mon Apr 10 18:16:39 KST 2023 aarch64 Android
built with clang version 17.0.6 for aarch64-unknown-linux-android24

system_info: n_threads = 3 / 8 | AVX = 0 | AVX_VNNI = 0 | AVX2 = 0 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 0 | NEON = 1 | ARM_FMA = 1 | F16C = 0 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 0 | SSE3 = 0 | SSSE3 = 0 | VSX = 0 | MATMUL_INT8 = 0 |

(master) test with gemma 2b 8_0:

llama_print_timings:        load time =    3336.88 ms
llama_print_timings:      sample time =     204.57 ms /    78 runs   (    2.62 ms per token,   381.28 tokens per second)
llama_print_timings: prompt eval time =    7224.29 ms /    83 tokens (   87.04 ms per token,    11.49 tokens per second)
llama_print_timings:        eval time =   11994.10 ms /    77 runs   (  155.77 ms per token,     6.42 tokens per second)
llama_print_timings:       total time =   36415.72 ms /   160 tokens

(PR) test:

llama_print_timings:        load time =   31002.49 ms
llama_print_timings:      sample time =     351.07 ms /   129 runs   (    2.72 ms per token,   367.45 tokens per second)
llama_print_timings: prompt eval time =    8281.97 ms /    58 tokens (  142.79 ms per token,     7.00 tokens per second)
llama_print_timings:        eval time =   19549.39 ms /   128 runs   (  152.73 ms per token,     6.55 tokens per second)
llama_print_timings:       total time =   34360.41 ms /   186 tokens

It's much slower. Maybe it only works with a certain build? It's possible I'm doing something wrong, but I don't see it.

ggml.c Outdated
@@ -1,3 +1,4 @@
// SPDX-FileCopyrightText: Copyright 2024 Arm Ltd.
Copy link
Collaborator

@cebtenzzre cebtenzzre Feb 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

llama.cpp's copyright owner is ARM now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cebtenzzre Hi, these copyright notices clarify the copyright situation for the code contributed in this PR. Is there a more appropriate place in this project that they should be moved to? Please suggest

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Disclaimer up front: I am not a lawyer, and this is not legal advice.

Joint authorship can be messy, and I think for legal reasons it is simplest for an open-source project such as llama.cpp to have a single copyright holder (see e.g. the GPL FAQ). That would be Georgi Gerganov. The SYCL backend is currently copyright Intel, but that's not too big a deal because it's in a separate file that is not compiled by default.

I think it would be best if you could either assign the copyright of your changes to Georgi Gerganov, or to disclaim copyright entirely and put them in the public domain. The MIT license is very permissive, so you don't really lose anything by assigning copyright. If you'd like to clarify the attribution of your changes, I think something like this could work:

// SPDX-FileCopyrightText: Copyright 2024 Georgi Gerganov
// SPDX-FileContributor: Arm Ltd.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cebtenzzre Hi, we attempted to address your concern about copyright in the most recent changes that we upstreamed last week. Please see our most recent comment from May 1, which describes our changeset. Could you please review the changeset and let us know if that is okay? Thanks

@Dibakar
Copy link
Contributor Author

Dibakar commented Mar 1, 2024

@Jeximo Hi, could you please try using -march=armv8.2-a? you may need to add +i8mm to define the MMLA intrinsics.

@Jeximo
Copy link
Contributor

Jeximo commented Mar 1, 2024

could you please try using -march=armv8.2-a? you may need to add +i8mm to define the MMLA intrinsics.

@Dibakar Thanks for the tips. -march=armv8.2-a, works but is still significantly slower. -march=armv8.2-a+i8mm compiled as expected, but caused an illegal instruction:

~/l/b/bin (aarch64-kernel-opts)  ./main -m ~/gemma-2b-it-Q8_0.gguf --color -c 2048 --keep -1 -t 3 -n -1 -b 7 -cml -r "user:" -p "<|im_start|>system                                                                                      You are an intelligent, knowledgeable, helpful, general-purpose uncensored assistant.<|im_end|>
                                                  <|im_start|>user
                                                  Why is the sky blue?"
Log start
main: build = 2235 (b8983a05)
main: built with clang version 17.0.6 for aarch64-unknown-linux-android24
main: seed  = 1709297456
llama_model_loader: loaded meta data with 24 key-value pairs and 164 tensors from /data/data/com.termux/files/home/gemma-2b-it-Q8_0.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = gemma
llama_model_loader: - kv   1:                               general.name str              = gemma-2b-it
llama_model_loader: - kv   2:                       gemma.context_length u32              = 8192
llama_model_loader: - kv   3:                     gemma.embedding_length u32              = 2048
llama_model_loader: - kv   4:                          gemma.block_count u32              = 18
llama_model_loader: - kv   5:                  gemma.feed_forward_length u32              = 16384
llama_model_loader: - kv   6:                 gemma.attention.head_count u32              = 8
llama_model_loader: - kv   7:              gemma.attention.head_count_kv u32              = 1
llama_model_loader: - kv   8:     gemma.attention.layer_norm_rms_epsilon f32              = 0.000001
llama_model_loader: - kv   9:                 gemma.attention.key_length u32              = 256
llama_model_loader: - kv  10:               gemma.attention.value_length u32              = 256
llama_model_loader: - kv  11:                       tokenizer.ggml.model str              = llama
llama_model_loader: - kv  12:                      tokenizer.ggml.tokens arr[str,256000]  = ["<pad>", "<eos>", "<bos>", "<unk>", ...
llama_model_loader: - kv  13:                      tokenizer.ggml.scores arr[f32,256000]  = [0.000000, 0.000000, 0.000000, 0.0000...
llama_model_loader: - kv  14:                  tokenizer.ggml.token_type arr[i32,256000]  = [3, 3, 3, 2, 1, 1, 1, 1, 1, 1, 1, 1, ...
llama_model_loader: - kv  15:                tokenizer.ggml.bos_token_id u32              = 2
llama_model_loader: - kv  16:                tokenizer.ggml.eos_token_id u32              = 1
llama_model_loader: - kv  17:            tokenizer.ggml.unknown_token_id u32              = 3
llama_model_loader: - kv  18:            tokenizer.ggml.padding_token_id u32              = 0
llama_model_loader: - kv  19:               tokenizer.ggml.add_bos_token bool             = true
llama_model_loader: - kv  20:               tokenizer.ggml.add_eos_token bool             = false
llama_model_loader: - kv  21:                    tokenizer.chat_template str              = {% if messages[0]['role'] == 'system'...
llama_model_loader: - kv  22:               general.quantization_version u32              = 2
llama_model_loader: - kv  23:                          general.file_type u32              = 7
llama_model_loader: - type  f32:   37 tensors
llama_model_loader: - type q8_0:  127 tensors
llm_load_vocab: mismatch in special tokens definition ( 416/256000 vs 260/256000 ).
llm_load_print_meta: format           = GGUF V3 (latest)
llm_load_print_meta: arch             = gemma
llm_load_print_meta: vocab type       = SPM
llm_load_print_meta: n_vocab          = 256000
llm_load_print_meta: n_merges         = 0
llm_load_print_meta: n_ctx_train      = 8192
llm_load_print_meta: n_embd           = 2048
llm_load_print_meta: n_head           = 8
llm_load_print_meta: n_head_kv        = 1
llm_load_print_meta: n_layer          = 18
llm_load_print_meta: n_rot            = 256
llm_load_print_meta: n_embd_head_k    = 256
llm_load_print_meta: n_embd_head_v    = 256
llm_load_print_meta: n_gqa            = 8
llm_load_print_meta: n_embd_k_gqa     = 256
llm_load_print_meta: n_embd_v_gqa     = 256
llm_load_print_meta: f_norm_eps       = 0.0e+00
llm_load_print_meta: f_norm_rms_eps   = 1.0e-06
llm_load_print_meta: f_clamp_kqv      = 0.0e+00
llm_load_print_meta: f_max_alibi_bias = 0.0e+00
llm_load_print_meta: n_ff             = 16384
llm_load_print_meta: n_expert         = 0
llm_load_print_meta: n_expert_used    = 0
llm_load_print_meta: rope scaling     = linear
llm_load_print_meta: freq_base_train  = 10000.0
llm_load_print_meta: freq_scale_train = 1
llm_load_print_meta: n_yarn_orig_ctx  = 8192
llm_load_print_meta: rope_finetuned   = unknown
llm_load_print_meta: model type       = 2B
llm_load_print_meta: model ftype      = Q8_0
llm_load_print_meta: model params     = 2.51 B
llm_load_print_meta: model size       = 2.48 GiB (8.50 BPW)
llm_load_print_meta: general.name     = gemma-2b-it
llm_load_print_meta: BOS token        = 2 '<bos>'
llm_load_print_meta: EOS token        = 1 '<eos>'
llm_load_print_meta: UNK token        = 3 '<unk>'
llm_load_print_meta: PAD token        = 0 '<pad>'
llm_load_print_meta: LF token         = 227 '<0x0A>'
llm_load_tensors: ggml ctx size =    0.07 MiB
llm_load_tensors:        CPU buffer size =  2539.66 MiB
.............................................................
llama_new_context_with_model: n_ctx      = 2048
llama_new_context_with_model: freq_base  = 10000.0
llama_new_context_with_model: freq_scale = 1
llama_kv_cache_init:        CPU KV buffer size =    36.00 MiB
llama_new_context_with_model: KV self size  =   36.00 MiB, K (f16):   18.00 MiB, V (f16):   18.00 MiB
llama_new_context_with_model:        CPU input buffer size   =     0.13 MiB
llama_new_context_with_model:        CPU compute buffer size =     6.89 MiB
llama_new_context_with_model: graph splits (measure): 1
fish: Job 1, './main -m ~/gemma-2b-it-Q8_0.gg…' terminated by signal You are an intelligent, knowled… (<|im_start|>user)
fish: Job Why is the sky blue?", 'SIGILL' terminated by signal Illegal instruction ()

I'm inexperienced so if you're clear then I'm open to suggestions. All good either way.

@slaren
Copy link
Collaborator

slaren commented Mar 1, 2024

Maybe all the data rearrangement that this change requires could be implemented in a less intrusive way as a different backend. It is not really possible at the moment, supporting backends that only implement matrix multiplication would require changes to ggml_backend_sched, and the cost of launching the threads in the CPU backend would make the switches between backends too expensive, but we plan to address these issues in the future.

@Dibakar
Copy link
Contributor Author

Dibakar commented Mar 1, 2024

@Jeximo what is the prompt (prompt eval time) and text generation (eval time) token/s rate now for armv8.2-a with and without the patch?

@Jeximo
Copy link
Contributor

Jeximo commented Mar 1, 2024

@Dibakar to be as clear as possible, I build for cpu. Not openblas, not gpu, not vulkan.

Here's build info:

cmake -B build -DCMAKE_C_FLAGS=-march=armv8.4-a -DLLAMA_VULKAN=OFF && cd build && cmake --build . --config Release
-- The C compiler identification is Clang 17.0.6
-- The CXX compiler identification is Clang 17.0.6
-- Detecting C compiler ABI info
-- Detecting C compiler ABI info - done
-- Check for working C compiler: /data/data/com.termux/files/usr/bin/cc - skipped
-- Detecting C compile features
-- Detecting C compile features - done
-- Detecting CXX compiler ABI info
-- Detecting CXX compiler ABI info - done
-- Check for working CXX compiler: /data/data/com.termux/files/usr/bin/c++ - skipped
-- Detecting CXX compile features
-- Detecting CXX compile features - done
-- Found Git: /data/data/com.termux/files/usr/bin/git (found version "2.44.0")
-- Performing Test CMAKE_HAVE_LIBC_PTHREAD
-- Performing Test CMAKE_HAVE_LIBC_PTHREAD - Failed
-- Check if compiler accepts -pthread
-- Check if compiler accepts -pthread - yes
-- Found Threads: TRUE
-- ccache found, compilation results will be cached. Disable with LLAMA_CCACHE=OFF.
-- CMAKE_SYSTEM_PROCESSOR: aarch64                
-- ARM detected                                  
-- Performing Test COMPILER_SUPPORTS_FP16_FORMAT_I3E
-- Performing Test COMPILER_SUPPORTS_FP16_FORMAT_I3E - Failed                                       -- Configuring done (3.0s)
-- Generating done (0.2s)
-- Build files have been written to: /data/data/com.termux/files/home/llama.cpp/build

Master with -march=armv8.4-a

llama_print_timings:        load time =    1457.52 ms
llama_print_timings:      sample time =     151.63 ms /   121 runs   (    1.25 ms per token,   797.97 tokens per second)
llama_print_timings: prompt eval time =    7169.65 ms /    84 tokens (   85.35 ms per token,    11.72 tokens per second)
llama_print_timings:        eval time =   18688.17 ms /   120 runs   (  155.73 ms per token,     6.42 tokens per second)
llama_print_timings:       total time =   42734.25 ms /   204 tokens

PR with -march=armv8.4-a:

llama_print_timings:        load time =   35110.65 ms
llama_print_timings:      sample time =     153.11 ms /   124 runs   (    1.23 ms per token,   809.87 tokens per second)
llama_print_timings: prompt eval time =   12702.64 ms /    88 tokens (  144.35 ms per token,     6.93 tokens per second)
llama_print_timings:        eval time =   18753.10 ms /   123 runs   (  152.46 ms per token,     6.56 tokens per second)
llama_print_timings:       total time =   52990.14 ms /   211 tokens

command: ./main -m ~/gemma-2b-it-Q8_0.gguf -s 7 -e --temp 0 --repeat-penalty 1.0 --no-penalize-nl -c 4096 -gan 2 -gaw 1024 --keep -1 -t 3 -n -1 -b 7 -cml -r "user:" -p "<|im_start|>system You are an intelligent, knowledgeable, helpful, general-purpose assistant.<|im_end|> <|im_start|>user Why is the sky blue?"

For my device, --batch N(-b 7), affects print timing quite a bit, so let me know if you want me to try something else, or use different parameters. I tried another batch size, similar timings:
-b 30 Master:

llama_print_timings:        load time =     492.74 ms
llama_print_timings:      sample time =     154.58 ms /   124 runs   (    1.25 ms per token,   802.16 tokens per second)
llama_print_timings: prompt eval time =    6814.38 ms /    84 tokens (   81.12 ms per token,    12.33 tokens per second)
llama_print_timings:        eval time =   19341.29 ms /   123 runs   (  157.25 ms per token,     6.36 tokens per second)
llama_print_timings:       total time =   38342.52 ms /   207 tokens

-b 30 PR:

llama_print_timings:        load time =   32652.41 ms
llama_print_timings:      sample time =     151.19 ms /   122 runs   (    1.24 ms per token,   806.94 tokens per second)
llama_print_timings: prompt eval time =   11788.72 ms /    83 tokens (  142.03 ms per token,     7.04 tokens per second)
llama_print_timings:        eval time =   18387.48 ms /   121 runs   (  151.96 ms per token,     6.58 tokens per second)
llama_print_timings:       total time =   46544.11 ms /   204 tokens

Admittedly, eval time in llama_print_timings is improved in every test, but other timings suffer for my device.

@USBhost
Copy link

USBhost commented Mar 2, 2024

As of writing current head

llama_print_timings:        load time =    2141.24 ms
llama_print_timings:      sample time =      32.60 ms /   128 runs   (    0.25 ms per token,  3926.62 tokens per second)
llama_print_timings: prompt eval time =    1001.54 ms /     7 tokens (  143.08 ms per token,     6.99 tokens per second)
llama_print_timings:        eval time =   19410.93 ms /   127 runs   (  152.84 ms per token,     6.54 tokens per second)
llama_print_timings:       total time =   20484.00 ms /   134 tokens
Log end

PR branch:

llama_print_timings:        load time =   34172.00 ms
llama_print_timings:      sample time =      34.98 ms /   128 runs   (    0.27 ms per token,  3658.82 tokens per second)
llama_print_timings: prompt eval time =     655.01 ms /     7 tokens (   93.57 ms per token,    10.69 tokens per second)
llama_print_timings:        eval time =   14391.61 ms /   127 runs   (  113.32 ms per token,     8.82 tokens per second)
llama_print_timings:       total time =   15130.46 ms /   134 tokens                                      Log end

Tests done on my OnePlus 12 (Android) under Termux. Things I have noticed:

  1. I had to build with make CFLAGS="-march=armv8.2-a+dotprod" -j8. Aka I had to add dotprod or it would not compile. Edit: I tried -march=armv8.4-a like the others here and it works with identical results.
  2. This PR made the llm_load_tensors part load like 15 times slower.
  3. Memory required basically doubled for some reason... I needed close to 11GB of RAM for 4_0 7b
  4. It's actually legit faster 6.54t/s vs 8.82t/s not quite double but still a good speed increase.

All tests I ran with ./main -m ../westlake-7b-v2.Q4_0.gguf -t 4 -p "why is the sky blue?" --no-mmap -n 128
I did not choose that model for any particular reason I just wanted to get a 4_0 to test since this was doubling my required memory usage so I couldn't run my normal 8_0 lol

@ggerganov
Copy link
Owner

These changes are too intrusive - there a lot of new quantization structs introduced to the core library which would be difficult to maintain. The approach proposed earlier in the discussion for implementing a dedicated backend and moving the entire matrix multiplication code there should be considered instead

Also, one additional thing to consider is that ggml_tensor should not be modified with backend specific members. Such data can reside in extra if necessary

@Dibakar
Copy link
Contributor Author

Dibakar commented Mar 6, 2024

@ggerganov @slaren Sure, we should be able to place things in extra field of ggml_tensor.
This patch requires rearranging the initial loaded weights (for example, bundle q4_0 blocks from multiple weight columns to make them prepared for the optimized gemv/gemm). Hence we have a few rearrange-related functions. Since this is done with weights, it should ideally be done offline in advance. We do it here online and overwrite the weights when we load them just before the inference starts. Aside from that, our patch essentially proposes the addition of the two functions ggml_gemv_q4_0_q8_0 and ggml_gemm_q4_0_q8_0 and the optimized variants are called if possible before falling back to the reference vec_dot in the ggml_compute_forward_mul_mat.

@Dibakar
Copy link
Contributor Author

Dibakar commented Mar 6, 2024

@ggerganov @slaren Hi, as I mentioned in my previous response, the code is made up of two main parts and associated structures. One is weight rearrangement, and the other is actual optimized kernels. Please suggest if there is a specific backend file/location where you want our weight rearrangement code to be moved. The optimized matrix multiplication kernels ggml_gemv_q4_0_q8_0 and ggml_gemm_q4_0_q8_0 are currently defined in ggml-quants.h/c, where the other vec_dot/gemm kernels for other variants are defined. Should we keep them there? Please suggest.

@slaren
Copy link
Collaborator

slaren commented Mar 6, 2024

@Dibakar Ideally, this would be implemented as a different backend that implements the ggml-backend interface. This would allow us to isolate all this code to a different file, without modifying the internals of ggml, and without requiring applications to call additional functions to convert the tensors.

However, it is not possible to do this at this moment due to the limitations that I mentioned earlier - we need to modify ggml_backend_sched to support this use case, and we need to make changes in the CPU backend to reduce the cost of switching between backends. Once that is done, we will also move the BLAS code to a different backend, and you could use that as a template for adapting this code to the ggml-backend interface.

@Dibakar
Copy link
Contributor Author

Dibakar commented Mar 7, 2024

@slaren Thanks for the feedback.

I would like to clarify that the weight rearrangement step can be done completely offline when converting an HF file to a GGUF file. We can incorporate it into the current convert-hf-to-gguf.py for Arm aarch64 optimized kernels, and make provisions for generating another GGUF in addition to the current GGUF file and using it for inference. This will remove all of the changes we made to ggml.c/.h and llama.cpp. We needed to do this once before the inference begins, and we placed the code to ggml.c/.h and llama.cpp. However, this code can be completely removed.

After that, we will be left with only our ggml_gemv_q4_0_q8_0, and ggml_gemm_q4_0_q8_0 kernels, and quantize_row_q4_0 functions that we introduced in the ggml-quants.c file, where the vec-dot/gemm kernels for different quantization methods are defined. Our quantize_row_q4_0 changes can be easily refactored to incorporate into the existing quantize_q4_0 function. In addition, we need to keep a few lines of code in ggml.c's forward_mul_mat function to call our optimized kernels whenever possible. We can definitely refactor the code around the current vec_dot call in the forward_mul_mat function or make it part of the vec_dot call to introduce this in a non-intrusive manner

If we make these changes, I was wondering if we still need a separate backend for them. We can submit the changes for that in the coming days, making the overall changes significantly less intrusive than they appear.

@slaren
Copy link
Collaborator

slaren commented Mar 7, 2024

I think that could also work, you could add new data types for the rearranged Q4_0 and Q8_0, add the ability to quantize models to these data types, and add all the relevant functions for these data types in type_traits to support them. That would integrate better in the current ggml framework, and would be less intrusive. @ggerganov what do you think?

@ggerganov
Copy link
Owner

Yes, rearranged data types should work. This way the logic in ggml_compute_forward_mul_mat could be delegated outside of ggml - it will be based on the data types, which would be determined upon model creation, based on the tensor sizes.

Try to remove the changes in llama.h and ggml.h and re-implement the rearrange_ functions as quantization function for the new types instead

@snadampal
Copy link
Contributor

I have tested the PR on AWS Graviton3 ( r7g.16xl instance). for q4_0 quantized inference, it improved the prompt evaluation performance by around 1.75x for 64 thread and 2.5x for single thread configuration.

the weights reordering/prepacking latency is the only concern i see, once it's addressed with new quantization format, the changes should be fine.

command:

./main -m ./llama-2-7b.Q4_0.gguf -p "Building a visually appealing website can be done in ten simple steps:" -n 128 -t 64

64 vcpus results:

mainline

llama_print_timings:        load time =     807.92 ms
llama_print_timings:      sample time =      26.81 ms /   128 runs   (    0.21 ms per token,  4774.52 tokens per second)
llama_print_timings: prompt eval time =     106.22 ms /    16 tokens (    6.64 ms per token,   150.64 tokens per second)
llama_print_timings:        eval time =    4025.86 ms /   127 runs   (   31.70 ms per token,    31.55 tokens per second)
llama_print_timings:       total time =    4195.21 ms /   143 tokens

PR:

llama_print_timings:        load time =   40763.67 ms
llama_print_timings:      sample time =      12.95 ms /    64 runs   (    0.20 ms per token,  4942.47 tokens per second)
llama_print_timings: prompt eval time =      60.83 ms /    16 tokens (    3.80 ms per token,   263.05 tokens per second)
llama_print_timings:        eval time =    1970.96 ms /    63 runs   (   31.29 ms per token,    31.96 tokens per second)
llama_print_timings:       total time =    2062.69 ms /    79 tokens

single thread results:

mainline:

llama_print_timings:        load time =    1366.04 ms
llama_print_timings:      sample time =      25.40 ms /   128 runs   (    0.20 ms per token,  5039.97 tokens per second)
llama_print_timings: prompt eval time =    4005.19 ms /    16 tokens (  250.32 ms per token,     3.99 tokens per second)
llama_print_timings:        eval time =   47081.50 ms /   127 runs   (  370.72 ms per token,     2.70 tokens per second)
llama_print_timings:       total time =   51148.90 ms /   143 tokens

PR:

llama_print_timings:        load time =   41334.29 ms
llama_print_timings:      sample time =      26.32 ms /   128 runs   (    0.21 ms per token,  4862.48 tokens per second)
llama_print_timings: prompt eval time =    1557.81 ms /    16 tokens (   97.36 ms per token,    10.27 tokens per second)
llama_print_timings:        eval time =   22168.20 ms /   127 runs   (  174.55 ms per token,     5.73 tokens per second)
llama_print_timings:       total time =   23788.29 ms /   143 tokens

@akingoverlook
Copy link

akingoverlook commented Mar 9, 2024

Yes, rearranged data types should work. This way the logic in ggml_compute_forward_mul_mat could be delegated outside of ggml - it will be based on the data types, which would be determined upon model creation, based on the tensor sizes.

Try to remove the changes in llama.h and ggml.h and re-implement the rearrange_ functions as quantization function for the new types instead

A generalized capability to re-process the weights on a model load would be quite useful to a lot of people, I think. This patch was doing it for a narrow purpose, but considering how cheap/fast your quantization is, why can't you have "online quantization" (which then could handle this narrow case too). After all, HF transformers can do that.

That would make logistics of evaluating a bunch of different flavors much simpler. No need to produce and post a ton of GGUF files that will mostly ever be used once (per person evaluating them) and just clutter everyone's storage. It would be enough to just keep the fp16 flavor around - which are already being released by some newer original models (e.g., gemma).

If that does not put TheBloke out of business, it would lighten the load at least. Also would significantly reduce the problem of obsolete GGUF files spread everywhere when something changes in the code.

@akingoverlook
Copy link

FYI, @Dibakar here's this PR (your branch + PR cherry-pick) on one of the latest snapdragons.
Verified that matmul int8 is available and enabled.

03-10 18:33:57.665  2136  2136 F DEBUG   : pid: 2122, tid: 2122, name: main  >>> ./llama-armopt/bin/main <<<
03-10 18:33:57.666  2136  2136 F DEBUG   : uid: 0
03-10 18:33:57.666  2136  2136 F DEBUG   : tagged_addr_ctrl: 0000000000000001 (PR_TAGGED_ADDR_ENABLE)
03-10 18:33:57.666  2136  2136 F DEBUG   : pac_enabled_keys: 000000000000000f (PR_PAC_APIAKEY, PR_PAC_APIBKEY, PR_PAC_APDAKEY, PR_PAC_APDBKEY)
03-10 18:33:57.666  2136  2136 F DEBUG   : signal 11 (SIGSEGV), code 1 (SEGV_MAPERR), fault addr 0x0000000000000000
03-10 18:33:57.666  2136  2136 F DEBUG   : Cause: null pointer dereference
03-10 18:33:57.666  2136  2136 F DEBUG   :     x0  0000007fd188d6c0  x1  0000000000000000  x2  0000000000000040  x3  0000000000000004
03-10 18:33:57.666  2136  2136 F DEBUG   :     x4  0000007fd188d6b0  x5  0000000000000000  x6  b400007b09aac010  x7  0000000000000000
03-10 18:33:57.667  2136  2136 F DEBUG   :     x8  0000000000000000  x9  0000000000000000  x10 0000007fd188d690  x11 0000000000000000
03-10 18:33:57.667  2136  2136 F DEBUG   :     x12 0000007fd188d6b0  x13 0000007fd188d690  x14 0000000000000000  x15 0000000000000004
03-10 18:33:57.667  2136  2136 F DEBUG   :     x16 0000005563d06590  x17 0000000000000048  x18 0000007c9beb6000  x19 0000000000000004
03-10 18:33:57.667  2136  2136 F DEBUG   :     x20 b4000077ae3b1fe0  x21 0000000000000001  x22 0000000000000088  x23 0000007fd188d670
03-10 18:33:57.667  2136  2136 F DEBUG   :     x24 0000000000000080  x25 0000000000004400  x26 0000007fd188d6b0  x27 0000000000004000
03-10 18:33:57.667  2136  2136 F DEBUG   :     x28 0000000000000000  x29 0000007fd188d760
03-10 18:33:57.667  2136  2136 F DEBUG   :     lr  0000005563cb687c  sp  0000007fd188d670  pc  0000005563cb68c8  pst 0000000060001000
03-10 18:33:57.668  2136  2136 F DEBUG   : 10 total frames
03-10 18:33:57.668  2136  2136 F DEBUG   : backtrace:
03-10 18:33:57.668  2136  2136 F DEBUG   :       #00 pc 00000000001208c8  /data/local/tmp/llm/llama-armopt/bin/main (ggml_gemm_q4_0_q8_0+672) (BuildId: 0fd84a10d690aa71299c4aa6154962fab6119584)
03-10 18:33:57.668  2136  2136 F DEBUG   :       #01 pc 00000000000f7dd4  /data/local/tmp/llm/llama-armopt/bin/main (ggml_compute_forward_mul_mat+4380) (BuildId: 0fd84a10d690aa71299c4aa6154962fab6119584)
03-10 18:33:57.668  2136  2136 F DEBUG   :       #02 pc 00000000000e87dc  /data/local/tmp/llm/llama-armopt/bin/main (ggml_graph_compute_thread+860) (BuildId: 0fd84a10d690aa71299c4aa6154962fab6119584)
03-10 18:33:57.668  2136  2136 F DEBUG   :       #03 pc 00000000000e82a4  /data/local/tmp/llm/llama-armopt/bin/main (ggml_graph_compute+228) (BuildId: 0fd84a10d690aa71299c4aa6154962fab6119584)
03-10 18:33:57.668  2136  2136 F DEBUG   :       #04 pc 0000000000110ff8  /data/local/tmp/llm/llama-armopt/bin/main (ggml_backend_cpu_graph_compute+104) (BuildId: 0fd84a10d690aa71299c4aa6154962fab6119584)
03-10 18:33:57.669  2136  2136 F DEBUG   :       #05 pc 000000000011043c  /data/local/tmp/llm/llama-armopt/bin/main (ggml_backend_sched_graph_compute+292) (BuildId: 0fd84a10d690aa71299c4aa6154962fab6119584)
03-10 18:33:57.669  2136  2136 F DEBUG   :       #06 pc 00000000000a2c00  /data/local/tmp/llm/llama-armopt/bin/main (llama_decode_internal(llama_context&, llama_batch)+3920) (BuildId: 0fd84a10d690aa71299c4aa6154962fab6119584)
03-10 18:33:57.669  2136  2136 F DEBUG   :       #07 pc 00000000000a340c  /data/local/tmp/llm/llama-armopt/bin/main (llama_decode+56) (BuildId: 0fd84a10d690aa71299c4aa6154962fab6119584)
03-10 18:33:57.669  2136  2136 F DEBUG   :       #08 pc 000000000005cc10  /data/local/tmp/llm/llama-armopt/bin/main (main+19640) (BuildId: 0fd84a10d690aa71299c4aa6154962fab6119584)


@akingoverlook
Copy link

FYI, @Dibakar here's this PR (your branch + PR cherry-pick) on one of the latest snapdragons. Verified that matmul int8 is available and enabled.

03-10 18:33:57.665  2136  2136 F DEBUG   : pid: 2122, tid: 2122, name: main  >>> ./llama-armopt/bin/main <<<
03-10 18:33:57.666  2136  2136 F DEBUG   : uid: 0
03-10 18:33:57.666  2136  2136 F DEBUG   : tagged_addr_ctrl: 0000000000000001 (PR_TAGGED_ADDR_ENABLE)
03-10 18:33:57.666  2136  2136 F DEBUG   : pac_enabled_keys: 000000000000000f (PR_PAC_APIAKEY, PR_PAC_APIBKEY, PR_PAC_APDAKEY, PR_PAC_APDBKEY)
03-10 18:33:57.666  2136  2136 F DEBUG   : signal 11 (SIGSEGV), code 1 (SEGV_MAPERR), fault addr 0x0000000000000000
03-10 18:33:57.666  2136  2136 F DEBUG   : Cause: null pointer dereference
03-10 18:33:57.666  2136  2136 F DEBUG   :     x0  0000007fd188d6c0  x1  0000000000000000  x2  0000000000000040  x3  0000000000000004
03-10 18:33:57.666  2136  2136 F DEBUG   :     x4  0000007fd188d6b0  x5  0000000000000000  x6  b400007b09aac010  x7  0000000000000000
03-10 18:33:57.667  2136  2136 F DEBUG   :     x8  0000000000000000  x9  0000000000000000  x10 0000007fd188d690  x11 0000000000000000
03-10 18:33:57.667  2136  2136 F DEBUG   :     x12 0000007fd188d6b0  x13 0000007fd188d690  x14 0000000000000000  x15 0000000000000004
03-10 18:33:57.667  2136  2136 F DEBUG   :     x16 0000005563d06590  x17 0000000000000048  x18 0000007c9beb6000  x19 0000000000000004
03-10 18:33:57.667  2136  2136 F DEBUG   :     x20 b4000077ae3b1fe0  x21 0000000000000001  x22 0000000000000088  x23 0000007fd188d670
03-10 18:33:57.667  2136  2136 F DEBUG   :     x24 0000000000000080  x25 0000000000004400  x26 0000007fd188d6b0  x27 0000000000004000
03-10 18:33:57.667  2136  2136 F DEBUG   :     x28 0000000000000000  x29 0000007fd188d760
03-10 18:33:57.667  2136  2136 F DEBUG   :     lr  0000005563cb687c  sp  0000007fd188d670  pc  0000005563cb68c8  pst 0000000060001000
03-10 18:33:57.668  2136  2136 F DEBUG   : 10 total frames
03-10 18:33:57.668  2136  2136 F DEBUG   : backtrace:
03-10 18:33:57.668  2136  2136 F DEBUG   :       #00 pc 00000000001208c8  /data/local/tmp/llm/llama-armopt/bin/main (ggml_gemm_q4_0_q8_0+672) (BuildId: 0fd84a10d690aa71299c4aa6154962fab6119584)
03-10 18:33:57.668  2136  2136 F DEBUG   :       #01 pc 00000000000f7dd4  /data/local/tmp/llm/llama-armopt/bin/main (ggml_compute_forward_mul_mat+4380) (BuildId: 0fd84a10d690aa71299c4aa6154962fab6119584)
03-10 18:33:57.668  2136  2136 F DEBUG   :       #02 pc 00000000000e87dc  /data/local/tmp/llm/llama-armopt/bin/main (ggml_graph_compute_thread+860) (BuildId: 0fd84a10d690aa71299c4aa6154962fab6119584)
03-10 18:33:57.668  2136  2136 F DEBUG   :       #03 pc 00000000000e82a4  /data/local/tmp/llm/llama-armopt/bin/main (ggml_graph_compute+228) (BuildId: 0fd84a10d690aa71299c4aa6154962fab6119584)
03-10 18:33:57.668  2136  2136 F DEBUG   :       #04 pc 0000000000110ff8  /data/local/tmp/llm/llama-armopt/bin/main (ggml_backend_cpu_graph_compute+104) (BuildId: 0fd84a10d690aa71299c4aa6154962fab6119584)
03-10 18:33:57.669  2136  2136 F DEBUG   :       #05 pc 000000000011043c  /data/local/tmp/llm/llama-armopt/bin/main (ggml_backend_sched_graph_compute+292) (BuildId: 0fd84a10d690aa71299c4aa6154962fab6119584)
03-10 18:33:57.669  2136  2136 F DEBUG   :       #06 pc 00000000000a2c00  /data/local/tmp/llm/llama-armopt/bin/main (llama_decode_internal(llama_context&, llama_batch)+3920) (BuildId: 0fd84a10d690aa71299c4aa6154962fab6119584)
03-10 18:33:57.669  2136  2136 F DEBUG   :       #07 pc 00000000000a340c  /data/local/tmp/llm/llama-armopt/bin/main (llama_decode+56) (BuildId: 0fd84a10d690aa71299c4aa6154962fab6119584)
03-10 18:33:57.669  2136  2136 F DEBUG   :       #08 pc 000000000005cc10  /data/local/tmp/llm/llama-armopt/bin/main (main+19640) (BuildId: 0fd84a10d690aa71299c4aa6154962fab6119584)

Sorry, that was a build problem. The matmul int8 got enabled in the ggml (where I did check it) but not in llama (where I did not). So the new functions were getting called with NULL for the "rerranged weights" buffer.

While debugging that, I did notice that all the "rearrange" functions just call malloc() unchecked, and are declared void, which would lead to some lovely crashes once you run out of memory. And that will happen, because it doesn't look like the rearranged buffers are ever freed. Ran the llama-bench with some multiple models, and that did happen.

Of course, if this all goes into some offline "new quantization types", it is all moot. So, now to the good part - that is the performance. It is pretty amazing what this PR does.

First, some baseline numbers for the current master branch:

model size params backend threads test t/s
gemma 2B Q4_0 1.31 GiB 2.51 B CPU 4 pp 512 24.38 ± 1.24
gemma 2B Q4_0 1.31 GiB 2.51 B CPU 4 tg 128 10.11 ± 0.10
gemma 2B Q8_0 2.48 GiB 2.51 B CPU 4 pp 512 25.01 ± 0.69
gemma 2B Q8_0 2.48 GiB 2.51 B CPU 4 tg 128 8.47 ± 0.12
llama 7B Q4_0 3.56 GiB 6.74 B CPU 4 pp 512 7.47 ± 0.40
llama 7B Q4_0 3.56 GiB 6.74 B CPU 4 tg 128 4.25 ± 0.02

And now, Dibakar's brach with this PR, on the same device.

model size params backend threads test t/s
gemma 2B Q4_0 1.59 GiB 3.03 B CPU 4 pp 512 55.32 ± 3.10
gemma 2B Q4_0 1.59 GiB 3.03 B CPU 4 tg 128 12.50 ± 0.01
gemma 2B Q8_0 3.00 GiB 3.03 B CPU 4 pp 512 45.52 ± 3.60
gemma 2B Q8_0 3.00 GiB 3.03 B CPU 4 tg 128 8.98 ± 0.05
llama 7B Q4_0 3.56 GiB 6.74 B CPU 4 pp 512 15.35 ± 1.13
llama 7B Q4_0 3.56 GiB 6.74 B CPU 4 tg 128 6.25 ± 0.00

@akingoverlook
Copy link

For the Arm AArch64 architecture, this PR adds support for optimized GEMV (using dot instructions) and GEMM (using mmla instructions) kernels for the q4_0_q8_0 and q4_0_q8_0 quantization methods.

The feature is enabled if the platform supports __ARM_FEATURE_MATMUL_INT8 (for GEMM) and __ARM_NEON or __ARM_FEATURE_SVE (for GEMV).

On AWS Graviton3 processors, these kernels resulted in a 2.5x improvement in prompt evaluation over the existing GEMM mmla kernels, as well as a 2x improvement in text generation over the default vec_dot kernel (Feb 21 commit 973053d). Please see the table below.

Authors: David Mansell (david.mansell@arm.com) and Dibakar Gope (dibakar.gope@arm.com)

image

Can this be extended to support Q4_1?

@Dibakar
Copy link
Contributor Author

Dibakar commented Mar 21, 2024

@akingoverlook Yes, it can easily be extended to support Q4_1 (it has the additional fp16 m field).

@Confetti-lxy
Copy link

Hello, this is really a fantastic work! I did get faster inference speed when trying this PR with online rearrangement, but at the same time, its memory overhead also increased. I think this is the additional memory overhead caused by rearrangement, so I want to try the offline rearrangement solution. However, due to my limited ability, I may not have enough ability to modify the code of convert-hf-to-gguf.py. Therefore, I would like to request if you could provide an example code, Thank you!

@Dibakar
Copy link
Contributor Author

Dibakar commented May 2, 2024

@ggerganov @slaren @cebtenzzre Thank you all for your suggestions. We attempted to address all of these suggestions in the current changes.

We defined a new quantization data type, Q4_0_AARCH64, as suggested by the llama.cpp reviewers, added the ability to quantize models to this datatype, and added all the relevant functions for these data types in type_traits to support them. We removed the changes in llama.cpp/ggml.c like main llama.cpp files and re-implemented the rearrange_ functions as quantization function for the new type.

We added two new files, ggml-aarch64.cpp and ggml-aarch64.h, to llama.cpp to place our Arm optimized kernels, and added minor changes to the ggml.c, llama.cpp-like main files that interact with our kernels. We have added the copyright claim only to the ggml-aarch64.cpp and ggml-aarch64.h files where we have placed our optimized kernels, and removed the copyright claim from the remaining main llama.cpp files where we have made minor changes simply to interact with the kernels. Note that Arm is not assigning copyright in these contributions.

We rearrange the weights offline while the code is quantized to Q4_0. As a result, when we run the inference later, the weight loading time is identical to loading a Q4_0 model file. Our changes automatically detect the underlying hardware platform (for example, Graviton2/3/4, and so on) and generate the appropriate aarch64 compliant weight format, which is then used in inference. In comparison to Q4_0, using Q4_0_AARCH64 gguf format does not result in an additional memory overhead. We have addressed it also in the changeset.

We have optimized our kernels further for different Arm cpu architecture, which resulted in a higher performance than we showed last time. Please see below for the updated performance numbers. We also ran the perplexity test to ensure that the perplexity remained consistent with the original Q4_0 model, as we expected. We have included provisions for converting both an fp16 HF model to q4_0_aarch64 and an existing q4_0 from HF to q4_0_aarch64. Please see below for the required command lines to convert a model to Q4_0_AARCH64 gguf format before using it for inference.

Single inference:
image
command lines:
baseline llama.cpp: ./llama-bench -m llama-2-7b.Q4_0.gguf -p 128 -n 128 -t 1,2,4,8,16,32,64 -fa 1
llama.cpp with our patches:
(a) If you want to re-quantize from the huggingface Q4_0 4b weight gguf:
./quantize --allow-requantize llama-2-7b.Q4_0.gguf llama-2-7b-Q4_0_aarch64.gguf Q4_0_AARCH64
If you want to quantize from the huggingface 16b weight gguf:
./quantize llama-2-7B-hf.gguf llama-2-7b-Q4_0_aarch64.gguf Q4_0_AARCH64
(b) ./llama-bench -m llama-2-7b-Q4_0_aarch64.gguf -p 128 -n 128 -t 1,2,4,8,16,32,64 -fa 1

Batched inference:
image
command lines:
baseline llama.cpp: ./batched-bench llama-2-7b.Q4_0.gguf 4096 4096 1024 1 0 99 128 128 4,8,16
llama.cpp with our patches: ./batched-bench llama-2-7b-Q4_0_aarch64.gguf 4096 4096 1024 1 0 99 128 128 4,8,16

@Dibakar
Copy link
Contributor Author

Dibakar commented May 6, 2024

I think that could also work, you could add new data types for the rearranged Q4_0 and Q8_0, add the ability to quantize models to these data types, and add all the relevant functions for these data types in type_traits to support them. That would integrate better in the current ggml framework, and would be less intrusive. @ggerganov what do you think?

Yes, rearranged data types should work. This way the logic in ggml_compute_forward_mul_mat could be delegated outside of ggml - it will be based on the data types, which would be determined upon model creation, based on the tensor sizes.

Try to remove the changes in llama.h and ggml.h and re-implement the rearrange_ functions as quantization function for the new types instead

@ggerganov @slaren @cebtenzzre Hi, we attempted to address your suggestions for this PR in the most recent changes, which we upstreamed last week. Please see our most recent comment from May 1, which describes our changeset. Could you please review the changeset and let us know your feedback? Thanks

@Jeximo
Copy link
Contributor

Jeximo commented May 6, 2024

I'm failing to build the latest optimizations:

cmake PR build log
cmake -B build -DCMAKE_C_FLAGS=-march=armv8.4-a+dotprod+i8mm && cd build && cmake --build . --config Release --target server --target main --target llama-bench --target quantize && cd bin/

-- The C compiler identification is Clang 18.1.4
-- The CXX compiler identification is Clang 18.1.4
-- Detecting C compiler ABI info
-- Detecting C compiler ABI info - done
-- Check for working C compiler: /data/data/com.termux/files/usr/bin/cc - skipped
-- Detecting C compile features
-- Detecting C compile features - done
-- Detecting CXX compiler ABI info
-- Detecting CXX compiler ABI info - done
-- Check for working CXX compiler: /data/data/com.termux/files/usr/bin/c++ - skipped
-- Detecting CXX compile features
-- Detecting CXX compile features - done
-- Found Git: /data/data/com.termux/files/usr/bin/git (found version "2.45.0")
-- Performing Test CMAKE_HAVE_LIBC_PTHREAD
-- Performing Test CMAKE_HAVE_LIBC_PTHREAD - Failed
-- Check if compiler accepts -pthread
-- Check if compiler accepts -pthread - yes
-- Found Threads: TRUE
-- ccache found, compilation results will be cached. Disable with LLAMA_CCACHE=OFF.
-- CMAKE_SYSTEM_PROCESSOR: aarch64
-- ARM detected
-- Performing Test COMPILER_SUPPORTS_FP16_FORMAT_I3E
-- Performing Test COMPILER_SUPPORTS_FP16_FORMAT_I3E - Failed
-- Configuring done (2.9s)
-- Generating done (0.3s)
-- Build files have been written to: /data/data/com.termux/files/home/llama2/build

[ 5%] Building CXX object common/CMakeFiles/build_info.dir/build-info.cpp.o
[ 11%] Built target build_info
[ 11%] Building C object CMakeFiles/ggml.dir/ggml.c.o
/data/data/com.termux/files/home/llama2/ggml.c:1588:5: warning: implicit conversion increases floating-point precision: 'float32_t' (aka 'float') to 'ggml_float' (aka 'double') [-Wdouble-promotion]
1588 | GGML_F16_VEC_REDUCE(sumf, sum);
| ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/data/data/com.termux/files/home/llama2/ggml.c:1008:41: note: expanded from macro 'GGML_F16_VEC_REDUCE'
1008 | #define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE
| ^
/data/data/com.termux/files/home/llama2/ggml.c:998:38: note: expanded from macro 'GGML_F32Cx4_REDUCE'
998 | #define GGML_F32Cx4_REDUCE GGML_F32x4_REDUCE
| ^
/data/data/com.termux/files/home/llama2/ggml.c:928:11: note: expanded from macro 'GGML_F32x4_REDUCE'
928 | res = GGML_F32x4_REDUCE_ONE(x[0]);
| ~ ^~~~~~~~~~~~~~~~~~~~~~~~~~~ /data/data/com.termux/files/home/llama2/ggml.c:913:34: note: expanded from macro 'GGML_F32x4_REDUCE_ONE'
913 | #define GGML_F32x4_REDUCE_ONE(x) vaddvq_f32(x)
| ^~~~~~~~~~~~~ /data/data/com.termux/files/home/llama2/ggml.c:1636:9: warning: implicit conversion increases floating-point precision: 'float32_t' (aka 'float') to 'ggml_float' (aka 'double') [-Wdouble-promotion] 1636 | GGML_F16_VEC_REDUCE(sumf[k], sum[k]); | ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ /data/data/com.termux/files/home/llama2/ggml.c:1008:41: note: expanded from macro 'GGML_F16_VEC_REDUCE' 1008 | #define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE | ^ /data/data/com.termux/files/home/llama2/ggml.c:998:38: note: expanded from macro 'GGML_F32Cx4_REDUCE' 998 | #define GGML_F32Cx4_REDUCE GGML_F32x4_REDUCE | ^ /data/data/com.termux/files/home/llama2/ggml.c:928:11: note: expanded from macro 'GGML_F32x4_REDUCE' 928 | res = GGML_F32x4_REDUCE_ONE(x[0]); \ | ~ ^~~~~~~~~~~~~~~~~~~~~~~~~~~ /data/data/com.termux/files/home/llama2/ggml.c:913:34: note: expanded from macro 'GGML_F32x4_REDUCE_ONE' 913 | #define GGML_F32x4_REDUCE_ONE(x) vaddvq_f32(x) | ^~~~~~~~~~~~~ 2 warnings generated. [ 16%] Building C object CMakeFiles/ggml.dir/ggml-alloc.c.o [ 22%] Building C object CMakeFiles/ggml.dir/ggml-backend.c.o [ 22%] Building C object CMakeFiles/ggml.dir/ggml-quants.c.o /data/data/com.termux/files/home/llama2/ggml-quants.c:3412:46: warning: arithmetic on a pointer to void is a GNU extension [-Wgnu-pointer-arith] 3412 | const block_q4_0 * restrict vx1 = vx + bx; | ~~ ^ /data/data/com.termux/files/home/llama2/ggml-quants.c:3415:46: warning: arithmetic on a pointer to void is a GNU extension [-Wgnu-pointer-arith] 3415 | const block_q8_0 * restrict vy1 = vy + by; | ~~ ^ /data/data/com.termux/files/home/llama2/ggml-quants.c:3779:46: warning: arithmetic on a pointer to void is a GNU extension [-Wgnu-pointer-arith] 3779 | const block_q4_1 * restrict vx1 = vx + bx; | ~~ ^ /data/data/com.termux/files/home/llama2/ggml-quants.c:3781:46: warning: arithmetic on a pointer to void is a GNU extension [-Wgnu-pointer-arith] 3781 | const block_q8_1 * restrict vy1 = vy + by; | ~~ ^ /data/data/com.termux/files/home/llama2/ggml-quants.c:4592:46: warning: arithmetic on a pointer to void is a GNU extension [-Wgnu-pointer-arith] 4592 | const block_q8_0 * restrict vx1 = vx + bx; | ~~ ^ /data/data/com.termux/files/home/llama2/ggml-quants.c:4594:46: warning: arithmetic on a pointer to void is a GNU extension [-Wgnu-pointer-arith] 4594 | const block_q8_0 * restrict vy1 = vy + by; | ~~ ^ 6 warnings generated. [ 27%] Building CXX object CMakeFiles/ggml.dir/sgemm.cpp.o [ 33%] Building CXX object CMakeFiles/ggml.dir/ggml-aarch64.cpp.o In file included from /data/data/com.termux/files/home/llama2/ggml-aarch64.cpp:5: In file included from /data/data/com.termux/files/home/llama2/ggml-quants.h:4: /data/data/com.termux/files/home/llama2/ggml-common.h:154:9: warning: anonymous structs are a GNU extension [-Wgnu-anonymous-struct] 154 | struct { | ^ /data/data/com.termux/files/home/llama2/ggml-common.h:154:9: warning: anonymous types declared in an anonymous union are an extension [-Wnested-anon-types] /data/data/com.termux/files/home/llama2/ggml-common.h:175:9: warning: anonymous structs are a GNU extension [-Wgnu-anonymous-struct] 175 | struct { | ^ /data/data/com.termux/files/home/llama2/ggml-common.h:175:9: warning: anonymous types declared in an anonymous union are an extension [-Wnested-anon-types] /data/data/com.termux/files/home/llama2/ggml-common.h:196:9: warning: anonymous structs are a GNU extension [-Wgnu-anonymous-struct] 196 | struct { | ^ /data/data/com.termux/files/home/llama2/ggml-common.h:196:9: warning: anonymous types declared in an anonymous union are an extension [-Wnested-anon-types] /data/data/com.termux/files/home/llama2/ggml-common.h:242:9: warning: anonymous structs are a GNU extension [-Wgnu-anonymous-struct] 242 | struct { | ^ /data/data/com.termux/files/home/llama2/ggml-common.h:242:9: warning: anonymous types declared in an anonymous union are an extension [-Wnested-anon-types] /data/data/com.termux/files/home/llama2/ggml-common.h:287:9: warning: anonymous structs are a GNU extension [-Wgnu-anonymous-struct] 287 | struct { | ^ /data/data/com.termux/files/home/llama2/ggml-common.h:287:9: warning: anonymous types declared in an anonymous union are an extension [-Wnested-anon-types] /data/data/com.termux/files/home/llama2/ggml-common.h:314:9: warning: anonymous structs are a GNU extension [-Wgnu-anonymous-struct] 314 | struct { | ^ /data/data/com.termux/files/home/llama2/ggml-common.h:314:9: warning: anonymous types declared in an anonymous union are an extension [-Wnested-anon-types]
/data/data/com.termux/files/home/llama2/ggml-aarch64.cpp:318:45: warning: unused parameter 'n' [-Wunused-parameter]
318 | void ggml_gemv_q4_0_q8_0_aarch64_sve256(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth) { | ^
/data/data/com.termux/files/home/llama2/ggml-aarch64.cpp:318:70: warning: unused parameter 's' [-Wunused-parameter]
318 | void ggml_gemv_q4_0_q8_0_aarch64_sve256(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth) {
| ^
/data/data/com.termux/files/home/llama2/ggml-aarch64.cpp:318:100: warning: unused parameter 'vx' [-Wunused-parameter]
318 | void ggml_gemv_q4_0_q8_0_aarch64_sve256(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth) { | ^
/data/data/com.termux/files/home/llama2/ggml-aarch64.cpp:318:131: warning: unused parameter 'vy' [-Wunused-parameter] 318 | void ggml_gemv_q4_0_q8_0_aarch64_sve256(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth) {
| ^
/data/data/com.termux/files/home/llama2/ggml-aarch64.cpp:318:139: warning: unused parameter 'nr' [-Wunused-parameter] 318 | void ggml_gemv_q4_0_q8_0_aarch64_sve256(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth) {
| ^
/data/data/com.termux/files/home/llama2/ggml-aarch64.cpp:318:147: warning: unused parameter 'nc' [-Wunused-parameter]
318 | void ggml_gemv_q4_0_q8_0_aarch64_sve256(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth) { | ^ /data/data/com.termux/files/home/llama2/ggml-aarch64.cpp:318:155: warning: unused parameter 'ith' [-Wunused-parameter] 318 | void ggml_gemv_q4_0_q8_0_aarch64_sve256(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth) {
| ^ /data/data/com.termux/files/home/llama2/ggml-aarch64.cpp:318:164: warning: unused parameter 'nth' [-Wunused-parameter]
318 | void ggml_gemv_q4_0_q8_0_aarch64_sve256(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth) { | ^ /data/data/com.termux/files/home/llama2/ggml-aarch64.cpp:548:45: warning: unused parameter 'n' [-Wunused-parameter]
548 | void ggml_gemv_q8_0_q8_0_aarch64_sve256(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth) {
| ^
/data/data/com.termux/files/home/llama2/ggml-aarch64.cpp:548:70: warning: unused parameter 's' [-Wunused-parameter]
548 | void ggml_gemv_q8_0_q8_0_aarch64_sve256(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth) {
| ^ /data/data/com.termux/files/home/llama2/ggml-aarch64.cpp:548:100: warning: unused parameter 'vx' [-Wunused-parameter] 548 | void ggml_gemv_q8_0_q8_0_aarch64_sve256(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth) {
| ^
/data/data/com.termux/files/home/llama2/ggml-aarch64.cpp:548:131: warning: unused parameter 'vy' [-Wunused-parameter]
548 | void ggml_gemv_q8_0_q8_0_aarch64_sve256(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth) {
| ^ /data/data/com.termux/files/home/llama2/ggml-aarch64.cpp:548:139: warning: unused parameter 'nr' [-Wunused-parameter] 548 | void ggml_gemv_q8_0_q8_0_aarch64_sve256(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth) { | ^
/data/data/com.termux/files/home/llama2/ggml-aarch64.cpp:548:147: warning: unused parameter 'nc' [-Wunused-parameter]
548 | void ggml_gemv_q8_0_q8_0_aarch64_sve256(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth) { | ^
/data/data/com.termux/files/home/llama2/ggml-aarch64.cpp:548:155: warning: unused parameter 'ith' [-Wunused-parameter]
548 | void ggml_gemv_q8_0_q8_0_aarch64_sve256(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth) { | ^
/data/data/com.termux/files/home/llama2/ggml-aarch64.cpp:548:164: warning: unused parameter 'nth' [-Wunused-parameter] 548 | void ggml_gemv_q8_0_q8_0_aarch64_sve256(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth) { | ^ /data/data/com.termux/files/home/llama2/ggml-aarch64.cpp:669:25: error: always_inline function 'vdotq_s32' requires target feature 'dotprod', but would be inlined into function 'ggml_gemv_q8_0_q8_0_aarch64_neon' that is compiled without support for 'dotprod' 669 | iacc0 = vdotq_laneq_s32(iacc0, rhs_vec_0_0_0, lhs_vec_0, 0);
| ^ /data/data/com.termux/files/usr/lib/clang/18/include/arm_neon.h:62995:15: note: expanded from macro 'vdotq_laneq_s32' 62995 | __ret_749 = vdotq_s32(__s0_749, __s1_749, *(int8x16_t *) &__reint1_749); \ | ^
/data/data/com.termux/files/home/llama2/ggml-aarch64.cpp:670:25: error: always_inline function 'vdotq_s32' requires target feature 'dotprod', but would be inlined into function 'ggml_gemv_q8_0_q8_0_aarch64_neon' that is compiled without support for 'dotprod'
670 | iacc0 = vdotq_laneq_s32(iacc0, rhs_vec_0_0_1, lhs_vec_1, 0);
| ^
/data/data/com.termux/files/usr/lib/clang/18/include/arm_neon.h:62995:15: note: expanded from macro 'vdotq_laneq_s32' 62995 | __ret_749 = vdotq_s32(__s0_749, __s1_749, *(int8x16_t *) &__reint1_749); \ | ^
/data/data/com.termux/files/home/llama2/ggml-aarch64.cpp:672:25: error: always_inline function 'vdotq_s32' requires target feature 'dotprod', but would be inlined into function 'ggml_gemv_q8_0_q8_0_aarch64_neon' that is compiled without support for 'dotprod'
672 | iacc1 = vdotq_laneq_s32(iacc1, rhs_vec_1_0_0, lhs_vec_0, 0);
| ^ /data/data/com.termux/files/usr/lib/clang/18/include/arm_neon.h:62995:15: note: expanded from macro 'vdotq_laneq_s32'
62995 | __ret_749 = vdotq_s32(__s0_749, __s1_749, *(int8x16_t *) &__reint1_749); \ | ^ /data/data/com.termux/files/home/llama2/ggml-aarch64.cpp:673:25: error: always_inline function 'vdotq_s32' requires target feature 'dotprod', but would be inlined into function 'ggml_gemv_q8_0_q8_0_aarch64_neon' that is compiled without support for 'dotprod'
673 | iacc1 = vdotq_laneq_s32(iacc1, rhs_vec_1_0_1, lhs_vec_1, 0); | ^ /data/data/com.termux/files/usr/lib/clang/18/include/arm_neon.h:62995:15: note: expanded from macro 'vdotq_laneq_s32' 62995 | __ret_749 = vdotq_s32(__s0_749, __s1_749, *(int8x16_t *) &__reint1_749);
| ^
/data/data/com.termux/files/home/llama2/ggml-aarch64.cpp:675:25: error: always_inline function 'vdotq_s32' requires target feature 'dotprod', but would be inlined into function 'ggml_gemv_q8_0_q8_0_aarch64_neon' that is compiled without support for 'dotprod' 675 | iacc0 = vdotq_laneq_s32(iacc0, rhs_vec_0_1_0, lhs_vec_0, 1);
| ^
/data/data/com.termux/files/usr/lib/clang/18/include/arm_neon.h:62995:15: note: expanded from macro 'vdotq_laneq_s32' 62995 | __ret_749 = vdotq_s32(__s0_749, __s1_749, *(int8x16_t *) &__reint1_749);
| ^ /data/data/com.termux/files/home/llama2/ggml-aarch64.cpp:676:25: error: always_inline function 'vdotq_s32' requires target feature 'dotprod', but would be inlined into function 'ggml_gemv_q8_0_q8_0_aarch64_neon' that is compiled without support for 'dotprod' 676 | iacc0 = vdotq_laneq_s32(iacc0, rhs_vec_0_1_1, lhs_vec_1, 1);
| ^ /data/data/com.termux/files/usr/lib/clang/18/include/arm_neon.h:62995:15: note: expanded from macro 'vdotq_laneq_s32'
62995 | __ret_749 = vdotq_s32(__s0_749, __s1_749, *(int8x16_t *) &__reint1_749); \ | ^ /data/data/com.termux/files/home/llama2/ggml-aarch64.cpp:678:25: error: always_inline function 'vdotq_s32' requires target feature 'dotprod', but would be inlined into function 'ggml_gemv_q8_0_q8_0_aarch64_neon' that is compiled without support for 'dotprod' 678 | iacc1 = vdotq_laneq_s32(iacc1, rhs_vec_1_1_0, lhs_vec_0, 1); | ^
/data/data/com.termux/files/usr/lib/clang/18/include/arm_neon.h:62995:15: note: expanded from macro 'vdotq_laneq_s32' 62995 | __ret_749 = vdotq_s32(__s0_749, __s1_749, *(int8x16_t *) &__reint1_749);
| ^ /data/data/com.termux/files/home/llama2/ggml-aarch64.cpp:679:25: error: always_inline function 'vdotq_s32' requires target feature 'dotprod', but would be inlined into function 'ggml_gemv_q8_0_q8_0_aarch64_neon' that is compiled without support for 'dotprod' 679 | iacc1 = vdotq_laneq_s32(iacc1, rhs_vec_1_1_1, lhs_vec_1, 1);
| ^ /data/data/com.termux/files/usr/lib/clang/18/include/arm_neon.h:62995:15: note: expanded from macro 'vdotq_laneq_s32'
62995 | __ret_749 = vdotq_s32(__s0_749, __s1_749, *(int8x16_t *) &__reint1_749); \ | ^
/data/data/com.termux/files/home/llama2/ggml-aarch64.cpp:681:25: error: always_inline function 'vdotq_s32' requires target feature 'dotprod', but would be inlined into function 'ggml_gemv_q8_0_q8_0_aarch64_neon' that is compiled without support for 'dotprod' 681 | iacc0 = vdotq_laneq_s32(iacc0, rhs_vec_0_2_0, lhs_vec_0, 2); | ^ /data/data/com.termux/files/usr/lib/clang/18/include/arm_neon.h:62995:15: note: expanded from macro 'vdotq_laneq_s32' 62995 | __ret_749 = vdotq_s32(__s0_749, __s1_749, *(int8x16_t *) &__reint1_749); \ | ^ /data/data/com.termux/files/home/llama2/ggml-aarch64.cpp:682:25: error: always_inline function 'vdotq_s32' requires target feature 'dotprod', but would be inlined into function 'ggml_gemv_q8_0_q8_0_aarch64_neon' that is compiled without support for 'dotprod' 682 | iacc0 = vdotq_laneq_s32(iacc0, rhs_vec_0_2_1, lhs_vec_1, 2);
| ^ /data/data/com.termux/files/usr/lib/clang/18/include/arm_neon.h:62995:15: note: expanded from macro 'vdotq_laneq_s32'
62995 | __ret_749 = vdotq_s32(__s0_749, __s1_749, *(int8x16_t *) &__reint1_749); \ | ^ /data/data/com.termux/files/home/llama2/ggml-aarch64.cpp:684:25: error: always_inline function 'vdotq_s32' requires target feature 'dotprod', but would be inlined into function 'ggml_gemv_q8_0_q8_0_aarch64_neon' that is compiled without support for 'dotprod' 684 | iacc1 = vdotq_laneq_s32(iacc1, rhs_vec_1_2_0, lhs_vec_0, 2);
| ^ /data/data/com.termux/files/usr/lib/clang/18/include/arm_neon.h:62995:15: note: expanded from macro 'vdotq_laneq_s32'
62995 | __ret_749 = vdotq_s32(__s0_749, __s1_749, *(int8x16_t *) &__reint1_749); \ | ^
/data/data/com.termux/files/home/llama2/ggml-aarch64.cpp:685:25: error: always_inline function 'vdotq_s32' requires target feature 'dotprod', but would be inlined into function 'ggml_gemv_q8_0_q8_0_aarch64_neon' that is compiled without support for 'dotprod' 685 | iacc1 = vdotq_laneq_s32(iacc1, rhs_vec_1_2_1, lhs_vec_1, 2);
| ^ /data/data/com.termux/files/usr/lib/clang/18/include/arm_neon.h:62995:15: note: expanded from macro 'vdotq_laneq_s32' 62995 | __ret_749 = vdotq_s32(__s0_749, __s1_749, *(int8x16_t *) &__reint1_749); \ | ^
/data/data/com.termux/files/home/llama2/ggml-aarch64.cpp:687:25: error: always_inline function 'vdotq_s32' requires target feature 'dotprod', but would be inlined into function 'ggml_gemv_q8_0_q8_0_aarch64_neon' that is compiled without support for 'dotprod' 687 | iacc0 = vdotq_laneq_s32(iacc0, rhs_vec_0_3_0, lhs_vec_0, 3); | ^
/data/data/com.termux/files/usr/lib/clang/18/include/arm_neon.h:62995:15: note: expanded from macro 'vdotq_laneq_s32' 62995 | __ret_749 = vdotq_s32(__s0_749, __s1_749, *(int8x16_t *) &__reint1_749);
| ^ /data/data/com.termux/files/home/llama2/ggml-aarch64.cpp:688:25: error: always_inline function 'vdotq_s32' requires target feature 'dotprod', but would be inlined into function 'ggml_gemv_q8_0_q8_0_aarch64_neon' that is compiled without support for 'dotprod'
688 | iacc0 = vdotq_laneq_s32(iacc0, rhs_vec_0_3_1, lhs_vec_1, 3); | ^
/data/data/com.termux/files/usr/lib/clang/18/include/arm_neon.h:62995:15: note: expanded from macro 'vdotq_laneq_s32' 62995 | __ret_749 = vdotq_s32(__s0_749, __s1_749, *(int8x16_t *) &__reint1_749); \ | ^ /data/data/com.termux/files/home/llama2/ggml-aarch64.cpp:690:25: error: always_inline function 'vdotq_s32' requires target feature 'dotprod', but would be inlined into function 'ggml_gemv_q8_0_q8_0_aarch64_neon' that is compiled without support for 'dotprod' 690 | iacc1 = vdotq_laneq_s32(iacc1, rhs_vec_1_3_0, lhs_vec_0, 3);
| ^ /data/data/com.termux/files/usr/lib/clang/18/include/arm_neon.h:62995:15: note: expanded from macro 'vdotq_laneq_s32' 62995 | __ret_749 = vdotq_s32(__s0_749, __s1_749, *(int8x16_t *) &__reint1_749);
| ^ /data/data/com.termux/files/home/llama2/ggml-aarch64.cpp:691:25: error: always_inline function 'vdotq_s32' requires target feature 'dotprod', but would be inlined into function 'ggml_gemv_q8_0_q8_0_aarch64_neon' that is compiled without support for 'dotprod' 691 | iacc1 = vdotq_laneq_s32(iacc1, rhs_vec_1_3_1, lhs_vec_1, 3); | ^ /data/data/com.termux/files/usr/lib/clang/18/include/arm_neon.h:62995:15: note: expanded from macro 'vdotq_laneq_s32' 62995 | __ret_749 = vdotq_s32(__s0_749, __s1_749, *(int8x16_t *) &__reint1_749);
| ^ /data/data/com.termux/files/home/llama2/ggml-aarch64.cpp:704:45: warning: unused parameter 'n' [-Wunused-parameter]
704 | void ggml_gemm_q4_0_q8_0_aarch64_sve256(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth) { | ^ /data/data/com.termux/files/home/llama2/ggml-aarch64.cpp:704:70: warning: unused parameter 's' [-Wunused-parameter] 704 | void ggml_gemm_q4_0_q8_0_aarch64_sve256(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth) { | ^ /data/data/com.termux/files/home/llama2/ggml-aarch64.cpp:704:100: warning: unused parameter 'vx' [-Wunused-parameter] 704 | void ggml_gemm_q4_0_q8_0_aarch64_sve256(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth) { | ^
/data/data/com.termux/files/home/llama2/ggml-aarch64.cpp:704:131: warning: unused parameter 'vy' [-Wunused-parameter] 704 | void ggml_gemm_q4_0_q8_0_aarch64_sve256(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth) {
| ^
/data/data/com.termux/files/home/llama2/ggml-aarch64.cpp:704:139: warning: unused parameter 'nr' [-Wunused-parameter]
704 | void ggml_gemm_q4_0_q8_0_aarch64_sve256(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth) { | ^ /data/data/com.termux/files/home/llama2/ggml-aarch64.cpp:704:147: warning: unused parameter 'nc' [-Wunused-parameter] 704 | void ggml_gemm_q4_0_q8_0_aarch64_sve256(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth) { | ^
/data/data/com.termux/files/home/llama2/ggml-aarch64.cpp:704:155: warning: unused parameter 'ith' [-Wunused-parameter] 704 | void ggml_gemm_q4_0_q8_0_aarch64_sve256(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth) { | ^ /data/data/com.termux/files/home/llama2/ggml-aarch64.cpp:704:164: warning: unused parameter 'nth' [-Wunused-parameter]
704 | void ggml_gemm_q4_0_q8_0_aarch64_sve256(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth) { | ^ /data/data/com.termux/files/home/llama2/ggml-aarch64.cpp:1130:43: warning: unused parameter 'n' [-Wunused-parameter] 1130 | void ggml_gemm_q4_0_q8_0_aarch64_neon(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth) {
| ^ /data/data/com.termux/files/home/llama2/ggml-aarch64.cpp:1130:68: warning: unused parameter 's' [-Wunused-parameter] 1130 | void ggml_gemm_q4_0_q8_0_aarch64_neon(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth) {
| ^ /data/data/com.termux/files/home/llama2/ggml-aarch64.cpp:1130:98: warning: unused parameter 'vx' [-Wunused-parameter] 1130 | void ggml_gemm_q4_0_q8_0_aarch64_neon(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth) { | ^ /data/data/com.termux/files/home/llama2/ggml-aarch64.cpp:1130:129: warning: unused parameter 'vy' [-Wunused-parameter]
1130 | void ggml_gemm_q4_0_q8_0_aarch64_neon(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth) {
| ^
/data/data/com.termux/files/home/llama2/ggml-aarch64.cpp:1130:137: warning: unused parameter 'nr' [-Wunused-parameter] 1130 | void ggml_gemm_q4_0_q8_0_aarch64_neon(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth) { | ^ /data/data/com.termux/files/home/llama2/ggml-aarch64.cpp:1130:145: warning: unused parameter 'nc' [-Wunused-parameter] 1130 | void ggml_gemm_q4_0_q8_0_aarch64_neon(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth) { | ^
/data/data/com.termux/files/home/llama2/ggml-aarch64.cpp:1130:153: warning: unused parameter 'ith' [-Wunused-parameter]
1130 | void ggml_gemm_q4_0_q8_0_aarch64_neon(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth) { | ^ /data/data/com.termux/files/home/llama2/ggml-aarch64.cpp:1130:162: warning: unused parameter 'nth' [-Wunused-parameter] 1130 | void ggml_gemm_q4_0_q8_0_aarch64_neon(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth) {
| ^ /data/data/com.termux/files/home/llama2/ggml-aarch64.cpp:2010:38: warning: unused parameter 'n' [-Wunused-parameter] 2010 | void ggml_gemm_q8_0_q8_0_aarch64(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth) { | ^
/data/data/com.termux/files/home/llama2/ggml-aarch64.cpp:2010:63: warning: unused parameter 's' [-Wunused-parameter]
2010 | void ggml_gemm_q8_0_q8_0_aarch64(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth) { | ^
/data/data/com.termux/files/home/llama2/ggml-aarch64.cpp:2010:93: warning: unused parameter 'vx' [-Wunused-parameter] 2010 | void ggml_gemm_q8_0_q8_0_aarch64(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth) { | ^ /data/data/com.termux/files/home/llama2/ggml-aarch64.cpp:2010:124: warning: unused parameter 'vy' [-Wunused-parameter] 2010 | void ggml_gemm_q8_0_q8_0_aarch64(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth) { | ^ /data/data/com.termux/files/home/llama2/ggml-aarch64.cpp:2010:132: warning: unused parameter 'nr' [-Wunused-parameter]
2010 | void ggml_gemm_q8_0_q8_0_aarch64(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth) {
| ^
/data/data/com.termux/files/home/llama2/ggml-aarch64.cpp:2010:140: warning: unused parameter 'nc' [-Wunused-parameter] 2010 | void ggml_gemm_q8_0_q8_0_aarch64(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth) { | ^ /data/data/com.termux/files/home/llama2/ggml-aarch64.cpp:2010:148: warning: unused parameter 'ith' [-Wunused-parameter] 2010 | void ggml_gemm_q8_0_q8_0_aarch64(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth) { | ^
/data/data/com.termux/files/home/llama2/ggml-aarch64.cpp:2010:157: warning: unused parameter 'nth' [-Wunused-parameter]
2010 | void ggml_gemm_q8_0_q8_0_aarch64(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth) { | ^
52 warnings and 16 errors generated. make[3]: *** [CMakeFiles/ggml.dir/build.make:146: CMakeFiles/ggml.dir/ggml-aarch64.cpp.o]
Error 1 make[2]: *** [CMakeFiles/Makefile2:820: CMakeFiles/ggml.dir/all]
Error 2 make[1]: *** [CMakeFiles/Makefile2:3187: examples/server/CMakeFiles/server.dir/rule] Error 2
make: *** [Makefile:1284: server] Error 2

Maybe my device is now incompatible..? Same instruction works on master:

master cmake
cmake -B build -DCMAKE_C_FLAGS=-march=armv8.4-a+dotprod+i8mm && cd build && cmake --build . --config Release --target server --target main --target llama-bench --target quantize && cd bin/

-- The C compiler identification is Clang 18.1.4
-- The CXX compiler identification is Clang 18.1.4-- Detecting C compiler ABI info
-- Detecting C compiler ABI info - done
-- Check for working C compiler: /data/data/com.termux/files/usr/bin/cc - skipped
-- Detecting C compile features
-- Detecting C compile features - done
-- Detecting CXX compiler ABI info
-- Detecting CXX compiler ABI info - done
-- Check for working CXX compiler: /data/data/com.termux/files/usr/bin/c++ - skipped
-- Detecting CXX compile features
-- Detecting CXX compile features - done
-- Found Git: /data/data/com.termux/files/usr/bin/git (found version "2.45.0")
-- Performing Test CMAKE_HAVE_LIBC_PTHREAD
-- Performing Test CMAKE_HAVE_LIBC_PTHREAD - Failed
-- Check if compiler accepts -pthread
-- Check if compiler accepts -pthread - yes
-- Found Threads: TRUE
-- ccache found, compilation results will be cached. Disable with LLAMA_CCACHE=OFF.
-- CMAKE_SYSTEM_PROCESSOR: aarch64
-- ARM detected
-- Performing Test COMPILER_SUPPORTS_FP16_FORMAT_I3E -- Performing Test COMPILER_SUPPORTS_FP16_FORMAT_I3E - Failed
-- Configuring done (3.2s)
-- Generating done (0.3s)
-- Build files have been written to: /data/data/com.termux/files/home/llama3/build

[ 6%] Generating build details from Git
-- Found Git: /data/data/com.termux/files/usr/bin/git (found version "2.45.0")
[ 12%] Building CXX object common/CMakeFiles/build_info.dir/build-info.cpp.o
[ 12%] Built target build_info
[ 12%] Building C object CMakeFiles/ggml.dir/ggml.c.o /data/data/com.termux/files/home/llama3/ggml.c:1564:5: warning: implicit conversion increases floating-point precision: 'float32_t' (aka 'float') to 'ggml_float' (aka 'double') [-Wdouble-promotion] 1564 | GGML_F16_VEC_REDUCE(sumf, sum);
| ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ /data/data/com.termux/files/home/llama3/ggml.c:984:41: note: expanded from macro 'GGML_F16_VEC_REDUCE' 984 | #define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE
| ^ /data/data/com.termux/files/home/llama3/ggml.c:974:38: note: expanded from macro 'GGML_F32Cx4_REDUCE'
974 | #define GGML_F32Cx4_REDUCE GGML_F32x4_REDUCE | ^
/data/data/com.termux/files/home/llama3/ggml.c:904:11: note: expanded from macro 'GGML_F32x4_REDUCE' 904 | res = GGML_F32x4_REDUCE_ONE(x[0]); \ | ~ ^~~~~~~~~~~~~~~~~~~~~~~~~~~ /data/data/com.termux/files/home/llama3/ggml.c:889:34: note: expanded from macro 'GGML_F32x4_REDUCE_ONE' 889 | #define GGML_F32x4_REDUCE_ONE(x) vaddvq_f32(x)
| ^~~~~~~~~~~~~ /data/data/com.termux/files/home/llama3/ggml.c:1612:9: warning: implicit conversion increases floating-point precision: 'float32_t' (aka 'float') to 'ggml_float' (aka 'double') [-Wdouble-promotion] 1612 | GGML_F16_VEC_REDUCE(sumf[k], sum[k]); | ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/data/data/com.termux/files/home/llama3/ggml.c:984:41: note: expanded from macro 'GGML_F16_VEC_REDUCE'
984 | #define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE | ^ /data/data/com.termux/files/home/llama3/ggml.c:974:38: note: expanded from macro 'GGML_F32Cx4_REDUCE' 974 | #define GGML_F32Cx4_REDUCE GGML_F32x4_REDUCE
| ^ /data/data/com.termux/files/home/llama3/ggml.c:904:11: note: expanded from macro 'GGML_F32x4_REDUCE'
904 | res = GGML_F32x4_REDUCE_ONE(x[0]); \ | ~ ^~~~~~~~~~~~~~~~~~~~~~~~~~~ /data/data/com.termux/files/home/llama3/ggml.c:889:34: note: expanded from macro 'GGML_F32x4_REDUCE_ONE' 889 | #define GGML_F32x4_REDUCE_ONE(x) vaddvq_f32(x)
| ^~~~~~~~~~~~~ 2 warnings generated.
[ 18%] Building C object CMakeFiles/ggml.dir/ggml-alloc.c.o
[ 25%] Building C object CMakeFiles/ggml.dir/ggml-backend.c.o
[ 25%] Building C object CMakeFiles/ggml.dir/ggml-quants.c.o /data/data/com.termux/files/home/llama3/ggml-quants.c:3412:46: warning: arithmetic on a pointer to void is a GNU extension [-Wgnu-pointer-arith]
3412 | const block_q4_0 * restrict vx1 = vx + bx; | ~~ ^ /data/data/com.termux/files/home/llama3/ggml-quants.c:3415:46: warning: arithmetic on a pointer to void is a GNU extension [-Wgnu-pointer-arith]
3415 | const block_q8_0 * restrict vy1 = vy + by;
| ~~ ^
/data/data/com.termux/files/home/llama3/ggml-quants.c:3779:46: warning: arithmetic on a pointer to void is a GNU extension [-Wgnu-pointer-arith]
3779 | const block_q4_1 * restrict vx1 = vx + bx; | ~~ ^
/data/data/com.termux/files/home/llama3/ggml-quants.c:3781:46: warning: arithmetic on a pointer to void is a GNU extension [-Wgnu-pointer-arith]
3781 | const block_q8_1 * restrict vy1 = vy + by; | ~~ ^
/data/data/com.termux/files/home/llama3/ggml-quants.c:4592:46: warning: arithmetic on a pointer to void is a GNU extension [-Wgnu-pointer-arith] 4592 | const block_q8_0 * restrict vx1 = vx + bx;
| ~~ ^ /data/data/com.termux/files/home/llama3/ggml-quants.c:4594:46: warning: arithmetic on a pointer to void is a GNU extension [-Wgnu-pointer-arith]
4594 | const block_q8_0 * restrict vy1 = vy + by;
| ~~ ^
6 warnings generated.
[ 31%] Building CXX object CMakeFiles/ggml.dir/sgemm.cpp.o
[ 31%] Built target ggml
[ 31%] Building CXX object CMakeFiles/llama.dir/llama.cpp.o
[ 37%] Building CXX object CMakeFiles/llama.dir/unicode.cpp.o
[ 43%] Building CXX object CMakeFiles/llama.dir/unicode-data.cpp.o
[ 43%] Linking CXX static library libllama.a
[ 43%] Built target llama
[ 43%] Building CXX object common/CMakeFiles/common.dir/common.cpp.o
[ 50%] Building CXX object common/CMakeFiles/common.dir/sampling.cpp.o
[ 56%] Building CXX object common/CMakeFiles/common.dir/console.cpp.o
[ 56%] Building CXX object common/CMakeFiles/common.dir/grammar-parser.cpp.o
[ 62%] Building CXX object common/CMakeFiles/common.dir/json-schema-to-grammar.cpp.o
[ 68%] Building CXX object common/CMakeFiles/common.dir/train.cpp.o
[ 68%] Building CXX object common/CMakeFiles/common.dir/ngram-cache.cpp.o
[ 75%] Linking CXX static library libcommon.a
[ 75%] Built target common
[ 81%] Generating json-schema-to-grammar.mjs.hpp
[ 87%] Generating completion.js.hpp
[ 93%] Generating index.html.hpp
[ 93%] Generating index.js.hpp
[ 93%] Building CXX object examples/server/CMakeFiles/server.dir/server.cpp.o
[100%] Linking CXX executable ../../bin/server
[100%] Built target server
[ 15%] Built target build_info
[ 38%] Built target ggml
[ 53%] Built target llama
[ 92%] Built target common
[ 92%] Building CXX object examples/main/CMakeFiles/main.dir/main.cpp.o
[100%] Linking CXX executable ../../bin/main
[100%] Built target main
[ 14%] Built target build_info
[ 35%] Built target ggml
[ 50%] Built target llama
[ 85%] Built target common
[ 92%] Building CXX object examples/llama-bench/CMakeFiles/llama-bench.dir/llama-bench.cpp.o
[100%] Linking CXX executable ../../bin/llama-bench
[100%] Built target llama-bench
[ 15%] Built target build_info
[ 38%] Built target ggml
[ 53%] Built target llama
[ 92%] Built target common
[ 92%] Building CXX object examples/quantize/CMakeFiles/quantize.dir/quantize.cpp.o
[100%] Linking CXX executable ../../bin/quantize
[100%] Built target quantize

@akingoverlook
Copy link

I'm failing to build the latest optimizations:
cmake PR build log
cmake -B build -DCMAKE_C_FLAGS=-march=armv8.4-a+dotprod+i8mm && cd build && cmake --build . --config Release --target server --target main --target llama-bench --target quantize && cd bin/

Maybe my device is now incompatible..? Same instruction works on master:
master cmake
cmake -B build -DCMAKE_C_FLAGS=-march=armv8.4-a+dotprod+i8mm && cd build && cmake --build . --config Release --target server --target main --target llama-bench --target quantize && cd bin/

C_FLAGS aren't passed to C++ compiler, you need to set similar CXX_FLAGS to enable dotprod+i8mm.
Had been burned by this in a more subtle way with the original version of those changes, where things would compile only to just segfault at runtime ;)

ggml-aarch64.cpp Outdated
Copy link

@akingoverlook akingoverlook May 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The performance improvements are lovely, but the logistical aspect of this can be improved.

It should be easy enough to detect availability of SVE/NEON/I8MM at runtime. Don't need to tell how to people working for ARM. That logic should be placed into ggml_cpu_has_neon(), ggml_cpu_has_matmul_int8() which right now just depend on the same __ARM_NEON/__ARM_FEATURE_MATMUL_INT8 flags that are set at compile time.

Since most people will be cross-compiling for ARM, at least half won't set the correct flags and will never get to see the benefits. Those that do set them, will face the problem with targeting multiple device types (as people targeting ARM often have to) and needing to produce, package, and deploy multiple/matching libraries/binaries.

Removing that headache should be worth the trouble of doing the dynamic detection.

Also, the new quantization types __AARCH64 would be dependent on what compile flags were set, and then the quantized models would not be interchangeable between ARM targets with different set of flags. I would need different quantized model for my different ARM devices, and they all would have the same type. That is unmanageable in any practical sense. The quantization type needs to convey compatibility with inference HW, so you would need one per ISA combination.

Copy link

@akingoverlook akingoverlook left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change the ISA detection method to runtime detection.

@ggerganov
Copy link
Owner

The 32-bit Armv7 build is failing:

https://github.com/ggerganov/llama.cpp/actions/runs/9857479218/job/27217025373?pr=5780#step:5:147

Probably have to check #if defined(__ARM_NEON) && defined(__aarch64__) in ggml-aarch.c instead of just #if defined(__ARM_NEON)

Also, take a look at this comment: #5780 (comment)

@Dibakar
Copy link
Contributor Author

Dibakar commented Jul 9, 2024

The 32-bit Armv7 build is failing:

https://github.com/ggerganov/llama.cpp/actions/runs/9857479218/job/27217025373?pr=5780#step:5:147

Probably have to check #if defined(__ARM_NEON) && defined(__aarch64__) in ggml-aarch.c instead of just #if defined(__ARM_NEON)

Also, take a look at this comment: #5780 (comment)

We included defined(__aarch64__) to the defined(__ARM_NEON) check to guard 64-bit Neon kernels in the latest commit.

@github-actions github-actions bot added the documentation Improvements or additions to documentation label Jul 9, 2024
@ggerganov ggerganov merged commit 0f1a39f into ggerganov:master Jul 10, 2024
53 checks passed
@ggerganov
Copy link
Owner

@Dibakar Thank you for the contribution - it's now merged. It would be great to add in the future AArch64-optimized kernels for Q8 models - contributions are welcome

@Dibakar
Copy link
Contributor Author

Dibakar commented Jul 10, 2024

@ggerganov Thank you so much for reviewing our changes and merging our PR into the mainline. Yes, we will discuss internally about contributing Q8 code.

arthw pushed a commit to arthw/llama.cpp that referenced this pull request Jul 13, 2024
* Arm AArch64: optimized GEMV and GEMM kernels for q4_0_q8_0, and q8_0_q8_0 quantization

* Arm AArch64: add optimized GEMV and GEMM asm kernels for q4_0_q8_0 quantization and refactor code to address llama.cpp pr#5780 suggestions

* Arm AArch64: add optimized GEMV and GEMM asm kernels for q4_0_q8_0 quantization and refactor code to address llama.cpp pr#5780 suggestions

* Arm AArch64: add optimized GEMV and GEMM asm kernels for q4_0_q8_0 quantization and refactor code to address llama.cpp pr#5780 suggestions

* Arm AArch64: add optimized GEMV and GEMM asm kernels for q4_0_q8_0 quantization and refactor code to address llama.cpp pr#5780 suggestions

* Arm AArch64: add copyright claim only to ggml-aarch64.cpp and ggml-aarch64.h files

* Arm AArch64: minor code refactoring for rebase

* Arm AArch64: minor code refactoring for resolving a build issue with cmake

* Arm AArch64: minor code refactoring to split the Q4_0_AARC64 type into three separate types: Q4_0_4_4, Q4_0_4_8, and Q4_0_8_8

* Arm AArch64: minor code change for resolving a build issue with server-windows

* retrigger checks

* Arm AArch64: minor code changes for rebase

* Arm AArch64: minor changes to skip the pr#7433 vec_dot code for arm cpus with SVE VL not equal to 256 bits

* Arm AArch64: remove stale LLAMA_QKK_64 from CMakeLists.txt and delete build.zig

* Arm AArch64: add reference scalar gemm and gemv, and avoid dynamic memory allocations during quantization for Q4_0_4_4, Q4_0_4_8, and Q4_0_8_8

* Arm AArch64: add multithreaded quantization support for the new types: Q4_0_4_4, Q4_0_4_8, and Q4_0_8_8

* Arm AArch64: minor code refactoring

* Arm AArch64: simplify logic for calling gemm and gemv functions in ggml_compute_forward_mul_mat

* Arm AArch64: minimize changes in ggml_compute_forward_mul_mat

* Arm AArch64: minor code refactoring, and add reference scalar code to quantize routines for new quant types

* Arm AArch64: minor code refactoring

* Arm AArch64: minor code refactoring

* Arm AArch64: minor code refactoring

* rebase on the latest master commit 3fd62a6 and adapt to the new directory structure

* Arm AArch64: remove a redundant comment

* Arm AArch64: add pragma in ggml-aarch64.c to turn -Woverlength-strings warning off

* Arm AArch64: use __aarch64__ check to guard 64-bit neon kernels

* Arm AArch64: update docs/build.md README to include compile time flags for buiilding the Q4_0_4_4 quant type
arthw pushed a commit to arthw/llama.cpp that referenced this pull request Jul 13, 2024
* Arm AArch64: optimized GEMV and GEMM kernels for q4_0_q8_0, and q8_0_q8_0 quantization

* Arm AArch64: add optimized GEMV and GEMM asm kernels for q4_0_q8_0 quantization and refactor code to address llama.cpp pr#5780 suggestions

* Arm AArch64: add optimized GEMV and GEMM asm kernels for q4_0_q8_0 quantization and refactor code to address llama.cpp pr#5780 suggestions

* Arm AArch64: add optimized GEMV and GEMM asm kernels for q4_0_q8_0 quantization and refactor code to address llama.cpp pr#5780 suggestions

* Arm AArch64: add optimized GEMV and GEMM asm kernels for q4_0_q8_0 quantization and refactor code to address llama.cpp pr#5780 suggestions

* Arm AArch64: add copyright claim only to ggml-aarch64.cpp and ggml-aarch64.h files

* Arm AArch64: minor code refactoring for rebase

* Arm AArch64: minor code refactoring for resolving a build issue with cmake

* Arm AArch64: minor code refactoring to split the Q4_0_AARC64 type into three separate types: Q4_0_4_4, Q4_0_4_8, and Q4_0_8_8

* Arm AArch64: minor code change for resolving a build issue with server-windows

* retrigger checks

* Arm AArch64: minor code changes for rebase

* Arm AArch64: minor changes to skip the pr#7433 vec_dot code for arm cpus with SVE VL not equal to 256 bits

* Arm AArch64: remove stale LLAMA_QKK_64 from CMakeLists.txt and delete build.zig

* Arm AArch64: add reference scalar gemm and gemv, and avoid dynamic memory allocations during quantization for Q4_0_4_4, Q4_0_4_8, and Q4_0_8_8

* Arm AArch64: add multithreaded quantization support for the new types: Q4_0_4_4, Q4_0_4_8, and Q4_0_8_8

* Arm AArch64: minor code refactoring

* Arm AArch64: simplify logic for calling gemm and gemv functions in ggml_compute_forward_mul_mat

* Arm AArch64: minimize changes in ggml_compute_forward_mul_mat

* Arm AArch64: minor code refactoring, and add reference scalar code to quantize routines for new quant types

* Arm AArch64: minor code refactoring

* Arm AArch64: minor code refactoring

* Arm AArch64: minor code refactoring

* rebase on the latest master commit 3fd62a6 and adapt to the new directory structure

* Arm AArch64: remove a redundant comment

* Arm AArch64: add pragma in ggml-aarch64.c to turn -Woverlength-strings warning off

* Arm AArch64: use __aarch64__ check to guard 64-bit neon kernels

* Arm AArch64: update docs/build.md README to include compile time flags for buiilding the Q4_0_4_4 quant type
@AndreasKunar
Copy link
Contributor

AndreasKunar commented Jul 14, 2024

FYI, I got the Q4__0_4_4 variant to work for me on M2 Macs, Snapdragon X / Windows 11 24H2 / clang, and on Snapdragon X / Windows 11 24H2 / WSL2-Ubuntu24.04 / gcc.

For me on the Mac, I had to -march=armv8.5-a and -D LLAMA_NO_ACCELERATE=1, otherwise it tends to segfault.

The speed-improvement results are very impressive
They show an improvement of nearly 2.5x speed for PP!!! TG is largely memory-bandwidth- and not compute-bound, so the improvements for TG are minimal. The SnapdragonX should in theory have a 33% higher memory-bandwidth than the M2, but it does not show in the TG results.

M2 MacBook Air (llama.cpp version: 3387 (fa79495) built with Apple clang version 15.0.0 (clang-1500.3.9.4) for arm64-apple-darwin23.5.0):

model size params backend threads test t/s
llama 7B Q4_0 3.56 GiB 6.74 B CPU 4 pp512 29.39 ± 0.65
llama 7B Q4_0 3.56 GiB 6.74 B CPU 4 tg128 14.29 ± 0.13
llama 7B Q4_0_4_4 3.56 GiB 6.74 B CPU 4 pp512 64.23 ± 1.24
llama 7B Q4_0_4_4 3.56 GiB 6.74 B CPU 4 tg128 16.17 ± 0.04

Snapdragon X Plus / Windows 11 24H2 / clang (lama.cpp version: 3388 (e236528) built with (clang) for aarch64-pc-windows-msvc):

model size params backend threads test t/s
llama 7B Q4_0 3.56 GiB 6.74 B CPU 10 pp512 47.35 ± 1.37
llama 7B Q4_0 3.56 GiB 6.74 B CPU 10 tg128 13.26 ± 0.74
llama 7B Q4_0_4_4 3.56 GiB 6.74 B CPU 10 pp512 116.73 ± 16.45
llama 7B Q4_0_4_4 3.56 GiB 6.74 B CPU 10 tg128 18.08 ± 3.87

Snapdragon X Plus / Windows 11 24H2 / WSL2-Ubuntu24.04 / gcc (llama.cpp version: 3388 (e236528) built with cc (Ubuntu 13.2.0-23ubuntu4) 13.2.0 for aarch64-linux-gnu):

model size params backend threads test t/s
llama 7B Q4_0 3.56 GiB 6.74 B CPU 10 pp512 47.00 ± 6.86
llama 7B Q4_0 3.56 GiB 6.74 B CPU 10 tg128 13.85 ± 2.74
llama 7B Q4_0_4_4 3.56 GiB 6.74 B CPU 10 pp512 111.52 ± 6.93
llama 7B Q4_0_4_4 3.56 GiB 6.74 B CPU 10 tg128 17.67 ± 6.07

P.S: the new aarch64-optimized GEMV and GEMM kernels for Q4_0_4_4 in ggml-aarch64.c do not work with MSVC (MSVC has no _asm_ on any 64-bit architecture) - please use clang on Windows for ARM.

@AndreasKunar
Copy link
Contributor

AndreasKunar commented Jul 14, 2024

FYI, another potential use-case for the Q4_0_4_4 optimizations is ARM virtual machines, where there is no GPU-virtualization - e.g. with Apple silicon VMs and containers.

Virtualization on macOS seems much slower than on Windows, where there is nearly no CPU-overhead for WSL2. I was able to get good acceleration with Q4_0_4_4 in M2 VMs (but starting on a bad virtualization-penalty base-line):

virtual 4-cpu Windows 11 24H2 in Parallels (clang compiler):

model size params backend threads test t/s
llama 7B Q4_0 3.56 GiB 6.74 B CPU 4 pp512 24.04 ± 1.43
llama 7B Q4_0 3.56 GiB 6.74 B CPU 4 tg128 11.66 ± 0.84
llama 7B Q4_0_4_4 3.56 GiB 6.74 B CPU 4 pp512 49.06 ± 1.50
llama 7B Q4_0_4_4 3.56 GiB 6.74 B CPU 4 tg128 12.67 ± 2.03

virtual 4-cpu Ubuntu24.04 in Parallels:

model size params backend threads test t/s
llama 7B Q4_0 3.56 GiB 6.74 B CPU 4 pp512 21.45 ± 0.39
llama 7B Q4_0 3.56 GiB 6.74 B CPU 4 tg128 12.17 ± 0.15
llama 7B Q4_0_4_4 3.56 GiB 6.74 B CPU 4 pp512 50.63 ± 1.46
llama 7B Q4_0_4_4 3.56 GiB 6.74 B CPU 4 tg128 14.01 ± 1.17

Note:

  • 4 CPUs is the optimum on M2, 5+ is slower because of e-core mix-in
  • this acceleration is faster than the performance improvements of virtualizing the GPU into the VM/Container via Vulkan/krunkit.

tc-wolf added a commit to tc-wolf/llama-cpp-python that referenced this pull request Sep 10, 2024
- scikit-build -> scikit-build-core
- Remove -DGGML_BLAS=ON, OpenBLAS cmake tags from build
  - Needed (see ggerganov/llama.cpp#5780 (review)) to get this to work properly
- Built, deployed, and tested with llama3.1 with q4_0_4_4 quantization
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
build Compilation issues documentation Improvements or additions to documentation enhancement New feature or request examples ggml changes relating to the ggml tensor library for machine learning Review Complexity : Medium Generally require more time to grok but manageable by beginner to medium expertise level
Projects
None yet
Development

Successfully merging this pull request may close these issues.