diff --git a/.github/workflows/package_wheel_release.yml b/.github/workflows/package_wheel_release.yml
index 93e5f38..f04ee07 100644
--- a/.github/workflows/package_wheel_release.yml
+++ b/.github/workflows/package_wheel_release.yml
@@ -29,11 +29,6 @@ jobs:
- { os: ubuntu-20.04, pyver: '3.12', cuda: '12.1.1', torch: '2.4.0', cudaarch: '8.0;8.6;8.7;8.9;9.0+PTX', instruct: 'AVX512', torch_cu: '121'}
- { os: ubuntu-20.04, pyver: '3.12', cuda: '12.1.1', torch: '2.4.0', cudaarch: '8.0;8.6;8.7;8.9;9.0+PTX', instruct: 'AVX2', torch_cu: '121'}
- { os: ubuntu-20.04, pyver: '3.12', cuda: '12.5.1', torch: '2.3.0', cudaarch: '8.0;8.6;8.7;8.9;9.0+PTX', instruct: 'FANCY', torch_cu: '124'}
- - { os: ubuntu-20.04, pyver: '3.12', cuda: '12.5.1', torch: '2.3.0', cudaarch: '8.0;8.6;8.7;8.9;9.0+PTX', instruct: 'AVX512', torch_cu: '124'}
- - { os: ubuntu-20.04, pyver: '3.12', cuda: '12.5.1', torch: '2.3.0', cudaarch: '8.0;8.6;8.7;8.9;9.0+PTX', instruct: 'AVX2', torch_cu: '124'}
- - { os: ubuntu-20.04, pyver: '3.12', cuda: '12.4.1', torch: '2.3.0', cudaarch: '8.0;8.6;8.7;8.9;9.0+PTX', instruct: 'FANCY', torch_cu: '124'}
- - { os: ubuntu-20.04, pyver: '3.12', cuda: '12.4.1', torch: '2.3.0', cudaarch: '8.0;8.6;8.7;8.9;9.0+PTX', instruct: 'AVX512', torch_cu: '124'}
- - { os: ubuntu-20.04, pyver: '3.12', cuda: '12.4.1', torch: '2.3.0', cudaarch: '8.0;8.6;8.7;8.9;9.0+PTX', instruct: 'AVX2', torch_cu: '124'}
- { os: ubuntu-20.04, pyver: '3.12', cuda: '12.2.2', torch: '2.3.0', cudaarch: '8.0;8.6;8.7;8.9;9.0+PTX', instruct: 'FANCY', torch_cu: '121'}
- { os: ubuntu-20.04, pyver: '3.12', cuda: '12.2.2', torch: '2.3.0', cudaarch: '8.0;8.6;8.7;8.9;9.0+PTX', instruct: 'AVX512', torch_cu: '121'}
- { os: ubuntu-20.04, pyver: '3.12', cuda: '12.2.2', torch: '2.3.0', cudaarch: '8.0;8.6;8.7;8.9;9.0+PTX', instruct: 'AVX2', torch_cu: '121'}
@@ -52,12 +47,6 @@ jobs:
- { os: ubuntu-20.04, pyver: '3.11', cuda: '12.1.1', torch: '2.4.0', cudaarch: '8.0;8.6;8.7;8.9;9.0+PTX', instruct: 'FANCY', torch_cu: '121'}
- { os: ubuntu-20.04, pyver: '3.11', cuda: '12.1.1', torch: '2.4.0', cudaarch: '8.0;8.6;8.7;8.9;9.0+PTX', instruct: 'AVX512', torch_cu: '121'}
- { os: ubuntu-20.04, pyver: '3.11', cuda: '12.1.1', torch: '2.4.0', cudaarch: '8.0;8.6;8.7;8.9;9.0+PTX', instruct: 'AVX2', torch_cu: '121'}
- - { os: ubuntu-20.04, pyver: '3.11', cuda: '12.5.1', torch: '2.3.0', cudaarch: '8.0;8.6;8.7;8.9;9.0+PTX', instruct: 'FANCY', torch_cu: '124'}
- - { os: ubuntu-20.04, pyver: '3.11', cuda: '12.5.1', torch: '2.3.0', cudaarch: '8.0;8.6;8.7;8.9;9.0+PTX', instruct: 'AVX512', torch_cu: '124'}
- - { os: ubuntu-20.04, pyver: '3.11', cuda: '12.5.1', torch: '2.3.0', cudaarch: '8.0;8.6;8.7;8.9;9.0+PTX', instruct: 'AVX2', torch_cu: '124'}
- - { os: ubuntu-20.04, pyver: '3.11', cuda: '12.4.1', torch: '2.3.0', cudaarch: '8.0;8.6;8.7;8.9;9.0+PTX', instruct: 'FANCY', torch_cu: '124'}
- - { os: ubuntu-20.04, pyver: '3.11', cuda: '12.4.1', torch: '2.3.0', cudaarch: '8.0;8.6;8.7;8.9;9.0+PTX', instruct: 'AVX512', torch_cu: '124'}
- - { os: ubuntu-20.04, pyver: '3.11', cuda: '12.4.1', torch: '2.3.0', cudaarch: '8.0;8.6;8.7;8.9;9.0+PTX', instruct: 'AVX2', torch_cu: '124'}
- { os: ubuntu-20.04, pyver: '3.11', cuda: '12.2.2', torch: '2.3.0', cudaarch: '8.0;8.6;8.7;8.9;9.0+PTX', instruct: 'FANCY', torch_cu: '121'}
- { os: ubuntu-20.04, pyver: '3.11', cuda: '12.2.2', torch: '2.3.0', cudaarch: '8.0;8.6;8.7;8.9;9.0+PTX', instruct: 'AVX512', torch_cu: '121'}
- { os: ubuntu-20.04, pyver: '3.11', cuda: '12.2.2', torch: '2.3.0', cudaarch: '8.0;8.6;8.7;8.9;9.0+PTX', instruct: 'AVX2', torch_cu: '121'}
@@ -76,12 +65,6 @@ jobs:
- { os: ubuntu-20.04, pyver: '3.10', cuda: '12.1.1', torch: '2.4.0', cudaarch: '8.0;8.6;8.7;8.9;9.0+PTX', instruct: 'FANCY', torch_cu: '121'}
- { os: ubuntu-20.04, pyver: '3.10', cuda: '12.1.1', torch: '2.4.0', cudaarch: '8.0;8.6;8.7;8.9;9.0+PTX', instruct: 'AVX512', torch_cu: '121'}
- { os: ubuntu-20.04, pyver: '3.10', cuda: '12.1.1', torch: '2.4.0', cudaarch: '8.0;8.6;8.7;8.9;9.0+PTX', instruct: 'AVX2', torch_cu: '121'}
- - { os: ubuntu-20.04, pyver: '3.10', cuda: '12.5.1', torch: '2.3.0', cudaarch: '8.0;8.6;8.7;8.9;9.0+PTX', instruct: 'FANCY', torch_cu: '124'}
- - { os: ubuntu-20.04, pyver: '3.10', cuda: '12.5.1', torch: '2.3.0', cudaarch: '8.0;8.6;8.7;8.9;9.0+PTX', instruct: 'AVX512', torch_cu: '124'}
- - { os: ubuntu-20.04, pyver: '3.10', cuda: '12.5.1', torch: '2.3.0', cudaarch: '8.0;8.6;8.7;8.9;9.0+PTX', instruct: 'AVX2', torch_cu: '124'}
- - { os: ubuntu-20.04, pyver: '3.10', cuda: '12.4.1', torch: '2.3.0', cudaarch: '8.0;8.6;8.7;8.9;9.0+PTX', instruct: 'FANCY', torch_cu: '124'}
- - { os: ubuntu-20.04, pyver: '3.10', cuda: '12.4.1', torch: '2.3.0', cudaarch: '8.0;8.6;8.7;8.9;9.0+PTX', instruct: 'AVX512', torch_cu: '124'}
- - { os: ubuntu-20.04, pyver: '3.10', cuda: '12.4.1', torch: '2.3.0', cudaarch: '8.0;8.6;8.7;8.9;9.0+PTX', instruct: 'AVX2', torch_cu: '124'}
- { os: ubuntu-20.04, pyver: '3.10', cuda: '12.2.2', torch: '2.3.0', cudaarch: '8.0;8.6;8.7;8.9;9.0+PTX', instruct: 'FANCY', torch_cu: '121'}
- { os: ubuntu-20.04, pyver: '3.10', cuda: '12.2.2', torch: '2.3.0', cudaarch: '8.0;8.6;8.7;8.9;9.0+PTX', instruct: 'AVX512', torch_cu: '121'}
- { os: ubuntu-20.04, pyver: '3.10', cuda: '12.2.2', torch: '2.3.0', cudaarch: '8.0;8.6;8.7;8.9;9.0+PTX', instruct: 'AVX2', torch_cu: '121'}
@@ -98,10 +81,6 @@ jobs:
- { os: windows-2022, pyver: '3.12', cuda: '12.2.2', torch: '2.4.0', cudaarch: '8.0;8.6;8.7;8.9;9.0+PTX', instruct: 'AVX2', torch_cu: '121'}
- { os: windows-2022, pyver: '3.12', cuda: '12.1.1', torch: '2.4.0', cudaarch: '8.0;8.6;8.7;8.9;9.0+PTX', instruct: 'AVX512', torch_cu: '121'}
- { os: windows-2022, pyver: '3.12', cuda: '12.1.1', torch: '2.4.0', cudaarch: '8.0;8.6;8.7;8.9;9.0+PTX', instruct: 'AVX2', torch_cu: '121'}
- - { os: windows-2022, pyver: '3.12', cuda: '12.5.1', torch: '2.3.0', cudaarch: '8.0;8.6;8.7;8.9;9.0+PTX', instruct: 'AVX512', torch_cu: '124'}
- - { os: windows-2022, pyver: '3.12', cuda: '12.5.1', torch: '2.3.0', cudaarch: '8.0;8.6;8.7;8.9;9.0+PTX', instruct: 'AVX2', torch_cu: '124'}
- - { os: windows-2022, pyver: '3.12', cuda: '12.4.1', torch: '2.3.0', cudaarch: '8.0;8.6;8.7;8.9;9.0+PTX', instruct: 'AVX512', torch_cu: '124'}
- - { os: windows-2022, pyver: '3.12', cuda: '12.4.1', torch: '2.3.0', cudaarch: '8.0;8.6;8.7;8.9;9.0+PTX', instruct: 'AVX2', torch_cu: '124'}
- { os: windows-2022, pyver: '3.12', cuda: '12.2.2', torch: '2.3.0', cudaarch: '8.0;8.6;8.7;8.9;9.0+PTX', instruct: 'AVX512', torch_cu: '121'}
- { os: windows-2022, pyver: '3.12', cuda: '12.2.2', torch: '2.3.0', cudaarch: '8.0;8.6;8.7;8.9;9.0+PTX', instruct: 'AVX2', torch_cu: '121'}
- { os: windows-2022, pyver: '3.12', cuda: '12.1.1', torch: '2.3.0', cudaarch: '8.0;8.6;8.7;8.9;9.0+PTX', instruct: 'AVX512', torch_cu: '121'}
@@ -114,10 +93,6 @@ jobs:
- { os: windows-2022, pyver: '3.11', cuda: '12.2.2', torch: '2.4.0', cudaarch: '8.0;8.6;8.7;8.9;9.0+PTX', instruct: 'AVX2', torch_cu: '121'}
- { os: windows-2022, pyver: '3.11', cuda: '12.1.1', torch: '2.4.0', cudaarch: '8.0;8.6;8.7;8.9;9.0+PTX', instruct: 'AVX512', torch_cu: '121'}
- { os: windows-2022, pyver: '3.11', cuda: '12.1.1', torch: '2.4.0', cudaarch: '8.0;8.6;8.7;8.9;9.0+PTX', instruct: 'AVX2', torch_cu: '121'}
- - { os: windows-2022, pyver: '3.11', cuda: '12.5.1', torch: '2.3.0', cudaarch: '8.0;8.6;8.7;8.9;9.0+PTX', instruct: 'AVX512', torch_cu: '124'}
- - { os: windows-2022, pyver: '3.11', cuda: '12.5.1', torch: '2.3.0', cudaarch: '8.0;8.6;8.7;8.9;9.0+PTX', instruct: 'AVX2', torch_cu: '124'}
- - { os: windows-2022, pyver: '3.11', cuda: '12.4.1', torch: '2.3.0', cudaarch: '8.0;8.6;8.7;8.9;9.0+PTX', instruct: 'AVX512', torch_cu: '124'}
- - { os: windows-2022, pyver: '3.11', cuda: '12.4.1', torch: '2.3.0', cudaarch: '8.0;8.6;8.7;8.9;9.0+PTX', instruct: 'AVX2', torch_cu: '124'}
- { os: windows-2022, pyver: '3.11', cuda: '12.2.2', torch: '2.3.0', cudaarch: '8.0;8.6;8.7;8.9;9.0+PTX', instruct: 'AVX512', torch_cu: '121'}
- { os: windows-2022, pyver: '3.11', cuda: '12.2.2', torch: '2.3.0', cudaarch: '8.0;8.6;8.7;8.9;9.0+PTX', instruct: 'AVX2', torch_cu: '121'}
- { os: windows-2022, pyver: '3.11', cuda: '12.1.1', torch: '2.3.0', cudaarch: '8.0;8.6;8.7;8.9;9.0+PTX', instruct: 'AVX512', torch_cu: '121'}
@@ -130,10 +105,6 @@ jobs:
- { os: windows-2022, pyver: '3.10', cuda: '12.2.2', torch: '2.4.0', cudaarch: '8.0;8.6;8.7;8.9;9.0+PTX', instruct: 'AVX2', torch_cu: '121'}
- { os: windows-2022, pyver: '3.10', cuda: '12.1.1', torch: '2.4.0', cudaarch: '8.0;8.6;8.7;8.9;9.0+PTX', instruct: 'AVX512', torch_cu: '121'}
- { os: windows-2022, pyver: '3.10', cuda: '12.1.1', torch: '2.4.0', cudaarch: '8.0;8.6;8.7;8.9;9.0+PTX', instruct: 'AVX2', torch_cu: '121'}
- - { os: windows-2022, pyver: '3.10', cuda: '12.5.1', torch: '2.3.0', cudaarch: '8.0;8.6;8.7;8.9;9.0+PTX', instruct: 'AVX512', torch_cu: '124'}
- - { os: windows-2022, pyver: '3.10', cuda: '12.5.1', torch: '2.3.0', cudaarch: '8.0;8.6;8.7;8.9;9.0+PTX', instruct: 'AVX2', torch_cu: '124'}
- - { os: windows-2022, pyver: '3.10', cuda: '12.4.1', torch: '2.3.0', cudaarch: '8.0;8.6;8.7;8.9;9.0+PTX', instruct: 'AVX512', torch_cu: '124'}
- - { os: windows-2022, pyver: '3.10', cuda: '12.4.1', torch: '2.3.0', cudaarch: '8.0;8.6;8.7;8.9;9.0+PTX', instruct: 'AVX2', torch_cu: '124'}
- { os: windows-2022, pyver: '3.10', cuda: '12.2.2', torch: '2.3.0', cudaarch: '8.0;8.6;8.7;8.9;9.0+PTX', instruct: 'AVX512', torch_cu: '121'}
- { os: windows-2022, pyver: '3.10', cuda: '12.2.2', torch: '2.3.0', cudaarch: '8.0;8.6;8.7;8.9;9.0+PTX', instruct: 'AVX2', torch_cu: '121'}
- { os: windows-2022, pyver: '3.10', cuda: '12.1.1', torch: '2.3.0', cudaarch: '8.0;8.6;8.7;8.9;9.0+PTX', instruct: 'AVX512', torch_cu: '121'}
@@ -219,6 +190,11 @@ jobs:
$env:CUDA_PATH = "$env:CUDA_PATH/Library"
$env:CUDA_HOME = $env:CUDA_PATH
$env:PATH = "$env:CUDA_PATH/bin;" + $env:PATH
+ $directory = "$env:CUDA_PATH/lib/x64/"
+ if (-not (Test-Path -Path $directory)) {
+ New-Item -ItemType Directory -Path $directory
+ Write-Output "Directory '$directory' created."
+ }
cp $env:CUDA_PATH/lib/*.lib $env:CUDA_PATH/lib/x64/
$env:INCLUDE =$env:CUDA_PATH + "/include/targets/x64;" + $env:INCLUDE
diff --git a/.gitignore b/.gitignore
index 1bb8666..5d72e80 100644
--- a/.gitignore
+++ b/.gitignore
@@ -17,4 +17,5 @@ compile_commands.json
*dist/
ktransformers/server/local_store/
ktransformers/server_test1.db
-*.patch
\ No newline at end of file
+*.patch
+local_chat_djw.py
\ No newline at end of file
diff --git a/README.md b/README.md
index f04a159..a3a6792 100644
--- a/README.md
+++ b/README.md
@@ -1,18 +1,17 @@
+https://github.com/user-attachments/assets/a865e5e4-bca3-401e-94b8-af3c080e6c12
+
+* **1M Context InternLM 2.5 7B**: Operates at full bf16 precision, utilizing 24GB VRAM and 150GB DRAM, which is feasible on a local desktop setup. It achieves a 92.88% success rate on the 1M "Needle In a Haystack" test and 100% on the 128K NIAH test.
+
+
Click To Show how to run other examples
-
* Qwen2-57B
```sh
@@ -208,6 +249,7 @@ python -m ktransformers.local_chat --model_name Qwen/Qwen2-57B-A14B-Instruct --g
```
* DeepseekV2
+
```sh
mkdir DeepSeek-V2-Chat-0628-GGUF && cd DeepSeek-V2-Chat-0628-GGUF
# Download weights
@@ -221,8 +263,11 @@ cd ..
python -m ktransformers.local_chat --model_name deepseek-ai/DeepSeek-V2-Chat-0628 --gguf_path ./DeepSeek-V2-Chat-0628-GGUF
# If you see “OSError: We couldn't connect to 'https://huggingface.co' to load this file”, try:
+
# GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/deepseek-ai/DeepSeek-V2-Chat-0628
+
# python -m ktransformers.local_chat --model_path ./DeepSeek-V2-Chat-0628 --gguf_path ./DeepSeek-V2-Chat-0628-GGUF
+
```
| model name | weights download link |
@@ -245,11 +290,15 @@ Start without website:
```sh
ktransformers --model_path deepseek-ai/DeepSeek-V2-Lite-Chat --gguf_path /path/to/DeepSeek-V2-Lite-Chat-GGUF --port 10002
```
+
Start with website:
+
```sh
ktransformers --model_path deepseek-ai/DeepSeek-V2-Lite-Chat --gguf_path /path/to/DeepSeek-V2-Lite-Chat-GGUF --port 10002 --web True
```
+
Or you want to start server with transformers, the model_path should include safetensors
+
```bash
ktransformers --type transformers --model_path /mnt/data/model/Qwen2-0.5B-Instruct --port 10002 --web True
```
@@ -264,10 +313,9 @@ Access website with url [http://localhost:10002/web/index.html#/chat](http://loc
More information about the RESTful API server can be found [here](doc/en/api/server/server.md). You can also find an example of integrating with Tabby [here](doc/en/api/server/tabby.md).
-
📃 Brief Injection Tutorial
At the heart of KTransformers is a user-friendly, template-based injection framework.
-This allows researchers to easily replace original torch modules with optimized variants. It also simplifies the process of combining multiple optimizations, allowing the exploration of their synergistic effects.
+This allows researchers to easily replace original torch modules with optimized variants. It also simplifies the process of combining multiple optimizations, allowing the exploration of their synergistic effects.
diff --git a/doc/assets/Framework_effect.png b/doc/assets/Framework_effect.png
new file mode 100644
index 0000000..fbf1be1
Binary files /dev/null and b/doc/assets/Framework_effect.png differ
diff --git a/doc/assets/InfLLM_equation.jpg b/doc/assets/InfLLM_equation.jpg
new file mode 100644
index 0000000..54d236c
Binary files /dev/null and b/doc/assets/InfLLM_equation.jpg differ
diff --git a/doc/assets/InfLLM_framework.png b/doc/assets/InfLLM_framework.png
new file mode 100644
index 0000000..a2d78b2
Binary files /dev/null and b/doc/assets/InfLLM_framework.png differ
diff --git a/doc/assets/KTransformers_long_context_v1.png b/doc/assets/KTransformers_long_context_v1.png
new file mode 100644
index 0000000..2aeea8a
Binary files /dev/null and b/doc/assets/KTransformers_long_context_v1.png differ
diff --git a/doc/assets/KTransformers_long_context_v2.png b/doc/assets/KTransformers_long_context_v2.png
new file mode 100644
index 0000000..6273ac5
Binary files /dev/null and b/doc/assets/KTransformers_long_context_v2.png differ
diff --git a/doc/assets/Quest_framework.png b/doc/assets/Quest_framework.png
new file mode 100644
index 0000000..4654e4d
Binary files /dev/null and b/doc/assets/Quest_framework.png differ
diff --git a/doc/assets/SnapKV_framework.png b/doc/assets/SnapKV_framework.png
new file mode 100644
index 0000000..a192a1d
Binary files /dev/null and b/doc/assets/SnapKV_framework.png differ
diff --git a/doc/assets/SparQ_attention.png b/doc/assets/SparQ_attention.png
new file mode 100644
index 0000000..de3c829
Binary files /dev/null and b/doc/assets/SparQ_attention.png differ
diff --git a/doc/assets/internlm_memory.png b/doc/assets/internlm_memory.png
new file mode 100644
index 0000000..be6de92
Binary files /dev/null and b/doc/assets/internlm_memory.png differ
diff --git a/doc/assets/long_context_generate.png b/doc/assets/long_context_generate.png
new file mode 100644
index 0000000..0e54fbc
Binary files /dev/null and b/doc/assets/long_context_generate.png differ
diff --git a/doc/assets/long_context_prefill.png b/doc/assets/long_context_prefill.png
new file mode 100644
index 0000000..271e6c4
Binary files /dev/null and b/doc/assets/long_context_prefill.png differ
diff --git a/doc/assets/needle_128K.png b/doc/assets/needle_128K.png
new file mode 100644
index 0000000..9c26a02
Binary files /dev/null and b/doc/assets/needle_128K.png differ
diff --git a/doc/assets/needle_1M.png b/doc/assets/needle_1M.png
new file mode 100644
index 0000000..b67e8a7
Binary files /dev/null and b/doc/assets/needle_1M.png differ
diff --git a/doc/en/long_context_tutorial.md b/doc/en/long_context_tutorial.md
new file mode 100644
index 0000000..e11467d
--- /dev/null
+++ b/doc/en/long_context_tutorial.md
@@ -0,0 +1,316 @@
+# KVCache Long Context
+
+## TL;DR
+
+Training larger models and supporting longer text sequences are currently the two most widely agreed-upon directions toward achieving AGI. After lowering the barrier for local inference with trillion-parameter MoE models, the second showcase scenario for KTransformers is reducing the inference barrier for ultra-long context sequences. Recently, both ChatGLM and InternLM have released open-source models supporting 1M tokens of context. This article will use InternLM2.5-7B-Chat-1M as an example to introduce a method that leverages the sparsity of attention to accelerate long-text inference on heterogeneous CPU/GPU systems.
+
+After optimization, KTransformers has achieved native-precision inference for 128K and even 1M tokens of context on a single 24GB GPU with CPU/DRAM support. In the 128K context scenario, the generation speed is 7.1 times faster than llama.cpp, while also achieving 100% accuary on relatively simple test sets like "needle in haystack" and "passkey". On the more challenging dataset kvretrieval, through flexible framework configurations, we can achieve a **6.22x speedup** during inference while obtaining even higher scores than running the original model directly (**21.2 -> 24.4**). In the 1M context scenario on a single 24GB GPU, KTransformers can similarly achieve a 16 tokens/s inference speed, nearly 10 times faster than llama.cpp under the same conditions, with the "needle in haystack" evaluation score even surpassing the original model (**89.31 -> 92.88**).
+
+Project url: https://github.com/kvcache-ai/ktransformers
+
+## Mathematical Principle: The computational overhead of long-text inference and the sparsity in Attention caused by Softmax.
+
+As the demand for longer context windows increases, not only have commercial large models like Kimi and Claude/Gemini started supporting increasingly longer context windows, but open-source models have also begun to catch up. Notably, both ChatGLM 4 and InternLM 2.5 have released versions that are under 10 billion parameters but support up to 1 million tokens of context. However, despite the relatively small size of these models, the enormous KVCache required for such ultra-long contexts still prevents local users from practically running these models. As shown in the figure below, while the InternLM2.5-7B-Chat-1M model weights only require 15.49GB of GPU memory, an additional 145.49GB is needed to store the entire 1M-token KVCache, which is clearly beyond the memory capacity of local users. Even when using the KVCache Offload feature of llama.cpp to offload the KVCache to CPU/DRAM, barely making the model runnable, performance remains unacceptable due to the need to fully scan the entire KVCache each time a single token is generated.
+
+|
|
|
+| ------------------------------------------------------------------------------------ | -------------------------------------------------------------------------- |
+
+Fortunately, many studies have noticed that attention distribution during the inference phase tends to be **sparse**. For example, the right figure shows SparQ's experimental statistics based on LLaMa 7B, where less than 1% of tokens in a 3k context have relatively high attention scores. Similar conclusions are not only reflected in many other papers, such as H2O, Quest, InfLLM, and SnapKV, but we have also further validated this through long-text experiments with InternLM 2.5-7B-1M. Although the proportion isn't as extreme as 1%, due to the inherent head-focused effect of the softmax operation in attention mechanisms, it is theoretically possible that if we can identify in advance which tokens have high attention scores, scanning less than 5% of the tokens would suffice to essentially replicate the original result.
+
+Thus, the problem narrows down to how to quickly identify these tokens with high attention scores without scanning them all. In the following sections, we will first briefly survey several key related papers, then summarize and propose a general framework we designed and implemented within KTransformers—a highly efficient sparse attention operator for CPUs.
+
+## Related Papers and Conclusions
+
+### Prune or Retrieval?
+
+Based on the aforementioned points, we studied papers from recent years related to sparse selection in KVCache. The earliest of these is the paper H2O, which suggested that the attention distribution during inference is sparse and that only 5% of the KVCache is needed during inference. Following this, a series of works built on H2O's approach by designing more complex methods for selecting tokens that perform better in different scenarios. These methods are quite reasonable for single-word inference. However, as we previously explored in the Mooncake project, **we believe that the future trend is to precompute reusable KVCache as much as possible, and then use it to answer different questions.** This "compute once, use many" approach aims to reduce computational costs. Therefore, with this goal in mind, we prefer not to delete any tokens from the KVCache, or at least not remove a significant portion of them, to ensure that different questions can focus on different parts of the context in the future.
+
+![InfLLM Framework](../assets/InfLLM_framework.png)
+
+We further investigated related research, among which InfLLM proposed a very promising framework. Not only does it recognize that attention is sparse, but it also suggests that overly long contexts can cause attention to be dispersed into irrelevant noise, thereby reducing the model's ability to focus on key information. To address this issue, InfLLM introduces an external memory module (Memory Units) to store the context's KVCache. In each computation step, the most relevant semantic information is retrieved from this external memory module to participate in the calculation, thus enhancing the model's ability to handle long-context inference.
+
+Specifically, InfLLM organizes the external memory module using semantic blocks composed of neighboring tokens and employs a sliding window mechanism during computation. In each step, it selects only the semantic blocks at the head of the context (Initial Tokens), the blocks near the current token (Local Tokens), and a few blocks with the highest semantic similarity to the current token to participate in the attention calculation. As shown in equation 1, to efficiently retrieve the blocks with the highest similarity, InfLLM selects a few representative tokens whose scores $$r_m
+$$ are the highest within each block. Use Equation 2 to calculate the semantic similarity between the current token and each semantic block.
+
+![InfLLM Equation](../assets/InfLLM_equation.jpg)
+
+Compared to the previously mentioned H2O, the differences in InfLLM are as follows:
+
+1. The KVCache is not discarded but stored in memory and dynamically loaded onto the GPU during inference.
+
+2. KVCache is managed at the granularity of blocks rather than tokens, with each block selecting a few tokens as its representative index tokens.
+
+InfLLM's proposed method aligns with our "compute once, use many" approach of reusing KVCache. The external memory units in this method can be offloaded to CPU/DRAM or even SSD storage, allowing different parts to be selected for computation based on the specific question. This significantly improves the efficiency of attention computation.
+
+### Other Improvements
+
+Similarly, after InfLLM, Quest also manages tokens at the granularity of blocks. Quest analyzed the recall rate of key tokens in H2O and full attention, finding that the Top-10 attention score token recall rate for the H2O algorithm is around 50%, which indicates that too much key information was lost. To improve the recall rate of key tokens, Quest chooses two "representative tokens" from each block for retrieval. In the prefill stage, each KVCache block records the maximum and minimum values for each channel, as shown in the figure below under "Reduced Keys," which contains the element-wise min key and element-wise max key.
+
+During the attention computation stage, the dot product is computed between the current query vector and the max key and min key of each KVCache block, respectively. Then, for each channel, the maximum value between the two resulting product vectors is selected and summed to serve as the upper bound of the relevance score for that KVCache block, as shown in stage 1 of the diagram. Based on the relevance scores, the top-k KVCache blocks are selected to participate in the attention computation, as illustrated in stage 2 of the diagram.
+
+![Quest Framework](../assets/Quest_framework.png)
+
+Compared to InfLLM, Quest does not take heterogeneous architectures into account. Instead, it assumes that all KVCache can still fit into memory, simply leveraging sparse attention to accelerate the inference process. Ultimately, Quest achieves a 7.03x speedup in attention computation and a 2.23x improvement in end-to-end inference latency.
+
+Going further, SnapKV proposes retaining two parts of the tokens during the prefill stage, as shown in the diagram below with the orange and green segments. The difference from InfLLM lies only in the method of selecting the middle tokens. SnapKV selects tokens at the token level rather than the block level, with the score calculation being similar to H2O, i.e., $$softmax(\frac{qk^T}{\sqrt{d_k}})$$. However, when summing across columns, only the rows within the final green window are selected for computation, corresponding to the Local Tokens section in InfLLM. Additionally, SnapKV introduces a pooling operation on top of attention, which the paper explains as ensuring that the recalled tokens retain more complete semantic information.
+
+This approach in SnapKV involves a one-time selection during the inference phase, after which only the selected tokens are used for attention computation, while the rest of the KVCache is discarded.
+
+![SnapKV Framework](../assets/SnapKV_framework.png)
+
+
+Other related papers include PyramidKV, which observed that attention scores exhibit a pyramid-shaped distribution across attention layers. In lower attention layers, attention is widely distributed, while in higher layers, the attention scores for a few key tokens become increasingly prominent. Therefore, PyramidKV allocates more KVCache storage space to lower layers and less space to higher layers.
+
+MagicPiG, based on Locality-Sensitive Hashing (LSH), proposes a dynamic KVCache management strategy. First, it uses SnapKV to select a portion of important tokens to be stored in the GPU, while the KVCache of other tokens is placed in memory. By leveraging the high efficiency of LSH in high-dimensional space searches and the multithreading capabilities of CPUs, MagicPiG retrieves KVCache from memory that is similar to the current query and loads it into memory for inference. Compared to the earlier methods like InfLLM, Quest, and SnapKV, MagicPiG does not need to scan all representative tokens and select the top-k KVCache. Instead, it utilizes the mathematical properties of LSH, which not only simulates attention scores but also allows for identifying important KVCache with low overhead and high speed.
+
+The above are just descriptions of some key points. For more detailed explanations, you can refer to the existing articles on Zhihu in Chinese:
+
+- https://zhuanlan.zhihu.com/p/701580870
+
+- https://zhuanlan.zhihu.com/p/714288577
+
+## KTransformers CPU Sparse Attn Framework
+
+### Framework Prototype
+
+Based on the introduction of the above papers, we have distilled the following key points:
+
+- The distribution of attention weights is sparse, and useless KVCache may introduce noise, which could actually reduce performance during the inference stage.
+
+- For the KVCache eviction strategy during the inference stage, the common approach is to retain the tokens from the beginning and the end of the prompt, while designing algorithms to select the tokens from the middle portion. One of the main factors affecting the model's performance is the ability to accurately identify the key tokens.
+
+- Managing the middle portion of tokens in blocks can improve memory swapping and attention computation efficiency, and smaller blocks do not seem to perform worse than token-level granularity.
+
+- The tokens that each attention layer focuses on during inference differ, and even the allocated KVCache capacity for different layers should vary.
+
+Based on these insights and inspirations, we developed a general framework for implementing sparse CPU attention operators during the inference phase. In the prefill stage, we use chunked prefill, loading only one layer of KVCache into GPU memory at a time for computation. Once completed, the KVCache is stored on CPU/DRAM. In the subsequent decode stage, instead of swapping KVCache in and out, the sparse attention operator runs directly on the CPU. **This significantly reduces the minimum** **GPU** **memory requirements, making local 128K or even 1M token contexts possible.**
+
+Specifically during the generation phase, we implemented the entire framework as shown in the diagram below.
+
+![KTransformers long congtext v1](../assets/KTransformers_long_context_v1.png)
+
+We organized the KVCache in units of blocks. Specifically:
+
+- **KVCache Partitioning:** A complete input prompt is divided into three configurable parts: Initial, Context, and Local. During the computation process, the Initial/Local parts will be fully attended to, while the Context part will be sparsely retrieved. This approach is based on findings from many papers (such as streamingLLM and Minference) which mention the existence of "attention sinks," where higher attention weights are often found at the beginning and the end of the sequence.
+
+- **Context Block Partitioning:** For the middle Context, we follow the InfLLM approach by dividing it into blocks based on a configurable fixed number of tokens. Each block can select 1 to k tokens as its representative tokens. During the actual inference phase, the Context blocks that require attention are selected based on these representative tokens.
+
+ - Specifically, we have implemented the following methods for selecting representative tokens, based on the approaches outlined in various papers.
+
+ - Max: The maximum values of multiple tokens within a block, across each channel, are concatenated to form the representative token for the current block.
+
+ - Mean: The average values of multiple tokens within a block, across each channel, are concatenated to form the representative token for the current block.
+
+ - Quest: A combination of the previous two methods: the maximum and minimum values of multiple tokens within a block, across each channel, are taken as the representative tokens for the block. Under this method, the number of representative tokens is fixed at 2
+
+ - Dynamic: By calculating the cumulative attention score for each token using a specific method, each block selects the top-k tokens with the highest scores as the representative tokens for the block. This is similar to InfLLM but with some simplifications.
+
+ - Fix: Select tokens at fixed intervals within the block.
+
+ - Once the representative tokens for each block are determined, use Equation 2 from InfLLM to calculate the similarity between the input X and the k representative tokens of each block B, and only select the top $$r_k$$ blocks for attention computation, where $$l_P $$ represents the length of the historical tokens:
+
+Since InfLLM requires calculating a representative score for each token during the prefill stage and then selecting a representative token for each block based on these scores, this operation involves invasive modifications to the prefill implementation, making it difficult to integrate with other methods. Furthermore, in actual testing, we found that in most scenarios, similar or even better results can be achieved through a combination of other methods. Therefore, we ultimately decided not to integrate this method into the framework.
+
+## Further Optimizations
+
+After implementing the above framework, we conducted a series of evaluations based on LongBench and InfiniteBench.
+
+At the beginning of the experiment, we designed the architecture so that for each inference token, the most relevant KVCache blocks would be reselected. On the one hand, this strategy incurred significant overhead during the retrieval process. On the other hand, we found that in some scenarios, f**requently changing the selection of retrieved blocks did not lead to better results**. For example, in the kvretrieval dataset, we observed that the model's responses were often correct in the first half but incorrect in the second half. Since the answers to kvretrieval questions consist of long and meaningless strings, this indicates that the correct KVCache blocks were selected during the inference of the earlier tokens but incorrect blocks were chosen during the later stages of inference.
+
+To address this issue, we further integrated the method proposed in SnapKV. Before starting the inference, we preselect relevant KVCache blocks by analyzing the attention scores of the context tokens, based on the question. During the subsequent inference stages, the selection of KVCache blocks is restricted to this preselected range. This approach allowed us to select the block containing the correct answer 100% of the time in the kvretrieval dataset.
+
+However, it should be noted that this method strictly relies on the structure of the Benchmark Prompt and **does not necessarily guarantee optimal performance in other scenarios, such as complex document understanding and generation tasks.** Therefore, we have integrated it into our framework as an optional module. The final framework and configurable parameters are as follows:
+
+![KTransformers long congtext v2](../assets/KTransformers_long_context_v2.png)
+
+
+Configuration:
+
+- **threads_num:** Number of CPU Threads
+
+- **block_size:** KVCache Block Size
+
+- **local_windows_len:** Prompt End Window Size
+
+- **preselect_block_count:** Number of Preselected Blocks
+
+- **second_block_count:** Number of Blocks Selected After Preselection
+
+- **preselect_block:** Whether to Enable Preselection
+
+- **token_step:** Interval Between Token Selections for KVCache
+
+- **layer_step:** Interval Between Layer Selections for KVCache
+
+- **dense_layer_num:** Number of Initial Layers Without KVCache Selection, Importing All KVCache
+
+- **head_select_mode:SEPARATE**(In the GQA scenario, each kv_head is selected separately) / **SHARED:** (All kv_heads are selected together)
+
+- **representative_type:** Method of Selecting Representative Tokens
+
+- **representative_num:** Number of Representative Tokens
+
+By modifying configuration options, various KVCache eviction or compression methods can be easily reproduced within our framework. For example:
+
+- Setting `block_size` to 1 and `preselect_block` to True results in a version of SnapKV without the pooling operation.
+
+- Setting `representative_type` to Quest, `preselect_block` to False, and `head_select_mode` to SEPARATE replicates the Quest method.
+
+Below is the pseudocode for the framework:
+
+```python
+def preselect_block(local_q, kvcache):
+ key_states = kvcache.keycache
+ attn_scores = torch.matmul(
+ local_q, key_states.transpose(2, 3)
+ ) / math.sqrt(head_dim)
+ attn_scores += attn_mask
+ attn_scores = nn.functional.softmax(
+ attn_scores, dim=-1, dtype=torch.float32
+ ).to(query_states.dtype)
+ vote = attn_scores[..., initial_size:-local_size:, :].sum(dim=-2)
+ pool_vote = pool1d(vote, kernel_size=kernel_size, padding=kernel_size//2, stride=1)
+ indices = pool_vote.topk(max_capacity_prompt - local_size, dim=-1).indices
+ kv_cache_block_indices = find_representative_tokens_block(indices)
+ kvcache_after_preselected = kvcache[kv_cache_block_indices]
+ ...
+ return kvcache_after_preselected
+def get_representative_tokens():
+ Calculate the representative token for each block based on the representative_type.
+ return ...
+def decode_attention(query, key, value):
+ # Select once every token_steps tokens.
+ token_steps = 4
+ # Select once every layer_steps layers.
+ layer_steps = 4
+ for token_idx in range(max_new_tokens):
+ for layer_idx in range(config.num_hidden_layers):
+ if token_idx % token_steps != 0 or layer_idx % layer_steps != 0:
+ # If the attention of the current layer in this round does not require reselection, the historical selection results from the kvcache will be retained.
+ kvcache_after_retrieval = history_kvcache_after_retrieval[layer_idx//layer_steps]
+ else:
+ # Otherwise, use the query from the current round's current layer to reselect the kvcache.
+ kvcache_after_retrieval = retrieval_kvcache(query, kvcache)
+ # Save it to the kvcache historical selection results.
+ history_kvcache_after_retrieval[layer_idx//layer_steps] = kvcache_after_retrieval
+ # calculate attention
+ output = attn(query, kvcache_after_retrieval)
+ yield output
+
+# Model prefill, if preselection is required, local_q still needs to be saved.
+local_q, KVCache = model.prefill(input_ids)
+if preselect_block:
+ # Preselection round
+ KVCache = preselect_block(local_q, kvcache)
+# Find the representative token for each block.
+block_representative_tokens = get_representative_tokens(
+ kvcache,
+ config.representative_type
+)
+
+# model generate
+'''
+'''
+decode_attention(query, key, value)
+'''
+'''
+```
+
+## Experiment
+
+At the beginning of testing, we will use the following basic configuration, which will be further optimized through the extended framework.
+
+```python
+max_seq_len: 256000 # KVCache length
+block_size: 128 # KVCache block size
+local_windows_len: 4096 # The KVCache of length local_windows_len is stored on the GPU.
+second_block_count: 96 # After preselection, each time select the number of KVCache blocks. If >= preselect_block_count, use the preselected blocks.
+threads_num: 64 # CPU thread num
+representative_type: DYNAMIC # KVCache block representative token selection method.
+kv_type: FP16
+dense_layer_num: 0 # The first few layers do not need to fill or select KVCache
+representative_num: 1 # The number of representative tokens within a KVCache block.
+preselect_block: False # Whether to preselect.
+head_select_mode: SHARED # All kv_heads jointly select.
+preselect_block_count: 0 # Number of preselected blocks.
+layer_step: 1 # Select every few layers.
+token_step: 1 # Select every few tokens.
+```
+
+Under our framework, the comparison between the original model and KTransformers after acceleration on datasets such as 128K Big Needle-in-a-Haystack, passkey, kvretrieval, etc., is as follows. The passkey dataset involves inserting a small segment of numbers at varying depths within a redundant text. kvretrieval is about finding a matching item in randomly generated key-value pairs. All tests were conducted under the opencompass framework:
+
+![needle_128K.png](../assets/needle_128K.png)
+
+| | | | |
+| ----------------------------------------------------------- | ------------------------------- | ------- | ----------- |
+| | Single needle retrieval zh 128k | passkey | kvretrieval |
+| Original model | 99.89 | 100 | 21.0 |
+| KTransformers (reselect KVCache blocks for each generation) | 100 | 100 | 15.40 |
+
+We can see that both the original model and the accelerated KTransformers achieve perfect scores on the relatively simpler datasets, such as Single Needle Retrieval and passkey. At the same time, the generation speed has significantly improved, increasing from 4.86 tokens/s with llama.cpp to 27.49 tokens/s with KTransformers, achieving up to a 5.65x speedup. Although the current configuration shows a noticeable drop in performance on the more challenging kvretrieval dataset, in the next section, we will address this by implementing a more optimized selection strategy to compensate for or even surpass the original model's accuracy.
+
+Additionally, we tested the performance of the KTransformers-based configuration framework in reproducing the results of Quest. However, since InternLM2.5-7B-Chat-1M uses GQA (Grouped Query Attention) while the Quest paper primarily focuses on optimizing MHA (Multi-Head Attention) models, the actual testing results were not particularly favorable. The official team also mentioned that further support for GQA models is needed, so we will not discuss this in detail for now.
+
+### Further improve performance
+
+By modifying certain configurations within our flexible framework on the basis of reproduction, **we can actually achieve better results than those reported in the previous paper,** as shown in the figure below:
+
+![](../assets/Framework_effect.png)
+
+As mentioned earlier, the goal of the kvretrieval dataset is to find a matching key-value pair within a long sequence of semantically meaningless pairs. If tokens are generated by reselecting based on the current query each time, the likelihood of deviation increases as the text grows, leading to the selection of different KVCache blocks compared to previous selections. To address this, we introduced a preselection mechanism using SnapKV to calculate the method for selecting representative tokens, which preselects a portion of the KVCache blocks. During the subsequent inference process, the selection is limited to these blocks. After one round of preselection, the score increased from 15.4 to 24.2, **surpassing the original model + full attention's performance of 21 points.** Further research indicates that the sparsity effect of the KVCache in the first few layers of LLMs is not as significant. Therefore, we set the first two layers to fully reuse the KVCache, ultimately achieving a score of **24.4**.
+
+Similarly, when testing the needle-in-a-haystack task on the 1M dataset, we not only reproduced the original model's reported score but also further improved accuracy (**from 89.31 to 92.88**) by using the KTransformers CPU Sparse Attn Framework to selectively compute only certain KVCache blocks. Additionally, the inference speed **reached nearly 10 times that of llama.cpp**.
+
+![needle 1M.png](../assets/needle_1M.png)
+
+### More comparisons
+
+As shown in the two figures below, using the Single Needle Retrieval dataset as an example, we set llama.cpp to store the KVCache on CPU/DRAM while performing all computations on the GPU. On a 4090D server, we compared the KTransformers CPU Sparse Attn Framework with llama.cpp. While maintaining **100% answer accuracy**, we achieved a 20.6 to 94.1 times prefill speed increase and a **1.2 to 7.1 times inference speed boost**.
+
+| ![long context prefill.png](../assets/long_context_prefill.png) | ![long context generate.png](../assets/long_context_generate.png) |
+| --------------------------------------------------------------- | ----------------------------------------------------------------- |
+
+The main reason for the significant gap in prefill speed is that after enabling KVCache offload, llama.cpp performs the attention (attn) computation on the CPU. In long-text scenarios, attention not only requires heavy computation but also takes up the majority of the computation time. In contrast, KTransformers leverages a flexible template injection framework to implement GPU Chunk Prefill layer by layer. Moving forward, we plan to further integrate high-performance sparse prefill methods such as MInference to boost speed even further.
+
+Additionally, as a key focus of this article, the right-hand graph shows that as the prompt length increases, the inference speed of KTransformers remains stable, hovering near a horizontal line. In contrast, llama.cpp slows down as the prompt length increases. By selecting only the most important 16K KVCache blocks to participate in the inference computation, KTransformers maintains a consistent inference speed comparable to llama.cpp when processing a 16K prompt, without any performance degradation (at least on these test datasets).
+
+## How to Use
+
+Currently, long context is only supported by our **local_chat.py** interface, and the integration with the server interface is under development.
+
+To facilitate user management, we have uploaded the model config, gguf, and tokenizer to a repo. URL: https://huggingface.co/nilv234/internlm2_5_to_llama_1m/tree/main
+
+By setting the model_path and gguf_path in the local_chat function to **/path/to/repo** and setting the mode to **"long_context"**, you can use the InternLM2.5-7B-Chat-1M model with 1m functionality on a 24G VRAM.
+
+After running local_chat.py for the first time, a config.yaml file will be automatically created under ** ~/.ktransformers**. The relevant configurations for long context are as follows:
+
+```python
+chunk_size: 4096 # prefill chunk size
+max_seq_len: 100000 # KVCache length
+block_size: 128 # KVCache block size
+local_windows_len: 4096 # The KVCache of length local_windows_len is stored on the GPU.
+second_select_num: 96 # After preselection, each time select the number of KVCache blocks. If >= preselect_block_count, use the preselected blocks.
+threads_num: 64 # CPU thread num
+anchor_type: DYNAMIC # KVCache block representative token selection method.
+kv_type: FP16
+dense_layer_num: 0 # The first few layers do not need to fill or select KVCache
+anchor_num: 1 # The number of representative tokens within a KVCache block.
+preselect_block: False # Whether to preselect.
+head_select_mode: SHARED # All kv_heads jointly select.
+preselect_block_count: 96 # Number of preselected blocks.
+layer_step: 1 # Select every few layers.
+token_step: 1 # Select every few tokens.
+```
+
+The memory required for different context lengths is shown in the table below:
+
+| | 4K | 32K | 64K | 128K | 512K | 1M |
+| -------------- | --- | ---- | ---- | ---- | ---- | ------ |
+| DRAM Size (GB) | 0.5 | 4.29 | 8.58 | 17.1 | 68.7 | 145.49 |
+
+Please choose an appropriate max_seq_len based on your DRAM size.
+For example:
+```python
+python local_chat.py --model_path="/data/model/internlm2_5_to_llama_1m" --gguf_path="/data/model/internlm2_5_to_llama_1m" --max_new_tokens=500 --cpu_infer=10 --use_cuda_graph=True --mode="long_context" --prompt_file="/path/to/file"
+```
+
diff --git a/ktransformers/__init__.py b/ktransformers/__init__.py
index 48fef32..a833c84 100644
--- a/ktransformers/__init__.py
+++ b/ktransformers/__init__.py
@@ -1 +1,11 @@
-__version__ = "0.1.2"
\ No newline at end of file
+#!/usr/bin/env python
+# coding=utf-8
+'''
+Description :
+Author : kkk1nak0
+Date : 2024-08-15 07:34:46
+Version : 1.0.0
+LastEditors : chenxl
+LastEditTime : 2024-08-28 15:19:03
+'''
+__version__ = "0.1.3"
diff --git a/ktransformers/configs/config.yaml b/ktransformers/configs/config.yaml
index 3a4816c..4078e24 100644
--- a/ktransformers/configs/config.yaml
+++ b/ktransformers/configs/config.yaml
@@ -34,4 +34,20 @@ web:
open_cross_domain: True
ext:
- cpu_infer: 10
\ No newline at end of file
+ cpu_infer: 10
+
+long_context:
+ chunk_size: 4096
+ max_seq_len: 32000
+ block_size: 128
+ local_windows_len: 4096
+ second_select_num: 32
+ anchor_type: DYNAMIC
+ kv_type: FP16
+ dense_layer_num: 2
+ anchor_num: 1
+ preselect_block: True
+ head_select_mode: SHARED
+ preselect_block_count: 32
+ layer_step: 1
+ token_step: 100
\ No newline at end of file
diff --git a/ktransformers/ktransformers_ext/CMakeLists.txt b/ktransformers/ktransformers_ext/CMakeLists.txt
index e6e0518..1ef9823 100644
--- a/ktransformers/ktransformers_ext/CMakeLists.txt
+++ b/ktransformers/ktransformers_ext/CMakeLists.txt
@@ -1,6 +1,7 @@
cmake_minimum_required(VERSION 3.16)
project(cpuinfer_ext VERSION 0.1.0)
+
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3 -ffast-math")
set(CMAKE_BUILD_TYPE "Release")
@@ -215,7 +216,8 @@ aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR} SOURCE_DIR1)
aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/cpu_backend SOURCE_DIR2)
aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/operators/llamafile SOURCE_DIR3)
aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/../../third_party/llamafile SOURCE_DIR4)
-set(ALL_SOURCES ${SOURCE_DIR1} ${SOURCE_DIR2} ${SOURCE_DIR3} ${SOURCE_DIR4})
+aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/operators/kvcache SOURCE_DIR5)
+set(ALL_SOURCES ${SOURCE_DIR1} ${SOURCE_DIR2} ${SOURCE_DIR3} ${SOURCE_DIR4} ${SOURCE_DIR5})
message(STATUS "ALL_SOURCES: ${ALL_SOURCES}")
pybind11_add_module(${PROJECT_NAME} MODULE ${ALL_SOURCES})
@@ -223,5 +225,8 @@ target_link_libraries(${PROJECT_NAME} PRIVATE llama)
if(WIN32)
target_link_libraries(${PROJECT_NAME} PRIVATE "$ENV{CUDA_PATH}/lib/x64/cudart.lib")#CUDA::cudart
elseif(UNIX)
+ if(NOT DEFINED ENV{CUDA_HOME} OR "$ENV{CUDA_HOME}" STREQUAL "")
+ set(ENV{CUDA_HOME} "/usr/local/cuda")
+ endif()
target_link_libraries(${PROJECT_NAME} PRIVATE "$ENV{CUDA_HOME}/lib64/libcudart.so")
-endif()
\ No newline at end of file
+endif()
diff --git a/ktransformers/ktransformers_ext/bench/bench_attention.py b/ktransformers/ktransformers_ext/bench/bench_attention.py
new file mode 100644
index 0000000..1b8b9b8
--- /dev/null
+++ b/ktransformers/ktransformers_ext/bench/bench_attention.py
@@ -0,0 +1,178 @@
+#!/usr/bin/env python
+# coding=utf-8
+"""
+Description :
+Author : Jianwei Dong
+Date : 2024-08-28 10:32:05
+Version : 1.0.0
+LastEditors : Jianwei Dong
+LastEditTime : 2024-08-28 10:32:05
+Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
+"""
+import os, sys
+import time
+
+sys.path.append(os.path.dirname(__file__) + "/../build")
+import cpuinfer_ext
+import torch
+
+layer_num = 10
+kv_head_num = 8
+q_head_num = 32
+head_dim = 128
+block_len = 128
+anchor_num = 1
+
+anchor_type = cpuinfer_ext.kvcache.AnchorType.DYNAMIC
+kv_type = cpuinfer_ext.kvcache.ggml_type.FP16
+retrieval_type = cpuinfer_ext.kvcache.RetrievalType.LAYER
+layer_step: int = 1
+token_step: int = 1
+layer_offset: int = 0
+max_thread_num: int = 64
+max_batch_size: int = 1
+max_block_num: int = 1024
+CPUInfer = cpuinfer_ext.CPUInfer(max_thread_num)
+
+warm_up_iter = 1000
+test_iter = 10000
+
+
+def bench_linear(cache_seqlen: int):
+ with torch.inference_mode(mode=True):
+ cache_seqlens = torch.tensor([cache_seqlen], dtype=torch.int32, device="cpu")
+ seqlens_zero = torch.zeros((1,), dtype=torch.int32, device="cpu")
+
+ config = cpuinfer_ext.kvcache.KVCacheConfig(
+ layer_num,
+ kv_head_num,
+ q_head_num,
+ head_dim,
+ block_len,
+ anchor_num,
+ anchor_type,
+ kv_type,
+ retrieval_type,
+ layer_step,
+ token_step,
+ layer_offset,
+ max_block_num,
+ max_batch_size,
+ max_thread_num,
+ )
+ local_kvcache = cpuinfer_ext.kvcache.KVCache(config)
+ block_table = (
+ torch.arange(max_block_num, dtype=torch.int32, device="cpu")
+ .contiguous()
+ .view(1, -1)
+ )
+
+ for layer_idx in range(layer_num):
+ k_cache = torch.randn(
+ (1, cache_seqlen, kv_head_num, head_dim),
+ dtype=torch.float16,
+ device="cpu",
+ ).contiguous()
+ v_cache = torch.randn(
+ (1, cache_seqlen, kv_head_num, head_dim),
+ dtype=torch.float16,
+ device="cpu",
+ ).contiguous()
+
+ CPUInfer.submit(
+ local_kvcache.update_kvcache_fp16(
+ k_cache.data_ptr(),
+ v_cache.data_ptr(),
+ layer_idx,
+ block_table.data_ptr(),
+ 1,
+ max_block_num,
+ seqlens_zero.data_ptr(),
+ cache_seqlen,
+ )
+ )
+ CPUInfer.sync()
+
+ input = torch.randn(
+ (1, 1, q_head_num, head_dim), dtype=torch.float16, device="cpu"
+ ).contiguous()
+ output = torch.empty(
+ (1, 1, q_head_num, head_dim), dtype=torch.float16, device="cpu"
+ ).contiguous()
+
+ # attn_lse: (bsz, q_len, q_head_num)
+ attn_lse = torch.empty(
+ (1, 1, q_head_num), dtype=torch.float32, device="cpu"
+ ).contiguous()
+ input = input / 100
+
+ # warm up
+ for i in range(warm_up_iter):
+ CPUInfer.submit(
+ local_kvcache.attn(
+ input.data_ptr(),
+ output.data_ptr(),
+ attn_lse.data_ptr(),
+ i % layer_num,
+ 0,
+ 1,
+ 1,
+ max_block_num,
+ block_table.data_ptr(),
+ cache_seqlens.data_ptr(),
+ -1,
+ -1,
+ -1,
+ )
+ )
+ CPUInfer.sync()
+
+ # test
+ start = time.perf_counter()
+ for i in range(test_iter):
+ CPUInfer.submit(
+ local_kvcache.attn(
+ input.data_ptr(),
+ output.data_ptr(),
+ attn_lse.data_ptr(),
+ i % layer_num,
+ 0,
+ 1,
+ 1,
+ max_block_num,
+ block_table.data_ptr(),
+ cache_seqlens.data_ptr(),
+ -1,
+ -1,
+ -1,
+ )
+ )
+ CPUInfer.sync()
+ end = time.perf_counter()
+ total_time = end - start
+ print("cache sequence length: ", cache_seqlen)
+ print("Time(s): ", total_time)
+ print("Iteration: ", test_iter)
+ print("Time(us) per iteration: ", total_time / test_iter * 1000000)
+ print(
+ "Bandwidth: ",
+ cache_seqlen
+ * kv_head_num
+ * head_dim
+ * 2
+ * 2
+ * test_iter
+ / total_time
+ / 1000
+ / 1000
+ / 1000,
+ "GB/s",
+ )
+ print("")
+
+
+bench_linear(1024)
+bench_linear(4096)
+bench_linear(16384)
+bench_linear(32768)
+bench_linear(65536)
diff --git a/ktransformers/ktransformers_ext/bench/bench_attention_torch.py b/ktransformers/ktransformers_ext/bench/bench_attention_torch.py
new file mode 100644
index 0000000..25f20e7
--- /dev/null
+++ b/ktransformers/ktransformers_ext/bench/bench_attention_torch.py
@@ -0,0 +1,94 @@
+#!/usr/bin/env python
+# coding=utf-8
+"""
+Description :
+Author : Jianwei Dong
+Date : 2024-08-28 10:32:05
+Version : 1.0.0
+LastEditors : Jianwei Dong
+LastEditTime : 2024-08-28 10:32:05
+Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
+"""
+import os, sys
+import time
+
+sys.path.append(os.path.dirname(__file__) + "/../build")
+import cpuinfer_ext
+import torch
+
+layer_num = 10
+kv_head_num = 8
+q_head_num = 32
+head_dim = 128
+block_len = 128
+anchor_num = 1
+warm_up_iter = 1000
+test_iter = 10000
+
+
+def bench_linear(cache_seqlen: int, device):
+ with torch.inference_mode(mode=True):
+
+ kvcaches = []
+
+ for layer_idx in range(layer_num):
+ k_cache = torch.randn(
+ (1, 32, cache_seqlen, head_dim),
+ dtype=torch.float16,
+ device=device,
+ ).contiguous()
+ v_cache = torch.randn(
+ (1, 32, cache_seqlen, head_dim),
+ dtype=torch.float16,
+ device=device,
+ ).contiguous()
+
+ kvcaches.append((k_cache, v_cache))
+
+ input = torch.randn(
+ (1, q_head_num, 1, head_dim), dtype=torch.float16, device=device
+ ).contiguous()
+ input = input / 100
+
+ # warm up
+ for i in range(warm_up_iter):
+ k_cache = kvcaches[i % layer_num][0]
+ v_cache = kvcaches[i % layer_num][1]
+ torch.nn.functional.scaled_dot_product_attention(input, k_cache, v_cache)
+
+ # test
+ start = time.perf_counter()
+ for i in range(test_iter):
+ k_cache = kvcaches[i % layer_num][0]
+ v_cache = kvcaches[i % layer_num][1]
+ torch.nn.functional.scaled_dot_product_attention(input, k_cache, v_cache)
+ end = time.perf_counter()
+ total_time = end - start
+ print("cache sequence length: ", cache_seqlen)
+ print("Time(s): ", total_time)
+ print("Iteration: ", test_iter)
+ print("Time(us) per iteration: ", total_time / test_iter * 1000000)
+ print(
+ "Bandwidth: ",
+ cache_seqlen
+ * q_head_num
+ * head_dim
+ * 2
+ * 2
+ * test_iter
+ / total_time
+ / 1000
+ / 1000
+ / 1000,
+ "GB/s",
+ )
+ print("")
+
+
+bench_linear(1024, "cpu")
+bench_linear(4096, "cpu")
+bench_linear(1024, "cuda")
+bench_linear(4096, "cuda")
+bench_linear(16384, "cuda")
+bench_linear(32768, "cuda")
+bench_linear(65536, "cuda")
diff --git a/ktransformers/ktransformers_ext/cpu_backend/backend.cpp b/ktransformers/ktransformers_ext/cpu_backend/backend.cpp
index 5707505..16693f0 100644
--- a/ktransformers/ktransformers_ext/cpu_backend/backend.cpp
+++ b/ktransformers/ktransformers_ext/cpu_backend/backend.cpp
@@ -3,93 +3,125 @@
* @Author : chenht2022
* @Date : 2024-07-22 02:03:05
* @Version : 1.0.0
- * @LastEditors : chenht2022
+ * @LastEditors : chenht2022
* @LastEditTime : 2024-07-25 10:33:34
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
+
#include "backend.h"
-Backend::Backend(int thread_num) {
- thread_num_ = thread_num;
- thread_state_.resize(thread_num);
- for (int i = 0; i < thread_num; i++) {
+thread_local int Backend::thread_local_id = -1;
+
+Backend::Backend(int max_thread_num) {
+ max_thread_num_ = max_thread_num;
+ thread_state_.resize(max_thread_num_);
+ for (int i = 0; i < max_thread_num_; i++) {
thread_state_[i].curr = std::make_unique>();
- thread_state_[i].status = std::make_unique>(ThreadStatus::WAITING);
+ thread_state_[i].status =
+ std::make_unique>(ThreadStatus::WAITING);
}
- workers_.resize(thread_num);
- for (int i = 1; i < thread_num; i++) {
+ workers_.resize(max_thread_num_);
+ for (int i = 1; i < max_thread_num_; i++) {
workers_[i] = std::thread(&Backend::worker_thread, this, i);
}
}
Backend::~Backend() {
- for (int i = 0; i < thread_num_; i++) {
- thread_state_[i].status->store(ThreadStatus::EXIT, std::memory_order_release);
+ for (int i = 0; i < max_thread_num_; i++) {
+ thread_state_[i].status->store(ThreadStatus::EXIT,
+ std::memory_order_release);
}
- for (int i = 1; i < thread_num_; i++) {
+ for (int i = 1; i < max_thread_num_; i++) {
if (workers_[i].joinable()) {
workers_[i].join();
}
}
}
-int Backend::get_thread_num() {
- return thread_num_;
-}
+int Backend::get_thread_num() { return max_thread_num_; }
-void Backend::do_work_stealing_job(int task_num, std::function func) {
- func_ = func;
+void Backend::do_work_stealing_job(int task_num,
+ std::function init_func,
+ std::function compute_func,
+ std::function finalize_func) {
+ init_func_ = init_func;
+ compute_func_ = compute_func;
+ finalize_func_ = finalize_func;
+ thread_num_ = std::min(max_thread_num_, task_num);
int base = task_num / thread_num_;
int remain = task_num % thread_num_;
thread_state_[0].end = base + (0 < remain);
+
+ // 为主线程设置 thread_local_id
+ thread_local_id = 0;
+
for (int i = 1; i < thread_num_; i++) {
- thread_state_[i].curr->store(thread_state_[i - 1].end, std::memory_order_relaxed);
+ thread_state_[i].curr->store(thread_state_[i - 1].end,
+ std::memory_order_relaxed);
thread_state_[i].end = thread_state_[i - 1].end + base + (i < remain);
- thread_state_[i].status->store(ThreadStatus::WORKING, std::memory_order_release);
+ thread_state_[i].status->store(ThreadStatus::WORKING,
+ std::memory_order_release);
}
thread_state_[0].curr->store(0, std::memory_order_relaxed);
- thread_state_[0].status->store(ThreadStatus::WORKING, std::memory_order_release);
+ thread_state_[0].status->store(ThreadStatus::WORKING,
+ std::memory_order_release);
process_tasks(0);
for (int i = 1; i < thread_num_; i++) {
- while (thread_state_[i].status->load(std::memory_order_acquire) == ThreadStatus::WORKING) {
+ while (thread_state_[i].status->load(std::memory_order_acquire) ==
+ ThreadStatus::WORKING) {
}
}
}
void Backend::process_tasks(int thread_id) {
+ if (init_func_ != nullptr) {
+ init_func_(thread_id);
+ }
while (true) {
- int task_id = thread_state_[thread_id].curr->fetch_add(1, std::memory_order_acq_rel);
+ int task_id = thread_state_[thread_id].curr->fetch_add(
+ 1, std::memory_order_acq_rel);
if (task_id >= thread_state_[thread_id].end) {
break;
}
- func_(task_id);
+ compute_func_(task_id);
}
for (int t_offset = 1; t_offset < thread_num_; t_offset++) {
int t_i = (thread_id + t_offset) % thread_num_;
- if (thread_state_[t_i].status->load(std::memory_order_acquire) != ThreadStatus::WORKING) {
+ if (thread_state_[t_i].status->load(std::memory_order_acquire) !=
+ ThreadStatus::WORKING) {
continue;
}
while (true) {
- int task_id = thread_state_[t_i].curr->fetch_add(1, std::memory_order_acq_rel);
+ int task_id = thread_state_[t_i].curr->fetch_add(
+ 1, std::memory_order_acq_rel);
if (task_id >= thread_state_[t_i].end) {
break;
}
- func_(task_id);
+ compute_func_(task_id);
}
}
- thread_state_[thread_id].status->store(ThreadStatus::WAITING, std::memory_order_release);
+ if (finalize_func_ != nullptr) {
+ finalize_func_(thread_id);
+ }
+ thread_state_[thread_id].status->store(ThreadStatus::WAITING,
+ std::memory_order_release);
}
void Backend::worker_thread(int thread_id) {
auto start = std::chrono::steady_clock::now();
+ thread_local_id = thread_id; // 设置线程本地变量
while (true) {
- ThreadStatus status = thread_state_[thread_id].status->load(std::memory_order_acquire);
+ ThreadStatus status =
+ thread_state_[thread_id].status->load(std::memory_order_acquire);
if (status == ThreadStatus::WORKING) {
process_tasks(thread_id);
start = std::chrono::steady_clock::now();
} else if (status == ThreadStatus::WAITING) {
auto now = std::chrono::steady_clock::now();
- auto duration = std::chrono::duration_cast(now - start).count();
+ auto duration =
+ std::chrono::duration_cast(now -
+ start)
+ .count();
if (duration > 50) {
std::this_thread::sleep_for(std::chrono::milliseconds(1));
}
diff --git a/ktransformers/ktransformers_ext/cpu_backend/backend.h b/ktransformers/ktransformers_ext/cpu_backend/backend.h
index be3d45b..80ff7f9 100644
--- a/ktransformers/ktransformers_ext/cpu_backend/backend.h
+++ b/ktransformers/ktransformers_ext/cpu_backend/backend.h
@@ -3,7 +3,7 @@
* @Author : chenht2022
* @Date : 2024-07-22 02:03:05
* @Version : 1.0.0
- * @LastEditors : chenht2022
+ * @LastEditors : chenht2022
* @LastEditTime : 2024-07-25 10:33:38
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
@@ -31,20 +31,25 @@ struct ThreadState {
};
class Backend {
- public:
+ public:
Backend(int);
~Backend();
int get_thread_num();
- void do_work_stealing_job(int, std::function);
+ void do_work_stealing_job(int, std::function,
+ std::function,
+ std::function);
+ static thread_local int thread_local_id;
- private:
+ private:
int thread_num_;
- std::vector thread_state_; // [thread_num]
- std::function func_;
+ int max_thread_num_;
+ std::vector thread_state_; // [thread_num]
+ std::function init_func_;
+ std::function compute_func_;
+ std::function finalize_func_;
std::vector workers_;
void process_tasks(int);
void worker_thread(int);
};
-
#endif
\ No newline at end of file
diff --git a/ktransformers/ktransformers_ext/cpu_backend/task_queue.cpp b/ktransformers/ktransformers_ext/cpu_backend/task_queue.cpp
index 0ca865b..fb7ac4f 100644
--- a/ktransformers/ktransformers_ext/cpu_backend/task_queue.cpp
+++ b/ktransformers/ktransformers_ext/cpu_backend/task_queue.cpp
@@ -54,4 +54,4 @@ void TaskQueue::processTasks() {
}
mutex.unlock();
}
-}
+}
\ No newline at end of file
diff --git a/ktransformers/ktransformers_ext/cpu_backend/task_queue.h b/ktransformers/ktransformers_ext/cpu_backend/task_queue.h
index 13836b7..a633a40 100644
--- a/ktransformers/ktransformers_ext/cpu_backend/task_queue.h
+++ b/ktransformers/ktransformers_ext/cpu_backend/task_queue.h
@@ -4,7 +4,7 @@
* @Date : 2024-07-16 10:43:18
* @Version : 1.0.0
* @LastEditors : chenxl
- * @LastEditTime : 2024-08-12 12:28:25
+ * @LastEditTime : 2024-08-08 04:23:51
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
#ifndef CPUINFER_TASKQUEUE_H
diff --git a/ktransformers/ktransformers_ext/examples/test_attention.py b/ktransformers/ktransformers_ext/examples/test_attention.py
new file mode 100644
index 0000000..5627a0e
--- /dev/null
+++ b/ktransformers/ktransformers_ext/examples/test_attention.py
@@ -0,0 +1,142 @@
+#!/usr/bin/env python
+# coding=utf-8
+"""
+Description :
+Author : Jianwei Dong
+Date : 2024-08-28 10:32:05
+Version : 1.0.0
+LastEditors : chenht2022
+LastEditTime : 2024-08-28 10:32:05
+Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
+"""
+import os, sys
+import time
+
+sys.path.append(os.path.dirname(__file__) + "/../build")
+import cpuinfer_ext
+from flash_attn import flash_attn_with_kvcache
+import torch
+
+layer_num = 10
+kv_head_num = 8
+q_head_num = 32
+head_dim = 128
+block_len = 128
+anchor_num = 1
+cache_seqlen = 8192
+cache_seqlens = torch.tensor([cache_seqlen], dtype=torch.int32, device="cpu")
+seqlens_zero = torch.zeros((1,), dtype=torch.int32, device="cpu")
+anchor_type = cpuinfer_ext.kvcache.AnchorType.DYNAMIC
+kv_type = cpuinfer_ext.kvcache.ggml_type.FP16
+retrieval_type = cpuinfer_ext.kvcache.RetrievalType.LAYER
+layer_step: int = 1
+token_step: int = 1
+layer_offset: int = 0
+max_thread_num: int = 2
+max_batch_size: int = 1
+max_block_num: int = 512
+CPUInfer = cpuinfer_ext.CPUInfer(max_thread_num)
+validation_iter = 100
+
+with torch.inference_mode(mode=True):
+ config = cpuinfer_ext.kvcache.KVCacheConfig(
+ layer_num,
+ kv_head_num,
+ q_head_num,
+ head_dim,
+ block_len,
+ anchor_num,
+ anchor_type,
+ kv_type,
+ retrieval_type,
+ layer_step,
+ token_step,
+ layer_offset,
+ max_block_num,
+ max_batch_size,
+ max_thread_num,
+ )
+ local_kvcache = cpuinfer_ext.kvcache.KVCache(config)
+
+ kvcaches = []
+ block_table = (
+ torch.arange(max_block_num, dtype=torch.int32, device="cpu")
+ .contiguous()
+ .view(1, -1)
+ )
+
+ for layer_idx in range(layer_num):
+ k_cache = torch.randn(
+ (1, cache_seqlen, kv_head_num, head_dim), dtype=torch.float16, device="cpu"
+ ).contiguous()
+ v_cache = torch.randn(
+ (1, cache_seqlen, kv_head_num, head_dim), dtype=torch.float16, device="cpu"
+ ).contiguous()
+
+ CPUInfer.submit(
+ local_kvcache.update_kvcache_fp16(
+ k_cache.data_ptr(),
+ v_cache.data_ptr(),
+ layer_idx,
+ block_table.data_ptr(),
+ 1,
+ max_block_num,
+ seqlens_zero.data_ptr(),
+ cache_seqlen,
+ )
+ )
+ CPUInfer.sync()
+
+ kvcaches.append((k_cache.to("cuda"), v_cache.to("cuda")))
+
+ # validation
+ for i in range(validation_iter):
+
+ k_cache = kvcaches[i % layer_num][0]
+ v_cache = kvcaches[i % layer_num][1]
+ input = torch.randn(
+ (1, 1, q_head_num, head_dim), dtype=torch.float16, device="cpu"
+ ).contiguous()
+ output = torch.empty(
+ (1, 1, q_head_num, head_dim), dtype=torch.float16, device="cpu"
+ ).contiguous()
+
+ # attn_lse: (bsz, q_len, q_head_num)
+ attn_lse = torch.empty(
+ (1, 1, q_head_num), dtype=torch.float32, device="cpu"
+ ).contiguous()
+ input = input / 100
+
+ CPUInfer.submit(
+ local_kvcache.attn(
+ input.data_ptr(),
+ output.data_ptr(),
+ attn_lse.data_ptr(),
+ i % layer_num,
+ 0,
+ 1,
+ 1,
+ max_block_num,
+ block_table.data_ptr(),
+ cache_seqlens.data_ptr(),
+ -1,
+ -1,
+ -1,
+ )
+ )
+ CPUInfer.sync()
+ # print("cpuinfer output", output)
+
+ t_output = flash_attn_with_kvcache(
+ q=input.to("cuda"),
+ k_cache=k_cache,
+ v_cache=v_cache,
+ cache_seqlens=cache_seqlens.to("cuda"),
+ )
+ # print("torch output", t_output)
+
+ diff = torch.mean(torch.abs(output.to("cuda") - t_output)) / torch.mean(
+ torch.abs(t_output)
+ )
+ print("diff = ", diff)
+ assert diff < 0.001
diff --git a/ktransformers/ktransformers_ext/ext_bindings.cpp b/ktransformers/ktransformers_ext/ext_bindings.cpp
index c220a9b..902d427 100644
--- a/ktransformers/ktransformers_ext/ext_bindings.cpp
+++ b/ktransformers/ktransformers_ext/ext_bindings.cpp
@@ -1,19 +1,17 @@
/**
* @Description :
- * @Author : chenht2022
+ * @Author : chenht2022, Jianwei Dong
* @Date : 2024-07-22 02:03:22
* @Version : 1.0.0
- * @LastEditors : chenht2022
- * @LastEditTime : 2024-08-07 10:39:37
+ * @LastEditors : Jianwei Dong
+ * @LastEditTime : 2024-08-26 22:47:06
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
// Python bindings
-#include
-#include
-#include
#include "cpu_backend/cpuinfer.h"
#include "device_launch_parameters.h"
#include "llamafile/flags.h"
+#include "operators/kvcache/kvcache.h"
#include "operators/llamafile/linear.h"
#include "operators/llamafile/mlp.h"
#include "operators/llamafile/moe.h"
@@ -21,119 +19,541 @@
#include "pybind11/operators.h"
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
+#include
+#include
+#include
namespace py = pybind11;
using namespace pybind11::literals;
+// Binding functions for the KVCache class
+class KVCacheBindings {
+ public:
+ class AttnBindings {
+ public:
+ struct Args {
+ CPUInfer *cpuinfer;
+ KVCache *kv_cache;
+ const ggml_fp16_t *q_in;
+ ggml_fp16_t *output;
+ float *attn_lse;
+ int layer_idx;
+ int generate_token_idx;
+ int q_len;
+ int batch_size;
+ int max_block_num;
+ int *block_table;
+ int *cache_seqlens;
+ int pick_block_num;
+ int init_block_num;
+ int local_block_num;
+ };
+ static void inner(void *args) {
+ Args *args_ = (Args *)args;
+ args_->cpuinfer->enqueue(
+ &KVCache::attn, args_->kv_cache, args_->q_in, args_->output,
+ args_->attn_lse, args_->layer_idx, args_->generate_token_idx,
+ args_->q_len, args_->batch_size, args_->max_block_num,
+ args_->block_table, args_->cache_seqlens, args_->pick_block_num,
+ args_->init_block_num, args_->local_block_num);
+ }
+ static std::pair
+ cpuinfer_interface(KVCache &kv_cache, intptr_t q_in, intptr_t output,
+ intptr_t attn_lse, int layer_idx,
+ int generate_token_idx, int q_len, int batch_size,
+ int max_block_num, intptr_t block_table,
+ intptr_t cache_seqlens, int pick_block_num,
+ int init_block_num, int local_block_num) {
+ Args *args = new Args{nullptr,
+ &kv_cache,
+ (const ggml_fp16_t *)q_in,
+ (ggml_fp16_t *)output,
+ (float *)attn_lse,
+ layer_idx,
+ generate_token_idx,
+ q_len,
+ batch_size,
+ max_block_num,
+ (int *)block_table,
+ (int *)cache_seqlens,
+ pick_block_num,
+ init_block_num,
+ local_block_num};
+ return std::make_pair((intptr_t)&inner, (intptr_t)args);
+ }
+ };
+
+ class GetAllKVCacheOneLayerBindings {
+ public:
+ struct Args {
+ CPUInfer *cpuinfer;
+ KVCache *kv_cache;
+ int layer_id;
+ ggml_fp16_t *k_in;
+ ggml_fp16_t *v_in;
+ };
+ static void inner(void *args) {
+ Args *args_ = (Args *)args;
+ args_->cpuinfer->enqueue(&KVCache::get_all_kvcache_one_layer,
+ args_->kv_cache, args_->layer_id,
+ args_->k_in, args_->v_in);
+ }
+ static std::pair
+ cpuinfer_interface(KVCache &kv_cache, intptr_t k_in, intptr_t v_in,
+ int layer_id) {
+ Args *args = new Args{nullptr, &kv_cache, layer_id,
+ (ggml_fp16_t *)k_in, (ggml_fp16_t *)v_in};
+ return std::make_pair((intptr_t)&inner, (intptr_t)args);
+ }
+ };
+
+ class GetAndUpdateKVCacheFp16Bindings {
+ public:
+ struct Args {
+ CPUInfer *cpuinfer;
+ KVCache *kv_cache;
+ ggml_fp16_t *k_in;
+ ggml_fp16_t *v_in;
+ int layer_id;
+ int *block_table;
+ int batch_size;
+ int max_block_num;
+ int *cache_seqlens;
+ int q_len;
+ };
+ static void inner(void *args) {
+ Args *args_ = (Args *)args;
+ args_->cpuinfer->enqueue(&KVCache::get_and_update_kvcache_fp16,
+ args_->kv_cache, args_->k_in, args_->v_in,
+ args_->layer_id, args_->block_table,
+ args_->batch_size, args_->max_block_num,
+ args_->cache_seqlens, args_->q_len);
+ }
+ static std::pair
+ cpuinfer_interface(KVCache &kv_cache, intptr_t k_in, intptr_t v_in,
+ int layer_id, intptr_t block_table, int batch_size,
+ int max_block_num, intptr_t cache_seqlens,
+ int q_len) {
+ Args *args = new Args{nullptr,
+ &kv_cache,
+ (ggml_fp16_t *)k_in,
+ (ggml_fp16_t *)v_in,
+ layer_id,
+ (int *)block_table,
+ batch_size,
+ max_block_num,
+ (int *)cache_seqlens,
+ q_len};
+ return std::make_pair((intptr_t)&inner, (intptr_t)args);
+ }
+ };
+ class GetKVCacheFp16Bindings {
+ public:
+ struct Args {
+ CPUInfer *cpuinfer;
+ KVCache *kv_cache;
+ ggml_fp16_t *k_in;
+ ggml_fp16_t *v_in;
+ int layer_id;
+ int *block_table;
+ int batch_size;
+ int max_block_num;
+ int *cache_seqlens;
+ };
+ static void inner(void *args) {
+ Args *args_ = (Args *)args;
+ args_->cpuinfer->enqueue(
+ &KVCache::get_kvcache_fp16, args_->kv_cache, args_->k_in,
+ args_->v_in, args_->layer_id, args_->block_table,
+ args_->batch_size, args_->max_block_num, args_->cache_seqlens);
+ }
+ static std::pair
+ cpuinfer_interface(KVCache &kv_cache, intptr_t k_in, intptr_t v_in,
+ int layer_id, intptr_t block_table, int batch_size,
+ int max_block_num, intptr_t cache_seqlens) {
+ Args *args = new Args{nullptr,
+ &kv_cache,
+ (ggml_fp16_t *)k_in,
+ (ggml_fp16_t *)v_in,
+ layer_id,
+ (int *)block_table,
+ batch_size,
+ max_block_num,
+ (int *)cache_seqlens};
+ return std::make_pair((intptr_t)&inner, (intptr_t)args);
+ }
+ };
+
+ class UpdateKVCacheFp16Bindings {
+ public:
+ struct Args {
+ CPUInfer *cpuinfer;
+ KVCache *kv_cache;
+ ggml_fp16_t *k_in;
+ ggml_fp16_t *v_in;
+ int layer_id;
+ int *block_table;
+ int batch_size;
+ int max_block_num;
+ int *cache_seqlens;
+ int q_len;
+ };
+ static void inner(void *args) {
+ Args *args_ = (Args *)args;
+ args_->cpuinfer->enqueue(&KVCache::update_kvcache_fp16,
+ args_->kv_cache, args_->k_in, args_->v_in,
+ args_->layer_id, args_->block_table,
+ args_->batch_size, args_->max_block_num,
+ args_->cache_seqlens, args_->q_len);
+ }
+ static std::pair
+ cpuinfer_interface(KVCache &kv_cache, intptr_t k_in, intptr_t v_in,
+ int layer_id, intptr_t block_table, int batch_size,
+ int max_block_num, intptr_t cache_seqlens,
+ int q_len) {
+ Args *args = new Args{nullptr,
+ &kv_cache,
+ (ggml_fp16_t *)k_in,
+ (ggml_fp16_t *)v_in,
+ layer_id,
+ (int *)block_table,
+ batch_size,
+ max_block_num,
+ (int *)cache_seqlens,
+ q_len};
+ return std::make_pair((intptr_t)&inner, (intptr_t)args);
+ }
+ };
+
+ class UpdateImportanceBindings {
+ public:
+ struct Args {
+ CPUInfer *cpuinfer;
+ KVCache *kv_cache;
+ const ggml_fp16_t *importance;
+ int layer_id;
+ int *block_table;
+ int batch_size;
+ int max_block_num;
+ int *offset;
+ int width;
+ };
+ static void inner(void *args) {
+ Args *args_ = (Args *)args;
+ args_->cpuinfer->enqueue(
+ &KVCache::update_importance, args_->kv_cache, args_->importance,
+ args_->layer_id, args_->block_table, args_->batch_size,
+ args_->max_block_num, args_->offset, args_->width);
+ }
+ static std::pair
+ cpuinfer_interface(KVCache &kv_cache, intptr_t importance, int layer_id,
+ intptr_t block_table, int batch_size,
+ int max_block_num, intptr_t offset, int width) {
+ Args *args = new Args{nullptr,
+ &kv_cache,
+ (const ggml_fp16_t *)importance,
+ layer_id,
+ (int *)block_table,
+ batch_size,
+ max_block_num,
+ (int *)offset,
+ width};
+ return std::make_pair((intptr_t)&inner, (intptr_t)args);
+ }
+ };
+
+ class AttnWithKVCacheBindings {
+ public:
+ struct Args {
+ CPUInfer *cpuinfer;
+ KVCache *kv_cache;
+ const ggml_fp16_t *q_in;
+ const ggml_fp16_t *k_in;
+ const ggml_fp16_t *v_in;
+ ggml_fp16_t *output;
+ float *attn_lse;
+ int layer_idx;
+ int generate_token_idx;
+ int q_len;
+ int batch_size;
+ int max_block_num;
+ int *block_table;
+ int *cache_seqlens;
+ int topk;
+ int local;
+ };
+ static void inner(void *args) {
+ Args *args_ = (Args *)args;
+ args_->cpuinfer->enqueue(
+ &KVCache::attn_with_kvcache, args_->kv_cache, args_->q_in,
+ args_->k_in, args_->v_in, args_->output, args_->attn_lse,
+ args_->layer_idx, args_->generate_token_idx, args_->q_len,
+ args_->batch_size, args_->max_block_num, args_->block_table,
+ args_->cache_seqlens, args_->topk, args_->local);
+ }
+ static std::pair
+ cpuinfer_interface(KVCache &kv_cache, intptr_t q_in, intptr_t k_in,
+ intptr_t v_in, intptr_t output, intptr_t attn_lse,
+ int layer_idx, int generate_token_idx, int q_len,
+ int batch_size, int max_block_num,
+ intptr_t block_table, intptr_t cache_seqlens,
+ int topk, int local) {
+ Args *args = new Args{nullptr,
+ &kv_cache,
+ (const ggml_fp16_t *)q_in,
+ (const ggml_fp16_t *)k_in,
+ (const ggml_fp16_t *)v_in,
+ (ggml_fp16_t *)output,
+ (float *)attn_lse,
+ layer_idx,
+ generate_token_idx,
+ q_len,
+ batch_size,
+ max_block_num,
+ (int *)block_table,
+ (int *)cache_seqlens,
+ topk,
+ local};
+ return std::make_pair((intptr_t)&inner, (intptr_t)args);
+ }
+ };
+
+ class ClearImportanceAllLayersBindings {
+ public:
+ struct Args {
+ CPUInfer *cpuinfer;
+ KVCache *kv_cache;
+ int *block_table;
+ int *cache_seqlens;
+ int batch_size;
+ int max_block_num;
+ };
+ static void inner(void *args) {
+ Args *args_ = (Args *)args;
+ args_->cpuinfer->enqueue(&KVCache::clear_importance_all_layers,
+ args_->kv_cache, args_->block_table,
+ args_->cache_seqlens, args_->batch_size,
+ args_->max_block_num);
+ }
+ static std::pair
+ cpuinfer_interface(KVCache &kv_cache, intptr_t block_table,
+ intptr_t cache_seqlens, int batch_size,
+ int max_block_num) {
+ Args *args = new Args{nullptr,
+ &kv_cache,
+ (int *)block_table,
+ (int *)cache_seqlens,
+ batch_size,
+ max_block_num};
+ return std::make_pair((intptr_t)&inner, (intptr_t)args);
+ }
+ };
+
+ class CalcAnchorAllLayersBindinds {
+ public:
+ struct Args {
+ CPUInfer *cpuinfer;
+ KVCache *kv_cache;
+ int *block_table;
+ int *cache_seqlens;
+ int batch_size;
+ int max_block_num;
+ };
+ static void inner(void *args) {
+ Args *args_ = (Args *)args;
+ args_->cpuinfer->enqueue(&KVCache::calc_anchor_all_layers,
+ args_->kv_cache, args_->block_table,
+ args_->cache_seqlens, args_->batch_size,
+ args_->max_block_num);
+ }
+ static std::pair
+ cpuinfer_interface(KVCache &kv_cache, intptr_t block_table,
+ intptr_t cache_seqlens, int batch_size,
+ int max_block_num) {
+ Args *args = new Args{nullptr,
+ &kv_cache,
+ (int *)block_table,
+ (int *)cache_seqlens,
+ batch_size,
+ max_block_num};
+ return std::make_pair((intptr_t)&inner, (intptr_t)args);
+ }
+ };
+
+ class LoadKVCacheBindings {
+ public:
+ struct Args {
+ CPUInfer *cpuinfer;
+ KVCache *kv_cache;
+ std::string tensor_file_path;
+ };
+ static void inner(void *args) {
+ Args *args_ = (Args *)args;
+ args_->cpuinfer->enqueue(&KVCache::load_kvcache, args_->kv_cache,
+ args_->tensor_file_path);
+ }
+ static std::pair
+ cpuinfer_interface(KVCache &kv_cache, std::string tensor_file_path) {
+ Args *args =
+ new Args{nullptr, &kv_cache, (std::string)tensor_file_path};
+ return std::make_pair((intptr_t)&inner, (intptr_t)args);
+ }
+ };
+ class DumpKVCacheBindings {
+ public:
+ struct Args {
+ CPUInfer *cpuinfer;
+ KVCache *kv_cache;
+ int *block_table;
+ int cache_total_len;
+ std::string tensor_file_path;
+ };
+ static void inner(void *args) {
+ Args *args_ = (Args *)args;
+ args_->cpuinfer->enqueue(&KVCache::dump_kvcache, args_->kv_cache,
+ args_->block_table, args_->cache_total_len,
+ args_->tensor_file_path);
+ }
+ static std::pair
+ cpuinfer_interface(KVCache &kv_cache, intptr_t block_table,
+ int cache_total_len, std::string tensor_file_path) {
+ Args *args =
+ new Args{nullptr, &kv_cache, (int *)block_table,
+ cache_total_len, (std::string)tensor_file_path};
+ return std::make_pair((intptr_t)&inner, (intptr_t)args);
+ }
+ };
+};
+
class LinearBindings {
- public:
+ public:
class WarmUpBindinds {
- public:
+ public:
struct Args {
- CPUInfer* cpuinfer;
- Linear* linear;
+ CPUInfer *cpuinfer;
+ Linear *linear;
};
- static void inner(void* args) {
- Args* args_ = (Args*)args;
+ static void inner(void *args) {
+ Args *args_ = (Args *)args;
args_->cpuinfer->enqueue(&Linear::warm_up, args_->linear);
}
- static std::pair cpuinfer_interface(Linear& linear) {
- Args* args = new Args{nullptr, &linear};
+ static std::pair
+ cpuinfer_interface(Linear &linear) {
+ Args *args = new Args{nullptr, &linear};
return std::make_pair((intptr_t)&inner, (intptr_t)args);
}
};
class ForwardBindings {
- public:
+ public:
struct Args {
- CPUInfer* cpuinfer;
- Linear* linear;
+ CPUInfer *cpuinfer;
+ Linear *linear;
int qlen;
- const void* input;
- void* output;
+ const void *input;
+ void *output;
};
- static void inner(void* args) {
- Args* args_ = (Args*)args;
- args_->cpuinfer->enqueue(&Linear::forward, args_->linear, args_->qlen, args_->input, args_->output);
+ static void inner(void *args) {
+ Args *args_ = (Args *)args;
+ args_->cpuinfer->enqueue(&Linear::forward, args_->linear,
+ args_->qlen, args_->input, args_->output);
}
- static std::pair cpuinfer_interface(Linear& linear, int qlen, intptr_t input, intptr_t output) {
- Args* args = new Args{nullptr, &linear, qlen, (const void*)input, (void*)output};
+ static std::pair
+ cpuinfer_interface(Linear &linear, int qlen, intptr_t input,
+ intptr_t output) {
+ Args *args = new Args{nullptr, &linear, qlen, (const void *)input,
+ (void *)output};
return std::make_pair((intptr_t)&inner, (intptr_t)args);
}
};
};
class MLPBindings {
- public:
+ public:
class WarmUpBindinds {
- public:
+ public:
struct Args {
- CPUInfer* cpuinfer;
- MLP* mlp;
+ CPUInfer *cpuinfer;
+ MLP *mlp;
};
- static void inner(void* args) {
- Args* args_ = (Args*)args;
+ static void inner(void *args) {
+ Args *args_ = (Args *)args;
args_->cpuinfer->enqueue(&MLP::warm_up, args_->mlp);
}
- static std::pair cpuinfer_interface(MLP& mlp) {
- Args* args = new Args{nullptr, &mlp};
+ static std::pair cpuinfer_interface(MLP &mlp) {
+ Args *args = new Args{nullptr, &mlp};
return std::make_pair((intptr_t)&inner, (intptr_t)args);
}
};
class ForwardBindings {
- public:
+ public:
struct Args {
- CPUInfer* cpuinfer;
- MLP* mlp;
+ CPUInfer *cpuinfer;
+ MLP *mlp;
int qlen;
- const void* input;
- void* output;
+ const void *input;
+ void *output;
};
- static void inner(void* args) {
- Args* args_ = (Args*)args;
- args_->cpuinfer->enqueue(&MLP::forward, args_->mlp, args_->qlen, args_->input, args_->output);
+ static void inner(void *args) {
+ Args *args_ = (Args *)args;
+ args_->cpuinfer->enqueue(&MLP::forward, args_->mlp, args_->qlen,
+ args_->input, args_->output);
}
- static std::pair cpuinfer_interface(MLP& mlp, int qlen, intptr_t input, intptr_t output) {
- Args* args = new Args{nullptr, &mlp, qlen, (const void*)input, (void*)output};
+ static std::pair
+ cpuinfer_interface(MLP &mlp, int qlen, intptr_t input,
+ intptr_t output) {
+ Args *args = new Args{nullptr, &mlp, qlen, (const void *)input,
+ (void *)output};
return std::make_pair((intptr_t)&inner, (intptr_t)args);
}
};
};
class MOEBindings {
- public:
+ public:
class WarmUpBindinds {
- public:
+ public:
struct Args {
- CPUInfer* cpuinfer;
- MOE* moe;
+ CPUInfer *cpuinfer;
+ MOE *moe;
};
- static void inner(void* args) {
- Args* args_ = (Args*)args;
+ static void inner(void *args) {
+ Args *args_ = (Args *)args;
args_->cpuinfer->enqueue(&MOE::warm_up, args_->moe);
}
- static std::pair cpuinfer_interface(MOE& moe) {
- Args* args = new Args{nullptr, &moe};
+ static std::pair cpuinfer_interface(MOE &moe) {
+ Args *args = new Args{nullptr, &moe};
return std::make_pair((intptr_t)&inner, (intptr_t)args);
}
};
class ForwardBindings {
- public:
+ public:
struct Args {
- CPUInfer* cpuinfer;
- MOE* moe;
+ CPUInfer *cpuinfer;
+ MOE *moe;
int qlen;
int k;
- const uint64_t* expert_ids;
- const float* weights;
- const void* input;
- void* output;
+ const uint64_t *expert_ids;
+ const float *weights;
+ const void *input;
+ void *output;
};
- static void inner(void* args) {
- Args* args_ = (Args*)args;
- args_->cpuinfer->enqueue(&MOE::forward, args_->moe, args_->qlen, args_->k, args_->expert_ids, args_->weights, args_->input, args_->output);
+ static void inner(void *args) {
+ Args *args_ = (Args *)args;
+ args_->cpuinfer->enqueue(
+ &MOE::forward, args_->moe, args_->qlen, args_->k,
+ args_->expert_ids, args_->weights, args_->input, args_->output);
}
- static std::pair cpuinfer_interface(MOE& moe, int qlen, int k, intptr_t expert_ids, intptr_t weights, intptr_t input, intptr_t output) {
- Args* args = new Args{nullptr, &moe, qlen, k, (const uint64_t*)expert_ids, (const float*)weights, (const void*)input, (void*)output};
+ static std::pair
+ cpuinfer_interface(MOE &moe, int qlen, int k, intptr_t expert_ids,
+ intptr_t weights, intptr_t input, intptr_t output) {
+ Args *args = new Args{nullptr,
+ &moe,
+ qlen,
+ k,
+ (const uint64_t *)expert_ids,
+ (const float *)weights,
+ (const void *)input,
+ (void *)output};
return std::make_pair((intptr_t)&inner, (intptr_t)args);
}
};
@@ -149,8 +569,12 @@ PYBIND11_MODULE(cpuinfer_ext, m) {
auto linear_module = m.def_submodule("linear");
py::class_(linear_module, "LinearConfig")
- .def(py::init([](int hidden_size, int intermediate_size, int stride, int group_max_len, intptr_t proj, int proj_type, int hidden_type) {
- return LinearConfig(hidden_size, intermediate_size, stride, group_max_len, (void*)proj, (ggml_type)proj_type, (ggml_type)hidden_type);
+ .def(py::init([](int hidden_size, int intermediate_size, int stride,
+ int group_max_len, intptr_t proj, int proj_type,
+ int hidden_type) {
+ return LinearConfig(hidden_size, intermediate_size, stride,
+ group_max_len, (void *)proj,
+ (ggml_type)proj_type, (ggml_type)hidden_type);
}));
py::class_(linear_module, "Linear")
.def(py::init())
@@ -159,8 +583,15 @@ PYBIND11_MODULE(cpuinfer_ext, m) {
auto mlp_module = m.def_submodule("mlp");
py::class_(mlp_module, "MLPConfig")
- .def(py::init([](int hidden_size, int intermediate_size, int stride, int group_max_len, intptr_t gate_proj, intptr_t up_proj, intptr_t down_proj, int gate_type, int up_type, int down_type, int hidden_type) {
- return MLPConfig(hidden_size, intermediate_size, stride, group_max_len, (void*)gate_proj, (void*)up_proj, (void*)down_proj, (ggml_type)gate_type, (ggml_type)up_type, (ggml_type)down_type, (ggml_type)hidden_type);
+ .def(py::init([](int hidden_size, int intermediate_size, int stride,
+ int group_max_len, intptr_t gate_proj,
+ intptr_t up_proj, intptr_t down_proj, int gate_type,
+ int up_type, int down_type, int hidden_type) {
+ return MLPConfig(hidden_size, intermediate_size, stride,
+ group_max_len, (void *)gate_proj, (void *)up_proj,
+ (void *)down_proj, (ggml_type)gate_type,
+ (ggml_type)up_type, (ggml_type)down_type,
+ (ggml_type)hidden_type);
}));
py::class_(mlp_module, "MLP")
.def(py::init())
@@ -169,11 +600,84 @@ PYBIND11_MODULE(cpuinfer_ext, m) {
auto moe_module = m.def_submodule("moe");
py::class_(moe_module, "MOEConfig")
- .def(py::init([](int expert_num, int routed_expert_num, int hidden_size, int intermediate_size, int stride, int group_min_len, int group_max_len, intptr_t gate_proj, intptr_t up_proj, intptr_t down_proj, int gate_type, int up_type, int down_type, int hidden_type) {
- return MOEConfig(expert_num, routed_expert_num, hidden_size, intermediate_size, stride, group_min_len, group_max_len, (void*)gate_proj, (void*)up_proj, (void*)down_proj, (ggml_type)gate_type, (ggml_type)up_type, (ggml_type)down_type, (ggml_type)hidden_type);
+ .def(py::init([](int expert_num, int routed_expert_num, int hidden_size,
+ int intermediate_size, int stride, int group_min_len,
+ int group_max_len, intptr_t gate_proj,
+ intptr_t up_proj, intptr_t down_proj, int gate_type,
+ int up_type, int down_type, int hidden_type) {
+ return MOEConfig(expert_num, routed_expert_num, hidden_size,
+ intermediate_size, stride, group_min_len,
+ group_max_len, (void *)gate_proj, (void *)up_proj,
+ (void *)down_proj, (ggml_type)gate_type,
+ (ggml_type)up_type, (ggml_type)down_type,
+ (ggml_type)hidden_type);
}));
py::class_(moe_module, "MOE")
.def(py::init())
.def("warm_up", &MOEBindings::WarmUpBindinds::cpuinfer_interface)
.def("forward", &MOEBindings::ForwardBindings::cpuinfer_interface);
+
+ auto kvcache_module = m.def_submodule("kvcache");
+
+ py::enum_(kvcache_module, "AnchorType")
+ .value("FIXED", AnchorType::FIXED_ANCHOR)
+ .value("DYNAMIC", AnchorType::DYNAMIC)
+ .value("QUEST", AnchorType::QUEST)
+ .value("BLOCK_MAX", AnchorType::BLOCK_MAX)
+ .value("BLOCK_MEAN", AnchorType::BLOCK_MEAN);
+ py::enum_(kvcache_module, "ggml_type")
+ .value("FP16", ggml_type::GGML_TYPE_F16)
+ .value("FP32", ggml_type::GGML_TYPE_F32)
+ .value("Q4_0", ggml_type::GGML_TYPE_Q4_0)
+ .value("Q8_0", ggml_type::GGML_TYPE_Q8_0);
+ py::enum_(kvcache_module, "RetrievalType")
+ .value("LAYER", RetrievalType::LAYER)
+ .value("KVHEAD", RetrievalType::KVHEAD)
+ .value("QHEAD", RetrievalType::QHEAD);
+
+ py::class_(kvcache_module, "KVCacheConfig")
+ .def(py::init())
+ .def_readwrite("layer_num", &KVCacheConfig::layer_num)
+ .def_readwrite("kv_head_num", &KVCacheConfig::kv_head_num)
+ .def_readwrite("q_head_num", &KVCacheConfig::q_head_num)
+ .def_readwrite("head_dim", &KVCacheConfig::head_dim)
+ .def_readwrite("block_len", &KVCacheConfig::block_len)
+ .def_readwrite("anchor_num", &KVCacheConfig::anchor_num)
+ .def_readwrite("anchor_type", &KVCacheConfig::anchor_type)
+ .def_readwrite("kv_type", &KVCacheConfig::kv_type)
+ .def_readwrite("retrieval_type", &KVCacheConfig::retrieval_type)
+ .def_readwrite("layer_step", &KVCacheConfig::layer_step)
+ .def_readwrite("token_step", &KVCacheConfig::token_step)
+ .def_readwrite("layer_offset", &KVCacheConfig::layer_offset)
+ .def_readwrite("max_block_num", &KVCacheConfig::max_block_num)
+ .def_readwrite("max_batch_size", &KVCacheConfig::max_batch_size)
+ .def_readwrite("max_thread_num", &KVCacheConfig::max_thread_num);
+ py::class_(kvcache_module, "KVCache")
+ .def(py::init())
+ .def("get_cache_total_len", &KVCache::get_cache_total_len)
+ .def("update_cache_total_len",
+ [](KVCache &kvcache, int cache_total_len) {
+ kvcache.update_cache_total_len(cache_total_len);
+ })
+ .def("attn", &KVCacheBindings::AttnBindings::cpuinfer_interface)
+ .def(
+ "get_all_kvcache_one_layer",
+ &KVCacheBindings::GetAllKVCacheOneLayerBindings::cpuinfer_interface)
+ .def("get_and_update_kvcache_fp16",
+ &KVCacheBindings::GetAndUpdateKVCacheFp16Bindings::
+ cpuinfer_interface)
+ .def("get_kvcache_fp16",
+ &KVCacheBindings::GetKVCacheFp16Bindings::cpuinfer_interface)
+ .def("update_kvcache_fp16",
+ &KVCacheBindings::UpdateKVCacheFp16Bindings::cpuinfer_interface)
+ .def("update_importance",
+ &KVCacheBindings::UpdateImportanceBindings::cpuinfer_interface)
+ .def("attn_with_kvcache",
+ &KVCacheBindings::AttnWithKVCacheBindings::cpuinfer_interface)
+ .def("clear_importance_all_layers",
+ &KVCacheBindings::ClearImportanceAllLayersBindings::
+ cpuinfer_interface)
+ .def("calc_anchor_all_layers",
+ &KVCacheBindings::CalcAnchorAllLayersBindinds::cpuinfer_interface);
}
diff --git a/ktransformers/ktransformers_ext/operators/kvcache/kvcache.h b/ktransformers/ktransformers_ext/operators/kvcache/kvcache.h
new file mode 100644
index 0000000..ac91778
--- /dev/null
+++ b/ktransformers/ktransformers_ext/operators/kvcache/kvcache.h
@@ -0,0 +1,727 @@
+/**
+ * @Description :
+ * @Author : Jianwei Dong
+ * @Date : 2024-08-26 22:47:06
+ * @Version : 1.0.0
+ * @LastEditors : Jianwei Dong
+ * @LastEditTime : 2024-08-26 22:47:06
+ * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
+ **/
+
+#ifndef CPUINFER_OPERATOR_KVCACHE_H
+#define CPUINFER_OPERATOR_KVCACHE_H
+
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+#include "../../cpu_backend/backend.h"
+#include "llama.cpp/ggml-common.h"
+#include "llama.cpp/ggml-impl.h"
+#include "llama.cpp/ggml-quants.h"
+#include "llama.cpp/ggml.h"
+#include "llamafile/sgemm.h"
+
+#define CHUNK_SIZE 32
+
+/**
+ * @brief Converts a ggml_type enum value to its corresponding string
+ * representation.
+ *
+ * This function provides a human-readable string representation for a given
+ * ggml_type enum value. The string can be used for logging, debugging, or
+ * displaying information in a user interface.
+ *
+ * @param type The ggml_type enum value to convert.
+ * @return A string representation of the enum value.
+ */
+std::string ggml_type_to_string(ggml_type type);
+
+/**
+ * @enum AnchorType
+ * @brief Defines the types of anchors used in attention mechanisms.
+ *
+ * This enum specifies different types of anchors that can be used in attention
+ * mechanisms, such as fixed anchors, dynamic anchors, or special anchors like
+ * QUEST, BLOCK_MEAN, or BLOCK_MAX.
+ */
+enum AnchorType {
+ FIXED_ANCHOR, /**< A fixed anchor that does not change. */
+ DYNAMIC, /**< A dynamic anchor that can change over time. */
+ QUEST, /**< A special anchor type used for QUEST (Query and Embedding Space
+ Transformation). */
+ BLOCK_MEAN, /**< An anchor based on the mean of a block of data. */
+ BLOCK_MAX /**< An anchor based on the maximum value within a block of data.
+ */
+};
+
+/**
+ * @brief Converts an AnchorType enum value to its corresponding string
+ * representation.
+ *
+ * This function provides a human-readable string representation for a given
+ * AnchorType enum value. The string can be used for logging, debugging, or
+ * displaying information in a user interface.
+ *
+ * @param anchor_type The AnchorType enum value to convert.
+ * @return A string representation of the enum value.
+ */
+std::string AnchorTypeToString(AnchorType anchor_type);
+
+/**
+ * @enum RetrievalType
+ * @brief Defines the types of retrieval strategies in attention mechanisms.
+ *
+ * This enum specifies different retrieval strategies that can be used in
+ * attention mechanisms, such as layer-level retrieval, key-value head-level
+ * retrieval, or query head-level retrieval.
+ */
+enum RetrievalType {
+ LAYER, /**< Retrieval at the layer level. */
+ KVHEAD, /**< Retrieval at the key-value head level. */
+ QHEAD /**< Retrieval at the query head level. */
+};
+
+/**
+ * @brief Converts a RetrievalType enum value to its corresponding string
+ * representation.
+ *
+ * This function provides a human-readable string representation for a given
+ * RetrievalType enum value. The string can be used for logging, debugging, or
+ * displaying information in a user interface.
+ *
+ * @param retrieval_type The RetrievalType enum value to convert.
+ * @return A string representation of the enum value.
+ */
+std::string RetrievalTypeToString(RetrievalType retrieval_type);
+
+/**
+ * @struct KVCacheConfig
+ * @brief Configuration structure for Key-Value (KV) Cache.
+ *
+ * This structure holds configuration parameters for setting up and managing
+ * a Key-Value (KV) Cache used in various attention mechanisms. It includes
+ * parameters such as the number of layers, the number of heads, the dimension
+ * of each head, block length, anchor information, and memory-related settings.
+ */
+struct KVCacheConfig {
+ int layer_num; /**< Number of layers in the model. */
+ int kv_head_num; /**< Number of heads in the KV Cache. */
+ int q_head_num; /**< Number of heads in the query. */
+ int head_dim; /**< Dimension of each head. */
+ int block_len; /**< Length of each block in the cache. */
+ int anchor_num; /**< Number of anchors used in attention. */
+
+ ggml_type kv_type; /**< Data type of the KV Cache (e.g., fp16, q8_0). */
+
+ // Controls the pre-allocated memory size
+ int max_block_num; /**< Maximum number of blocks that can be allocated. */
+ int max_batch_size; /**< Maximum batch size that can be processed. */
+ int max_thread_num; /**< Maximum number of threads that can be used. */
+
+ AnchorType
+ anchor_type; /**< Type of anchors used in the attention mechanism. */
+ RetrievalType
+ retrieval_type; /**< Type of retrieval strategy used in the cache. */
+
+ int layer_step; /**< Step size between layers. */
+ int token_step; /**< Step size between tokens. */
+ int layer_offset; /**< Offset value for layers. */
+
+ /**
+ * @brief Default constructor for KVCacheConfig.
+ *
+ * Initializes the configuration with default values. This constructor
+ * does not initialize any member variables explicitly.
+ */
+ KVCacheConfig() = default;
+
+ /**
+ * @brief Parameterized constructor for KVCacheConfig.
+ *
+ * This constructor initializes the configuration with specific values
+ * for all member variables.
+ *
+ * @param layer_num The number of layers in the model.
+ * @param kv_head_num The number of heads in the KV Cache.
+ * @param q_head_num The number of heads in the query.
+ * @param head_dim The dimension of each head.
+ * @param block_len The length of each block in the cache.
+ * @param anchor_num The number of anchors used in attention.
+ * @param anchor_type The type of anchors used in the attention mechanism.
+ * @param kv_type The data type of the KV Cache (e.g., fp16, q8_0).
+ * @param retrieval_type The type of retrieval strategy used in the cache.
+ * @param layer_step The step size between layers.
+ * @param token_step The step size between tokens.
+ * @param layer_offset The offset value for layers.
+ * @param max_block_num The maximum number of blocks that can be allocated.
+ * @param max_batch_size The maximum batch size that can be processed.
+ * @param max_thread_num The maximum number of threads that can be used.
+ */
+ KVCacheConfig(int layer_num, int kv_head_num, int q_head_num, int head_dim,
+ int block_len, int anchor_num, AnchorType anchor_type,
+ ggml_type kv_type, RetrievalType retrieval_type,
+ int layer_step, int token_step, int layer_offset,
+ int max_block_num, int max_batch_size, int max_thread_num);
+};
+
+/**
+ * @class KVCache
+ * @brief Manages the Key-Value (KV) Cache used in attention mechanisms.
+ *
+ * The KVCache class provides functionality for managing the Key-Value Cache,
+ * including resizing the cache, retrieving configuration parameters, and
+ * updating internal states. This class is typically used in transformer models
+ * to store and manage past key and value states for efficient attention
+ * computations.
+ */
+class KVCache {
+ public:
+ /**
+ * @brief Constructs a KVCache object with the given configuration.
+ *
+ * Initializes the KVCache with the specified configuration parameters,
+ * such as the number of layers, heads, head dimensions, and other
+ * relevant settings.
+ *
+ * @param config The configuration object containing initialization
+ * parameters.
+ */
+ KVCache(KVCacheConfig config);
+
+ /**
+ * @brief Resizes the number of threads used by the cache.
+ *
+ * This function adjusts the number of threads that the cache can utilize.
+ * It allows dynamic reconfiguration of the parallel processing capabilities
+ * based on the current workload or system resources.
+ *
+ * @param thread_num The new number of threads to use.
+ */
+ void ThreadResize(int thread_num);
+
+ /**
+ * @brief Resizes the batch size managed by the cache.
+ *
+ * This function adjusts the batch size that the cache can handle. It
+ * is useful when the input batch size changes dynamically, allowing
+ * the cache to be reconfigured accordingly.
+ *
+ * @param batch_size The new batch size.
+ */
+ void BatchResize(int batch_size);
+
+ /**
+ * @brief Resizes the number of blocks managed by the cache.
+ *
+ * This function adjusts the number of blocks that the cache can manage.
+ * It allows dynamic reconfiguration of the block structure based on the
+ * current sequence length or other factors.
+ *
+ * @param block_num The new number of blocks.
+ */
+ void BlockResize(int block_num);
+
+ /**
+ * @brief Gets the number of layers in the cache.
+ *
+ * @return The number of layers configured in the cache.
+ */
+ int get_layer_num() { return config_.layer_num; }
+
+ /**
+ * @brief Gets the number of KV heads in the cache.
+ *
+ * @return The number of KV heads configured in the cache.
+ */
+ int get_kv_head_num() { return config_.kv_head_num; }
+
+ /**
+ * @brief Gets the number of query heads in the cache.
+ *
+ * @return The number of query heads configured in the cache.
+ */
+ int get_q_head_num() { return config_.q_head_num; }
+
+ /**
+ * @brief Gets the dimension of each head in the cache.
+ *
+ * @return The dimension of each head.
+ */
+ int get_head_dim() { return config_.head_dim; }
+
+ /**
+ * @brief Gets the length of each block in the cache.
+ *
+ * @return The length of each block.
+ */
+ int get_block_len() { return config_.block_len; }
+
+ /**
+ * @brief Gets the number of blocks for a specific layer.
+ *
+ * @param layer_id The ID of the layer for which to retrieve the block
+ * number.
+ * @return The number of blocks in the specified layer.
+ */
+ int get_block_num(int layer_id) { return past_block_num_[layer_id]; }
+
+ /**
+ * @brief Gets the number of anchors in the cache.
+ *
+ * @return The number of anchors configured in the cache.
+ */
+ int get_anchor_num() { return config_.anchor_num; }
+
+ /**
+ * @brief Gets the total length of the cache.
+ *
+ * @return The total length of the cache.
+ */
+ int get_cache_total_len() { return cache_total_len_; }
+
+ /**
+ * @brief Gets the total number of blocks in the cache.
+ *
+ * This function computes and returns the total number of blocks in the
+ * cache based on the total cache length and the block length configuration.
+ *
+ * @return The total number of blocks in the cache.
+ */
+ int get_cache_total_block_num() {
+ return (cache_total_len_ + config_.block_len - 1) / config_.block_len;
+ }
+
+ /**
+ * @brief Updates the total length of the cache.
+ *
+ * This function sets a new total length for the cache, allowing dynamic
+ * adjustment of the cache size during runtime.
+ *
+ * @param cache_total_len The new total length of the cache.
+ */
+ void update_cache_total_len(int cache_total_len) {
+ cache_total_len_ = cache_total_len;
+ }
+ void attn(const ggml_fp16_t *q_in, ggml_fp16_t *output, float *attn_lse,
+ int layer_idx, int generate_token_idx, int q_len, int batch_size,
+ int max_block_num, int *block_table, int *cache_seqlens,
+ int pick_block_num, int init_block_num, int local_block_num,
+ Backend *backend);
+
+ void update_kvcache_one_block_fp16(const ggml_fp16_t *k_in,
+ const ggml_fp16_t *v_in, int layer_id,
+ int block_idx, Backend *backend);
+
+ void get_kvcache_one_block_fp16(ggml_fp16_t *k_in, ggml_fp16_t *v_in,
+ int layer_id, int block_idx,
+ Backend *backend);
+
+ void update_importance_one_block(const ggml_fp16_t *importance,
+ int layer_id, int block_idx,
+ Backend *backend);
+ void get_importance_one_block(ggml_fp16_t *importance, int layer_id,
+ int block_idx, Backend *backend);
+
+ void get_anchor_one_block(ggml_fp16_t *anchor, int layer_id, int block_idx,
+ Backend *backend);
+
+ void update_anchor_one_block(const ggml_fp16_t *anchor, int layer_id,
+ int block_idx, Backend *backend);
+
+ void calc_anchor_all_layers(int *block_table, int *cache_seqlens,
+ int batch_size, int max_block_num,
+ Backend *backend);
+
+ void load_kvcache(std::string tensor_file_path, Backend *backend);
+ void dump_kvcache(int *block_table, int cache_total_len,
+ std::string tensor_file_path, Backend *backend);
+
+ void get_and_update_kvcache_fp16(ggml_fp16_t *k_in, ggml_fp16_t *v_in,
+ int layer_id, int *block_table,
+ int batch_size, int max_block_num,
+ int *cache_seqlens, int q_len,
+ Backend *backend);
+
+ void get_kvcache_fp16(ggml_fp16_t *k_in, ggml_fp16_t *v_in, int layer_id,
+ int *block_table, int batch_size, int max_block_num,
+ int *cache_seqlens, Backend *backend);
+
+ void update_kvcache_fp16(const ggml_fp16_t *k_in, const ggml_fp16_t *v_in,
+ int layer_id, int *block_table, int batch_size,
+ int max_block_num, int *cache_seqlens, int q_len,
+ Backend *backend);
+
+ void update_importance(const ggml_fp16_t *importance, int layer_id,
+ int *block_table, int batch_size, int max_block_num,
+ int *offset, int width, Backend *backend);
+
+ void attn_with_kvcache(const ggml_fp16_t *q_in, const ggml_fp16_t *k_in,
+ const ggml_fp16_t *v_in, ggml_fp16_t *output,
+ float *attn_lse, int layer_idx,
+ int generate_token_idx, int q_len, int batch_size,
+ int max_block_num, int *block_table,
+ int *cache_seqlens, int topk, int local,
+ Backend *backend);
+
+ void clear_importance_all_layers(int *block_table, int *cache_seqlens,
+ int batch_size, int max_block_num,
+ Backend *backend);
+
+ void clear_kvcache_all_layers(int *block_table, int *cache_seqlens,
+ int batch_size, int max_block_num,
+ Backend *backend);
+
+ void get_sincos(ggml_fp16_t *sin, ggml_fp16_t *cos, int seqlen);
+
+ void get_attn_sparsity(const ggml_fp16_t *q_in, float *attn_sparsity,
+ int layer_idx, int generate_token_idx, int q_len,
+ int batch_size, int max_block_num, int *block_table,
+ int *cache_seqlens, int *block_table_origin,
+ int *cache_seqlens_origin, int max_block_num_origin,
+ int topk, int local, Backend *backend);
+
+ void get_all_kvcache_one_layer(int layer_id, ggml_fp16_t *k_in,
+ ggml_fp16_t *v_in, Backend *backend);
+
+ private:
+ // Persistent data
+ KVCacheConfig config_;
+ int n_gqa_; // q_head_num / kv_head_num
+ int cache_total_len_; // Number of tokens in cache
+ std::vector past_block_num_; // [layer_num]
+ std::vector>>>
+ k_cache_q4; // [layer_num, kv_head_num, past_block_num, block_len *
+ // (head_dim / QK_4)]
+ std::vector>>>
+ v_cache_q4; // [layer_num, kv_head_num, past_block_num, head_dim *
+ // (block_len / QK_4)]
+ std::vector>>>
+ k_cache_q8; // [layer_num, kv_head_num, past_block_num, block_len *
+ // (head_dim / QK_8)]
+ std::vector>>>
+ v_cache_q8; // [layer_num, kv_head_num, past_block_num, head_dim *
+ // (block_len / QK_8)]
+
+ std::vector>>>
+ k_cache_fp16_; // [layer_num, kv_head_num, past_block_num, block_len *
+ // head_dim]
+ std::vector>>>
+ v_cache_fp16_; // [layer_num, kv_head_num, past_block_num, head_dim *
+ // block_len]
+
+ std::vector>>>
+ importance_; // [layer_num, past_block_num, block_len,
+ // attention_head_num]
+
+ std::vector
+ anchor_; // [layer_num * past_block_num * anchor_num *
+ // attention_head_num * head_dim]
+
+ // Runtime data
+ int64_t layer_id_;
+ int64_t block_idx_;
+ int *block_table_;
+ uint64_t block_num_;
+ int max_block_num_after_retrieval_;
+
+ // Rotary positional embeddings
+ std::vector> sin_; // [seq_len, head_dim]
+ std::vector> cos_; // [seq_len, head_dim]
+
+ // update/get
+ int seq_len_;
+ uint16_t *k_scales_; // q4_0
+ uint8_t *k_in_; // q4_0
+ uint16_t *v_scales_; // q4_0
+ uint8_t *v_in_; // q4_0
+ uint16_t *k_data_; // fp16
+ uint16_t *v_data_; // fp16
+ uint16_t *importance_data_; // fp16
+ uint16_t *anchor_data_; // fp16
+
+ // sparsity = (sigma(block lse / lse))
+ std::vector>>
+ block_lse_; // [batch_size, max_block_num, q_head_num]
+ std::vector> attn_sparsity_; // [batch_size, q_head_num]
+
+ // attn
+ std::vector>
+ avg_q; // [batch_size, q_head_num * head_dim]
+
+ std::vector>
+ avg_q_fp16; // [batch_size, q_head_num * head_dim]
+ std::vector<
+ std::priority_queue,
+ std::vector>, std::greater<>>>
+ top_similar_block_;
+
+ std::vector> block_similar_;
+ std::vector>> block_similar_kv_head_;
+ std::vector>> block_similar_q_head_;
+
+ std::vector cache_seqlens_; // [batch_size]
+ std::vector selected_blocks_num_history_; // [layer_num // layer_step]
+
+ std::vector>> selected_blocks_history_;
+ // [layer_num // layer_step, batch_size, max_block_num]
+
+ std::vector>>>
+ selected_blocks_history_kvhead_; // [layer_num // layer_step,
+ // batch_size, max_block_num,
+ // kv_head_num]
+
+ std::vector>
+ block_table_before_retrieval_; // [batch_size, max_block_num]
+ std::vector>
+ block_table_after_retrieval_; // [batch_size, pick_block_num]
+
+ std::vector>>
+ block_table_before_retrieval_qhead_; // [batch_size, max_block_num,
+ // q_head_num]
+ std::vector>>
+ block_table_after_retrieval_qhead_; // [batch_size, pick_block_num,
+ // q_head_num]
+
+ std::vector>>
+ block_table_before_retrieval_kvhead_; // [batch_size, max_block_num,
+ // kv_head_num]
+ std::vector>>
+ block_table_after_retrieval_kvhead_; // [batch_size, pick_block_num,
+ // kv_head_num]
+
+ std::vector>>
+ mutex_; // [batch_size, kv_head_num]
+ std::vector>>
+ q_q8_0_; // [batch_size, kv_head_num, n_gqa * head_dim / QK8_0]
+ std::vector>>
+ q_fp32_; // [batch_size, kv_head_num, n_gqa * head_dim]
+
+ std::vector>>
+ output_fp32_; // [batch_size, kv_head_num, n_gqa * head_dim]
+ std::vector>>
+ attn_lse_; // [batch_size, kv_head_num, n_gqa]
+
+ std::vector> thread_cur_head_idx_; // [thread_num]
+
+ std::vector>
+ thread_local_output_q8_0_; // [thread_num, n_gqa * head_dim / QK8_0]
+ std::vector>
+ thread_local_attn_score_; // [thread_num, n_gqa * block_len]
+ std::vector>
+ thread_local_output_fp32_; // [thread_num, n_gqa * head_dim]
+ std::vector>
+ thread_local_attn_lse_; // [thread_num, n_gqa]
+ std::vector>
+ thread_local_cur_output_fp32_; // [thread_num, n_gqa * head_dim]
+ std::vector>
+ thread_local_cur_attn_lse_; // [thread_num, n_gqa]
+ std::vector>
+ thread_local_attn_mask_; // [thread_num, block_len // 8]
+ std::vector>
+ thread_local_draft_; // [thread_num, 2 * n_gqa * block_len + 6 * n_gqa *
+ // head_dim + 2 * block_len * head_dim]
+
+ // tmp space
+ std::vector q_fp32; // [n_gqa * head_dim]
+
+ void quantize_q_(const uint16_t *q_in_data, int batch_size);
+ void attn_initialize_layer_(int batch_size, int layer_idx, int *block_table,
+ int &max_block_num, int *cache_seqlens);
+ void attn_initialize_kvhead_(int batch_size, int layer_idx,
+ int *block_table, int &max_block_num,
+ int *cache_seqlens);
+ void retrieval_kvcache_layer_(const uint16_t *q_in_data, int init_block_num,
+ int local_block_num, int pick_block_num,
+ int q_len, int generate_token_idx,
+ int batch_size, int layer_idx,
+ int *cache_seqlens, int &max_block_num,
+ Backend *backend);
+ void retrieval_kvcache_kvhead_(const uint16_t *q_in_data,
+ int init_block_num, int local_block_num,
+ int pick_block_num, int q_len,
+ int generate_token_idx, int batch_size,
+ int layer_idx, int *cache_seqlens,
+ int &max_block_num, Backend *backend);
+
+ void calculate_block_similarity_layer_(
+ const uint16_t *q_in_data, int batch_size, int layer_idx, int q_len,
+ int max_block_num, int *cache_seqlens, int init_block_num,
+ int local_block_num, int pick_block_num, Backend *backend);
+ void calculate_block_similarity_kvhead_(
+ const uint16_t *q_in_data, int batch_size, int layer_idx, int q_len,
+ int max_block_num, int *cache_seqlens, int init_block_num,
+ int local_block_num, int pick_block_num, Backend *backend);
+
+ void select_block_layer_(int batch_size, int layer_idx, int max_block_num,
+ int init_block_num, int local_block_num,
+ int pick_block_num);
+ void select_block_kvhead_(int batch_size, int layer_idx, int max_block_num,
+ int init_block_num, int local_block_num,
+ int pick_block_num);
+
+ void calculate_sparsity_layer_(const uint16_t *q_in_data,
+ float *attn_sparsity, int batch_size,
+ int max_block_num, int *block_table,
+ int *cache_seqlens, Backend *backend);
+ void calculate_sparsity_kvhead_(const uint16_t *q_in_data,
+ float *attn_sparsity, int batch_size,
+ int max_block_num, int *block_table,
+ int *cache_seqlens, Backend *backend);
+
+ void attention_kvhead_(const uint16_t *q_in_data, ggml_fp16_t *output,
+ float *attn_lse, int batch_size, Backend *backend);
+ void attention_layer_(const uint16_t *q_in_data, ggml_fp16_t *output,
+ float *attn_lse, int batch_size, Backend *backend);
+
+ /**
+ * @brief Computes attention with KV cache for one block.
+ *
+ * This function performs attention computation for one block using KV
+ * cache. The function supports different data types for Q, K, and V caches,
+ * and provides options for quantization. The function does not perform any
+ * dynamic memory allocation internally, so all necessary buffers must be
+ * pre-allocated externally.
+ *
+ * @param head_dim The dimension of the head.
+ * @param bsz The batch size.
+ * @param q_type The data type of Q (GGML data type). Only supports fp16 and
+ * q8_0.
+ * @param q Pointer to the Q tensor [bsz, head_dim]. The quantization is
+ * always applied along the head_dim dimension. The size must be
+ * bsz * head_dim/32 * qtype_size. If head_dim % 32 != 0, an error
+ * will be raised.
+ * @param past_kv_len The length of the past KV cache.
+ * @param past_kv_offset The offset in the past KV cache.
+ * @param is_full_attn Boolean flag indicating whether to use full attention
+ * (true for full 1 mask).
+ * @param attn_mask Pointer to the attention mask [bsz, past_kv_len]. If
+ * is_full_attn = false, a bit matrix is passed to
+ * represent the mask.
+ * @param k_type The data type of K cache (GGML data type). Only supports
+ * fp16, q4_0, and q8_0.
+ * @param k_quant_type Quantization type for K cache. 0 for per_token, 1 for
+ * per_channel. Other values will raise an error.
+ * @param k_cache Pointer to the K cache tensor [seq_len, head_dim]. If
+ * quant_type == 0, head_dim % 32 must be 0. If quant_type ==
+ * 1, seq_len % 32 must be 0.
+ * @param num_k_anchor The number of K anchors. If num_k_anchor == 0, it
+ * means no anchor is present.
+ * @param k_cache_anchors Pointer to the K cache anchors [num_k_anchor,
+ * head_dim]. The k_anchor_type must be fp16.
+ * @param k_cache_anchor_pos Pointer to the K cache anchor positions. Each
+ * token is associated with the nearest previous anchor position.
+ * @param v_type The data type of V cache (GGML data type).
+ * @param v_quant_type Quantization type for V cache.
+ * @param v_cache Pointer to the V cache tensor [head_dim, seq_len].
+ * @param num_v_anchor The number of V anchors.
+ * @param v_cache_anchors Pointer to the V cache anchors.
+ * @param v_cache_anchor_pos Pointer to the V cache anchor positions.
+ * @param attn_score Pre-allocated buffer for attention scores [bsz,
+ * past_kv_len].
+ * @param output Output tensor [bsz, head_dim] with the same type as q_type.
+ * @param lse Pre-allocated buffer [bsz] for the log-sum-exp of the
+ * attention scores.
+ * @param draft Pre-allocated temporary buffer. The buffer size should be
+ * enough to hold (2 * bsz * past_kv_len + 6 * bsz * head_dim + 2 *
+ * past_kv_len * head_dim + past_kv_len * head_dim / 32) bytes.
+ * @param rotary_angle Pointer to the rotary angle tensor.
+ * @param rotary_cos Pointer to the cosine values for rotary embedding.
+ * @param rotary_sin Pointer to the sine values for rotary embedding.
+ */
+ void attn_with_kvcache_one_block_(
+ int head_dim, int bsz,
+ ggml_type q_type, // GGML data type of `Q`, only supports fp16 and q8_0
+ // [bsz, head_dim]
+ // Quantization is always on the head_dim dimension (per_token). If
+ // head_dim % 32 != 0, an error will be raised. The size must be bsz *
+ // head_dim/32 * qtype_size.
+ const void *q,
+
+ int past_kv_len, int past_kv_offset,
+ bool is_full_attn, // true indicates a full 1 mask
+ // If is_full_attn = false, a bit matrix representing the mask is
+ // passed. [bsz, past_kv_len]
+ const uint8_t *attn_mask,
+
+ ggml_type k_type, // GGML data type of `K Cache`, only supports fp16,
+ // q4_0, q8_0
+ int k_quant_type, // 0 for per_token, 1 for per_channel, others raise an
+ // error
+ // [seq_len, head_dim]
+ // If quant_type == 0, head_dim % 32 must be 0.
+ // If quant_type == 1, seq_len % 32 must be 0.
+ const void *k_cache,
+
+ // k_anchor_type must be fp16
+ int num_k_anchor, // num_k_anchor == 0 indicates no anchor
+ // [num_k_anchor, head_dim]
+ const void *k_cache_anchors,
+ // Each token is associated with the nearest previous position's anchor,
+ // with the same distance.
+ const int *k_cache_anchor_pos,
+
+ // v_cache similar to k_cache
+ ggml_type v_type, int v_quant_type,
+ // [head_dim, seq_len]
+ const void *v_cache, int num_v_anchor, const void *v_cache_anchors,
+ const int *v_cache_anchor_pos,
+
+ // Pre-allocated buffer for intermediate calculations [bsz,
+ // past_kv_len]. No malloc is performed inside this function.
+ float *attn_score,
+
+ // Output: [bsz, head_dim], with the same type as q_type
+ void *output,
+ // [bsz]
+ float *lse,
+
+ // Pre-allocated temporary buffer with sufficient size:
+ // (2 * bsz * past_kv_len + 6 * bsz * head_dim + 2 * past_kv_len *
+ // head_dim + past_kv_len * head_dim / 32) bytes.
+ void *draft,
+
+ // Apply rotary embedding online
+ const int *rotary_angle, const void *rotary_cos, const void *rotary_sin
+ // rotary_cos=None,
+ // rotary_sin=None,
+ // cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
+ // cache_batch_idx: Optional[torch.Tensor] = None,
+ // rotary_interleaved=True,
+
+ // // Not supported for now
+ // window_size=(-1, -1), # -1 means infinite context window
+ // alibi_slopes=None,
+ );
+};
+
+/**
+ * @brief Scales a float32 vector by a given scalar value.
+ *
+ * This function multiplies each element of the input vector `y` by a scalar
+ * `v`. It uses platform-specific optimizations if available, such as Apple's
+ * Accelerate framework or SIMD instructions. If no specific optimization is
+ * available, the function falls back to a simple scalar multiplication loop.
+ *
+ * @param n The number of elements in the vector `y`.
+ * @param y The input vector to be scaled. The result will be stored in the same
+ * vector.
+ * @param v The scalar value by which to scale the vector.
+ */
+void ggml_vec_scale_f32(const int n, float *y, const float v);
+#endif
\ No newline at end of file
diff --git a/ktransformers/ktransformers_ext/operators/kvcache/kvcache_attn.cpp b/ktransformers/ktransformers_ext/operators/kvcache/kvcache_attn.cpp
new file mode 100644
index 0000000..c59cb94
--- /dev/null
+++ b/ktransformers/ktransformers_ext/operators/kvcache/kvcache_attn.cpp
@@ -0,0 +1,2533 @@
+/**
+ * @Description :
+ * @Author : Jianwei Dong
+ * @Date : 2024-08-26 22:47:06
+ * @Version : 1.0.0
+ * @LastEditors : Jianwei Dong
+ * @LastEditTime : 2024-08-26 22:47:06
+ * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
+ **/
+
+#include "kvcache.h"
+
+void KVCache::attention_kvhead_(const uint16_t *q_in_data, ggml_fp16_t *output,
+ float *attn_lse, int batch_size,
+ Backend *backend) {
+ // Timer start
+ auto start = std::chrono::high_resolution_clock::now();
+ seq_len_ = config_.block_len;
+
+ backend->do_work_stealing_job(
+ batch_size * config_.kv_head_num * max_block_num_after_retrieval_,
+ [&](int thread_id) {
+ thread_cur_head_idx_[thread_id].first = -1;
+ thread_cur_head_idx_[thread_id].second = -1;
+ },
+ [&](int task_id) {
+ int batch_id = task_id / (config_.kv_head_num *
+ max_block_num_after_retrieval_);
+ int head_id = (task_id % (config_.kv_head_num *
+ max_block_num_after_retrieval_)) /
+ max_block_num_after_retrieval_;
+ int block_id = task_id % max_block_num_after_retrieval_;
+ int thread_id = Backend::thread_local_id;
+
+ // If the block is out of the sequence length, skip it.
+ if (cache_seqlens_[batch_id] / config_.block_len < block_id) {
+ return;
+ }
+ int block_idx =
+ block_table_after_retrieval_kvhead_[batch_id][block_id]
+ [head_id];
+ if (cache_seqlens_[batch_id] / config_.block_len == block_id) {
+ int seq_len = cache_seqlens_[batch_id] % config_.block_len;
+ if (seq_len == 0)
+ return;
+
+ // Prepare the attention mask for the last block.
+ int full_blocks = seq_len / 8;
+ int remaining_bits = seq_len % 8;
+ // Fill full blocks with 1s
+ for (int i = 0; i < full_blocks; ++i) {
+ thread_local_attn_mask_[thread_id][i] = 0xFF;
+ }
+ // Fill the remaining bits in the next block
+ if (remaining_bits > 0 && full_blocks < seq_len_ / 8) {
+ thread_local_attn_mask_[thread_id][full_blocks] =
+ (1 << remaining_bits) - 1;
+ } else {
+ thread_local_attn_mask_[thread_id][full_blocks] = 0;
+ }
+
+ for (int i = full_blocks + 1; i < seq_len_ / 8; ++i) {
+ thread_local_attn_mask_[thread_id][i] = 0;
+ }
+ if (config_.kv_type == ggml_type::GGML_TYPE_F16) {
+ attn_with_kvcache_one_block_(
+ config_.head_dim,
+ config_.q_head_num / config_.kv_head_num, GGML_TYPE_F16,
+ (void *)&q_in_data[batch_id * config_.kv_head_num *
+ n_gqa_ * config_.head_dim +
+ head_id * n_gqa_ * config_.head_dim],
+ seq_len_, 0, false,
+ thread_local_attn_mask_[thread_id].data(),
+ GGML_TYPE_F16, 0,
+ k_cache_fp16_[layer_id_][head_id][block_idx].data(), 0,
+ nullptr, nullptr, GGML_TYPE_F16, 1,
+ v_cache_fp16_[layer_id_][head_id][block_idx].data(), 0,
+ nullptr, nullptr,
+ thread_local_attn_score_[thread_id].data(),
+ thread_local_output_fp32_[thread_id].data(),
+ thread_local_attn_lse_[thread_id].data(),
+ thread_local_draft_[thread_id].data(), nullptr,
+ cos_.data(), sin_.data());
+ } else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {
+ attn_with_kvcache_one_block_(
+ config_.head_dim,
+ config_.q_head_num / config_.kv_head_num,
+ GGML_TYPE_Q8_0, q_q8_0_[batch_id][head_id].data(),
+ seq_len_, 0, false,
+ thread_local_attn_mask_[thread_id].data(),
+ GGML_TYPE_Q4_0, 0,
+ k_cache_q4[layer_id_][head_id][block_idx].data(), 0,
+ nullptr, nullptr, GGML_TYPE_Q4_0, 1,
+ v_cache_q4[layer_id_][head_id][block_idx].data(), 0,
+ nullptr, nullptr,
+ thread_local_attn_score_[thread_id].data(),
+ thread_local_output_q8_0_[thread_id].data(),
+ thread_local_attn_lse_[thread_id].data(),
+ thread_local_draft_[thread_id].data(), nullptr,
+ cos_.data(), sin_.data());
+ dequantize_row_q8_0(
+ thread_local_output_q8_0_[thread_id].data(),
+ thread_local_output_fp32_[thread_id].data(),
+ n_gqa_ * config_.head_dim);
+ } else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {
+ attn_with_kvcache_one_block_(
+ config_.head_dim,
+ config_.q_head_num / config_.kv_head_num,
+ GGML_TYPE_Q8_0, q_q8_0_[batch_id][head_id].data(),
+ seq_len_, 0, false,
+ thread_local_attn_mask_[thread_id].data(),
+ GGML_TYPE_Q8_0, 0,
+ k_cache_q8[layer_id_][head_id][block_idx].data(), 0,
+ nullptr, nullptr, GGML_TYPE_Q8_0, 1,
+ v_cache_q8[layer_id_][head_id][block_idx].data(), 0,
+ nullptr, nullptr,
+ thread_local_attn_score_[thread_id].data(),
+ thread_local_output_q8_0_[thread_id].data(),
+ thread_local_attn_lse_[thread_id].data(),
+ thread_local_draft_[thread_id].data(), nullptr,
+ cos_.data(), sin_.data());
+ dequantize_row_q8_0(
+ thread_local_output_q8_0_[thread_id].data(),
+ thread_local_output_fp32_[thread_id].data(),
+ n_gqa_ * config_.head_dim);
+ }
+ } else {
+ if (config_.kv_type == ggml_type::GGML_TYPE_F16) {
+ attn_with_kvcache_one_block_(
+ config_.head_dim,
+ config_.q_head_num / config_.kv_head_num, GGML_TYPE_F16,
+ (void *)&q_in_data[batch_id * config_.kv_head_num *
+ n_gqa_ * config_.head_dim +
+ head_id * n_gqa_ * config_.head_dim],
+ seq_len_, 0, true, nullptr, GGML_TYPE_F16, 0,
+ k_cache_fp16_[layer_id_][head_id][block_idx].data(), 0,
+ nullptr, nullptr, GGML_TYPE_F16, 1,
+ v_cache_fp16_[layer_id_][head_id][block_idx].data(), 0,
+ nullptr, nullptr,
+ thread_local_attn_score_[thread_id].data(),
+ thread_local_output_fp32_[thread_id].data(),
+ thread_local_attn_lse_[thread_id].data(),
+ thread_local_draft_[thread_id].data(), nullptr,
+ cos_.data(), sin_.data());
+
+ } else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {
+ attn_with_kvcache_one_block_(
+ config_.head_dim,
+ config_.q_head_num / config_.kv_head_num,
+ GGML_TYPE_Q8_0, q_q8_0_[batch_id][head_id].data(),
+ seq_len_, 0, true, nullptr, GGML_TYPE_Q4_0, 0,
+ k_cache_q4[layer_id_][head_id][block_idx].data(), 0,
+ nullptr, nullptr, GGML_TYPE_Q4_0, 1,
+ v_cache_q4[layer_id_][head_id][block_idx].data(), 0,
+ nullptr, nullptr,
+ thread_local_attn_score_[thread_id].data(),
+ thread_local_output_q8_0_[thread_id].data(),
+ thread_local_attn_lse_[thread_id].data(),
+ thread_local_draft_[thread_id].data(), nullptr,
+ cos_.data(), sin_.data());
+ dequantize_row_q8_0(
+ thread_local_output_q8_0_[thread_id].data(),
+ thread_local_output_fp32_[thread_id].data(),
+ n_gqa_ * config_.head_dim);
+ } else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {
+ attn_with_kvcache_one_block_(
+ config_.head_dim,
+ config_.q_head_num / config_.kv_head_num,
+ GGML_TYPE_Q8_0, q_q8_0_[batch_id][head_id].data(),
+ seq_len_, 0, true, nullptr, GGML_TYPE_Q8_0, 0,
+ k_cache_q8[layer_id_][head_id][block_idx].data(), 0,
+ nullptr, nullptr, GGML_TYPE_Q8_0, 1,
+ v_cache_q8[layer_id_][head_id][block_idx].data(), 0,
+ nullptr, nullptr,
+ thread_local_attn_score_[thread_id].data(),
+ thread_local_output_q8_0_[thread_id].data(),
+ thread_local_attn_lse_[thread_id].data(),
+ thread_local_draft_[thread_id].data(), nullptr,
+ cos_.data(), sin_.data());
+ dequantize_row_q8_0(
+ thread_local_output_q8_0_[thread_id].data(),
+ thread_local_output_fp32_[thread_id].data(),
+ n_gqa_ * config_.head_dim);
+ }
+ }
+ int cur_batch_idx = thread_cur_head_idx_[thread_id].first;
+ int cur_head_id = thread_cur_head_idx_[thread_id].second;
+ if (batch_id == cur_batch_idx && head_id == cur_head_id) {
+ for (int i = 0; i < n_gqa_; i++) {
+ float new_attn_lse =
+ thread_local_cur_attn_lse_[thread_id][i] +
+ std::log(
+ 1.0 +
+ std::exp(thread_local_attn_lse_[thread_id][i] -
+ thread_local_cur_attn_lse_[thread_id][i]));
+ ggml_vec_scale_f32(
+ config_.head_dim,
+ thread_local_cur_output_fp32_[thread_id].data() +
+ i * config_.head_dim,
+ std::exp(thread_local_cur_attn_lse_[thread_id][i] -
+ new_attn_lse));
+ ggml_vec_scale_f32(
+ config_.head_dim,
+ thread_local_output_fp32_[thread_id].data() +
+ i * config_.head_dim,
+ std::exp(thread_local_attn_lse_[thread_id][i] -
+ new_attn_lse));
+ for (int j = 0; j < config_.head_dim; j++) {
+ thread_local_cur_output_fp32_[thread_id]
+ [i * config_.head_dim +
+ j] +=
+ thread_local_output_fp32_[thread_id]
+ [i * config_.head_dim + j];
+ }
+ thread_local_cur_attn_lse_[thread_id][i] = new_attn_lse;
+ }
+ } else {
+ if (cur_batch_idx != -1) {
+ mutex_[cur_batch_idx][cur_head_id]->lock();
+ for (int i = 0; i < n_gqa_; i++) {
+ if (std::abs(attn_lse_[cur_batch_idx][cur_head_id][i]) <
+ 1e-6) {
+ attn_lse_[cur_batch_idx][cur_head_id][i] =
+ thread_local_cur_attn_lse_[thread_id][i];
+ for (int j = 0; j < config_.head_dim; j++) {
+ output_fp32_[cur_batch_idx][cur_head_id]
+ [i * config_.head_dim + j] =
+ thread_local_cur_output_fp32_
+ [thread_id]
+ [i * config_.head_dim + j];
+ }
+ continue;
+ }
+ float new_attn_lse =
+ attn_lse_[cur_batch_idx][cur_head_id][i] +
+ std::log(
+ 1.0 +
+ std::exp(
+ thread_local_cur_attn_lse_[thread_id][i] -
+ attn_lse_[cur_batch_idx][cur_head_id][i]));
+ ggml_vec_scale_f32(
+ config_.head_dim,
+ output_fp32_[cur_batch_idx][cur_head_id].data() +
+ i * config_.head_dim,
+ std::exp(attn_lse_[cur_batch_idx][cur_head_id][i] -
+ new_attn_lse));
+ ggml_vec_scale_f32(
+ config_.head_dim,
+ thread_local_cur_output_fp32_[thread_id].data() +
+ i * config_.head_dim,
+ std::exp(thread_local_cur_attn_lse_[thread_id][i] -
+ new_attn_lse));
+ for (int j = 0; j < config_.head_dim; j++) {
+ output_fp32_[cur_batch_idx][cur_head_id]
+ [i * config_.head_dim + j] +=
+ thread_local_cur_output_fp32_
+ [thread_id][i * config_.head_dim + j];
+ }
+ attn_lse_[cur_batch_idx][cur_head_id][i] = new_attn_lse;
+ }
+ mutex_[cur_batch_idx][cur_head_id]->unlock();
+ }
+ thread_cur_head_idx_[thread_id].first = batch_id;
+ thread_cur_head_idx_[thread_id].second = head_id;
+ for (int i = 0; i < n_gqa_; i++) {
+ thread_local_cur_attn_lse_[thread_id][i] =
+ thread_local_attn_lse_[thread_id][i];
+ for (int j = 0; j < config_.head_dim; j++) {
+ thread_local_cur_output_fp32_
+ [thread_id][i * config_.head_dim + j] =
+ thread_local_output_fp32_[thread_id]
+ [i * config_.head_dim +
+ j];
+ }
+ }
+ }
+ },
+ // Merge the results of the remaining blocks.
+ [&](int thread_id) {
+ int cur_batch_idx = thread_cur_head_idx_[thread_id].first;
+ int cur_head_id = thread_cur_head_idx_[thread_id].second;
+ if (cur_head_id != -1) {
+ mutex_[cur_batch_idx][cur_head_id]->lock();
+ for (int i = 0; i < n_gqa_; i++) {
+ float new_attn_lse;
+ if (std::abs(attn_lse_[cur_batch_idx][cur_head_id][i]) <
+ 1e-6) {
+ attn_lse_[cur_batch_idx][cur_head_id][i] =
+ thread_local_cur_attn_lse_[thread_id][i];
+ for (int j = 0; j < config_.head_dim; j++) {
+ output_fp32_[cur_batch_idx][cur_head_id]
+ [i * config_.head_dim + j] =
+ thread_local_cur_output_fp32_
+ [thread_id]
+ [i * config_.head_dim + j];
+ }
+ continue;
+ }
+ new_attn_lse =
+ attn_lse_[cur_batch_idx][cur_head_id][i] +
+ std::log(
+ 1.0 +
+ std::exp(thread_local_cur_attn_lse_[thread_id][i] -
+ attn_lse_[cur_batch_idx][cur_head_id][i]));
+ ggml_vec_scale_f32(
+ config_.head_dim,
+ output_fp32_[cur_batch_idx][cur_head_id].data() +
+ i * config_.head_dim,
+ std::exp(attn_lse_[cur_batch_idx][cur_head_id][i] -
+ new_attn_lse));
+ ggml_vec_scale_f32(
+ config_.head_dim,
+ thread_local_cur_output_fp32_[thread_id].data() +
+ i * config_.head_dim,
+ std::exp(thread_local_cur_attn_lse_[thread_id][i] -
+ new_attn_lse));
+ for (int j = 0; j < config_.head_dim; j++) {
+ output_fp32_[cur_batch_idx][cur_head_id]
+ [i * config_.head_dim + j] +=
+ thread_local_cur_output_fp32_[thread_id]
+ [i * config_.head_dim +
+ j];
+ }
+ attn_lse_[cur_batch_idx][cur_head_id][i] = new_attn_lse;
+ }
+ mutex_[cur_batch_idx][cur_head_id]->unlock();
+ }
+ });
+ // move the results to output and attn_lse
+ uint16_t *output_data = reinterpret_cast(output);
+ float *attn_lse_data = attn_lse;
+ for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {
+ for (int i = 0; i < config_.kv_head_num; i++) {
+ for (int j = 0; j < n_gqa_ * config_.head_dim; j++) {
+ output_data[batch_idx * config_.kv_head_num * n_gqa_ *
+ config_.head_dim +
+ i * n_gqa_ * config_.head_dim + j] =
+ GGML_FP32_TO_FP16(output_fp32_[batch_idx][i][j]);
+ }
+ for (int j = 0; j < n_gqa_; j++) {
+ attn_lse_data[batch_idx * config_.kv_head_num * n_gqa_ +
+ i * n_gqa_ + j] = attn_lse_[batch_idx][i][j];
+ }
+ }
+ }
+
+ // Timer end
+ auto end = std::chrono::high_resolution_clock::now();
+ std::chrono::duration diff = end - start;
+ // printf("layer %d time of computing attention: %f s\n", layer_idx,
+ // diff.count());
+}
+
+void KVCache::attention_layer_(const uint16_t *q_in_data, ggml_fp16_t *output,
+ float *attn_lse, int batch_size,
+ Backend *backend) {
+ // Timer start
+ auto start = std::chrono::high_resolution_clock::now();
+ seq_len_ = config_.block_len;
+ backend->do_work_stealing_job(
+ batch_size * config_.kv_head_num * max_block_num_after_retrieval_,
+ [&](int thread_id) {
+ thread_cur_head_idx_[thread_id].first = -1;
+ thread_cur_head_idx_[thread_id].second = -1;
+ },
+ [&](int task_id) {
+ int batch_id = task_id / (config_.kv_head_num *
+ max_block_num_after_retrieval_);
+ int head_id = (task_id % (config_.kv_head_num *
+ max_block_num_after_retrieval_)) /
+ max_block_num_after_retrieval_;
+ int block_id = task_id % max_block_num_after_retrieval_;
+ int thread_id = Backend::thread_local_id;
+ // If the block is out of the sequence length, skip it.
+ if (cache_seqlens_[batch_id] / config_.block_len < block_id) {
+ return;
+ }
+ int block_idx = block_table_after_retrieval_[batch_id][block_id];
+ if (cache_seqlens_[batch_id] / config_.block_len == block_id) {
+ int seq_len = cache_seqlens_[batch_id] % config_.block_len;
+ if (seq_len == 0)
+ return;
+
+ // Prepare the attention mask for the last block.
+ int full_blocks = seq_len / 8;
+ int remaining_bits = seq_len % 8;
+
+ // Fill full blocks with 1s
+ for (int i = 0; i < full_blocks; ++i) {
+ thread_local_attn_mask_[thread_id][i] = 0xFF;
+ }
+ // Fill the remaining bits in the next block
+ if (remaining_bits > 0 && full_blocks < seq_len_ / 8) {
+ thread_local_attn_mask_[thread_id][full_blocks] =
+ (1 << remaining_bits) - 1;
+ } else {
+ thread_local_attn_mask_[thread_id][full_blocks] = 0;
+ }
+
+ for (int i = full_blocks + 1; i < seq_len_ / 8; ++i) {
+ thread_local_attn_mask_[thread_id][i] = 0;
+ }
+ if (config_.kv_type == ggml_type::GGML_TYPE_F16) {
+ attn_with_kvcache_one_block_(
+ config_.head_dim,
+ config_.q_head_num / config_.kv_head_num, GGML_TYPE_F16,
+ (void *)&q_in_data[batch_id * config_.kv_head_num *
+ n_gqa_ * config_.head_dim +
+ head_id * n_gqa_ * config_.head_dim],
+ seq_len_, 0, false,
+ thread_local_attn_mask_[thread_id].data(),
+ GGML_TYPE_F16, 0,
+ k_cache_fp16_[layer_id_][head_id][block_idx].data(), 0,
+ nullptr, nullptr, GGML_TYPE_F16, 1,
+ v_cache_fp16_[layer_id_][head_id][block_idx].data(), 0,
+ nullptr, nullptr,
+ thread_local_attn_score_[thread_id].data(),
+ thread_local_output_fp32_[thread_id].data(),
+ thread_local_attn_lse_[thread_id].data(),
+ thread_local_draft_[thread_id].data(), nullptr,
+ cos_.data(), sin_.data());
+ } else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {
+ attn_with_kvcache_one_block_(
+ config_.head_dim,
+ config_.q_head_num / config_.kv_head_num,
+ GGML_TYPE_Q8_0, q_q8_0_[batch_id][head_id].data(),
+ seq_len_, 0, false,
+ thread_local_attn_mask_[thread_id].data(),
+ GGML_TYPE_Q4_0, 0,
+ k_cache_q4[layer_id_][head_id][block_idx].data(), 0,
+ nullptr, nullptr, GGML_TYPE_Q4_0, 1,
+ v_cache_q4[layer_id_][head_id][block_idx].data(), 0,
+ nullptr, nullptr,
+ thread_local_attn_score_[thread_id].data(),
+ thread_local_output_q8_0_[thread_id].data(),
+ thread_local_attn_lse_[thread_id].data(),
+ thread_local_draft_[thread_id].data(), nullptr,
+ cos_.data(), sin_.data());
+ dequantize_row_q8_0(
+ thread_local_output_q8_0_[thread_id].data(),
+ thread_local_output_fp32_[thread_id].data(),
+ n_gqa_ * config_.head_dim);
+ } else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {
+ attn_with_kvcache_one_block_(
+ config_.head_dim,
+ config_.q_head_num / config_.kv_head_num,
+ GGML_TYPE_Q8_0, q_q8_0_[batch_id][head_id].data(),
+ seq_len_, 0, false,
+ thread_local_attn_mask_[thread_id].data(),
+ GGML_TYPE_Q8_0, 0,
+ k_cache_q8[layer_id_][head_id][block_idx].data(), 0,
+ nullptr, nullptr, GGML_TYPE_Q8_0, 1,
+ v_cache_q8[layer_id_][head_id][block_idx].data(), 0,
+ nullptr, nullptr,
+ thread_local_attn_score_[thread_id].data(),
+ thread_local_output_q8_0_[thread_id].data(),
+ thread_local_attn_lse_[thread_id].data(),
+ thread_local_draft_[thread_id].data(), nullptr,
+ cos_.data(), sin_.data());
+ dequantize_row_q8_0(
+ thread_local_output_q8_0_[thread_id].data(),
+ thread_local_output_fp32_[thread_id].data(),
+ n_gqa_ * config_.head_dim);
+ }
+ } else {
+ if (config_.kv_type == ggml_type::GGML_TYPE_F16) {
+ attn_with_kvcache_one_block_(
+ config_.head_dim,
+ config_.q_head_num / config_.kv_head_num, GGML_TYPE_F16,
+ (void *)&q_in_data[batch_id * config_.kv_head_num *
+ n_gqa_ * config_.head_dim +
+ head_id * n_gqa_ * config_.head_dim],
+ seq_len_, 0, true, nullptr, GGML_TYPE_F16, 0,
+ k_cache_fp16_[layer_id_][head_id][block_idx].data(), 0,
+ nullptr, nullptr, GGML_TYPE_F16, 1,
+ v_cache_fp16_[layer_id_][head_id][block_idx].data(), 0,
+ nullptr, nullptr,
+ thread_local_attn_score_[thread_id].data(),
+ thread_local_output_fp32_[thread_id].data(),
+ thread_local_attn_lse_[thread_id].data(),
+ thread_local_draft_[thread_id].data(), nullptr,
+ cos_.data(), sin_.data());
+
+ } else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {
+ attn_with_kvcache_one_block_(
+ config_.head_dim,
+ config_.q_head_num / config_.kv_head_num,
+ GGML_TYPE_Q8_0, q_q8_0_[batch_id][head_id].data(),
+ seq_len_, 0, true, nullptr, GGML_TYPE_Q4_0, 0,
+ k_cache_q4[layer_id_][head_id][block_idx].data(), 0,
+ nullptr, nullptr, GGML_TYPE_Q4_0, 1,
+ v_cache_q4[layer_id_][head_id][block_idx].data(), 0,
+ nullptr, nullptr,
+ thread_local_attn_score_[thread_id].data(),
+ thread_local_output_q8_0_[thread_id].data(),
+ thread_local_attn_lse_[thread_id].data(),
+ thread_local_draft_[thread_id].data(), nullptr,
+ cos_.data(), sin_.data());
+ dequantize_row_q8_0(
+ thread_local_output_q8_0_[thread_id].data(),
+ thread_local_output_fp32_[thread_id].data(),
+ n_gqa_ * config_.head_dim);
+ } else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {
+ attn_with_kvcache_one_block_(
+ config_.head_dim,
+ config_.q_head_num / config_.kv_head_num,
+ GGML_TYPE_Q8_0, q_q8_0_[batch_id][head_id].data(),
+ seq_len_, 0, true, nullptr, GGML_TYPE_Q8_0, 0,
+ k_cache_q8[layer_id_][head_id][block_idx].data(), 0,
+ nullptr, nullptr, GGML_TYPE_Q8_0, 1,
+ v_cache_q8[layer_id_][head_id][block_idx].data(), 0,
+ nullptr, nullptr,
+ thread_local_attn_score_[thread_id].data(),
+ thread_local_output_q8_0_[thread_id].data(),
+ thread_local_attn_lse_[thread_id].data(),
+ thread_local_draft_[thread_id].data(), nullptr,
+ cos_.data(), sin_.data());
+ dequantize_row_q8_0(
+ thread_local_output_q8_0_[thread_id].data(),
+ thread_local_output_fp32_[thread_id].data(),
+ n_gqa_ * config_.head_dim);
+ }
+ }
+ int cur_batch_idx = thread_cur_head_idx_[thread_id].first;
+ int cur_head_id = thread_cur_head_idx_[thread_id].second;
+ if (batch_id == cur_batch_idx && head_id == cur_head_id) {
+ for (int i = 0; i < n_gqa_; i++) {
+ float new_attn_lse =
+ thread_local_cur_attn_lse_[thread_id][i] +
+ std::log(
+ 1.0 +
+ std::exp(thread_local_attn_lse_[thread_id][i] -
+ thread_local_cur_attn_lse_[thread_id][i]));
+ ggml_vec_scale_f32(
+ config_.head_dim,
+ thread_local_cur_output_fp32_[thread_id].data() +
+ i * config_.head_dim,
+ std::exp(thread_local_cur_attn_lse_[thread_id][i] -
+ new_attn_lse));
+ ggml_vec_scale_f32(
+ config_.head_dim,
+ thread_local_output_fp32_[thread_id].data() +
+ i * config_.head_dim,
+ std::exp(thread_local_attn_lse_[thread_id][i] -
+ new_attn_lse));
+ for (int j = 0; j < config_.head_dim; j++) {
+ thread_local_cur_output_fp32_[thread_id]
+ [i * config_.head_dim +
+ j] +=
+ thread_local_output_fp32_[thread_id]
+ [i * config_.head_dim + j];
+ }
+ thread_local_cur_attn_lse_[thread_id][i] = new_attn_lse;
+ }
+ } else {
+ if (cur_batch_idx != -1) {
+ mutex_[cur_batch_idx][cur_head_id]->lock();
+ for (int i = 0; i < n_gqa_; i++) {
+ if (std::abs(attn_lse_[cur_batch_idx][cur_head_id][i]) <
+ 1e-6) {
+ attn_lse_[cur_batch_idx][cur_head_id][i] =
+ thread_local_cur_attn_lse_[thread_id][i];
+ for (int j = 0; j < config_.head_dim; j++) {
+ output_fp32_[cur_batch_idx][cur_head_id]
+ [i * config_.head_dim + j] =
+ thread_local_cur_output_fp32_
+ [thread_id]
+ [i * config_.head_dim + j];
+ }
+ continue;
+ }
+ float new_attn_lse =
+ attn_lse_[cur_batch_idx][cur_head_id][i] +
+ std::log(
+ 1.0 +
+ std::exp(
+ thread_local_cur_attn_lse_[thread_id][i] -
+ attn_lse_[cur_batch_idx][cur_head_id][i]));
+ ggml_vec_scale_f32(
+ config_.head_dim,
+ output_fp32_[cur_batch_idx][cur_head_id].data() +
+ i * config_.head_dim,
+ std::exp(attn_lse_[cur_batch_idx][cur_head_id][i] -
+ new_attn_lse));
+ ggml_vec_scale_f32(
+ config_.head_dim,
+ thread_local_cur_output_fp32_[thread_id].data() +
+ i * config_.head_dim,
+ std::exp(thread_local_cur_attn_lse_[thread_id][i] -
+ new_attn_lse));
+ for (int j = 0; j < config_.head_dim; j++) {
+ output_fp32_[cur_batch_idx][cur_head_id]
+ [i * config_.head_dim + j] +=
+ thread_local_cur_output_fp32_
+ [thread_id][i * config_.head_dim + j];
+ }
+ attn_lse_[cur_batch_idx][cur_head_id][i] = new_attn_lse;
+ }
+ mutex_[cur_batch_idx][cur_head_id]->unlock();
+ }
+ thread_cur_head_idx_[thread_id].first = batch_id;
+ thread_cur_head_idx_[thread_id].second = head_id;
+ for (int i = 0; i < n_gqa_; i++) {
+ thread_local_cur_attn_lse_[thread_id][i] =
+ thread_local_attn_lse_[thread_id][i];
+ for (int j = 0; j < config_.head_dim; j++) {
+ thread_local_cur_output_fp32_
+ [thread_id][i * config_.head_dim + j] =
+ thread_local_output_fp32_[thread_id]
+ [i * config_.head_dim +
+ j];
+ }
+ }
+ }
+ },
+ // Merge the results of the remaining blocks.
+ [&](int thread_id) {
+ int cur_batch_idx = thread_cur_head_idx_[thread_id].first;
+ int cur_head_id = thread_cur_head_idx_[thread_id].second;
+ if (cur_head_id != -1) {
+ mutex_[cur_batch_idx][cur_head_id]->lock();
+ for (int i = 0; i < n_gqa_; i++) {
+ float new_attn_lse;
+ if (std::abs(attn_lse_[cur_batch_idx][cur_head_id][i]) <
+ 1e-6) {
+ attn_lse_[cur_batch_idx][cur_head_id][i] =
+ thread_local_cur_attn_lse_[thread_id][i];
+ for (int j = 0; j < config_.head_dim; j++) {
+ output_fp32_[cur_batch_idx][cur_head_id]
+ [i * config_.head_dim + j] =
+ thread_local_cur_output_fp32_
+ [thread_id]
+ [i * config_.head_dim + j];
+ }
+ continue;
+ }
+ new_attn_lse =
+ attn_lse_[cur_batch_idx][cur_head_id][i] +
+ std::log(
+ 1.0 +
+ std::exp(thread_local_cur_attn_lse_[thread_id][i] -
+ attn_lse_[cur_batch_idx][cur_head_id][i]));
+ ggml_vec_scale_f32(
+ config_.head_dim,
+ output_fp32_[cur_batch_idx][cur_head_id].data() +
+ i * config_.head_dim,
+ std::exp(attn_lse_[cur_batch_idx][cur_head_id][i] -
+ new_attn_lse));
+ ggml_vec_scale_f32(
+ config_.head_dim,
+ thread_local_cur_output_fp32_[thread_id].data() +
+ i * config_.head_dim,
+ std::exp(thread_local_cur_attn_lse_[thread_id][i] -
+ new_attn_lse));
+ for (int j = 0; j < config_.head_dim; j++) {
+ output_fp32_[cur_batch_idx][cur_head_id]
+ [i * config_.head_dim + j] +=
+ thread_local_cur_output_fp32_[thread_id]
+ [i * config_.head_dim +
+ j];
+ }
+ attn_lse_[cur_batch_idx][cur_head_id][i] = new_attn_lse;
+ }
+ mutex_[cur_batch_idx][cur_head_id]->unlock();
+ }
+ });
+
+ // move the results to output and attn_lse
+ uint16_t *output_data = reinterpret_cast(output);
+ float *attn_lse_data = attn_lse;
+ for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {
+ for (int i = 0; i < config_.kv_head_num; i++) {
+ for (int j = 0; j < n_gqa_ * config_.head_dim; j++) {
+ output_data[batch_idx * config_.kv_head_num * n_gqa_ *
+ config_.head_dim +
+ i * n_gqa_ * config_.head_dim + j] =
+ GGML_FP32_TO_FP16(output_fp32_[batch_idx][i][j]);
+ }
+ for (int j = 0; j < n_gqa_; j++) {
+ attn_lse_data[batch_idx * config_.kv_head_num * n_gqa_ +
+ i * n_gqa_ + j] = attn_lse_[batch_idx][i][j];
+ }
+ }
+ }
+ // Timer end
+ auto end = std::chrono::high_resolution_clock::now();
+ std::chrono::duration diff = end - start;
+ // printf("layer %d time of computing attention: %f s\n", layer_id_,
+ // diff.count());
+}
+
+void KVCache::attn(const ggml_fp16_t *q_in, ggml_fp16_t *output,
+ float *attn_lse, int layer_idx, int generate_token_idx,
+ int q_len, int batch_size, int max_block_num,
+ int *block_table, int *cache_seqlens, int pick_block_num,
+ int init_block_num, int local_block_num, Backend *backend) {
+
+ // Timer start
+ auto start = std::chrono::high_resolution_clock::now();
+ layer_id_ = layer_idx;
+ batch_size = batch_size * q_len;
+
+ const uint16_t *q_in_data = const_cast(q_in);
+
+ quantize_q_(q_in_data, batch_size);
+ if (config_.retrieval_type == RetrievalType::LAYER) {
+ attn_initialize_layer_(batch_size, layer_idx, block_table,
+ max_block_num, cache_seqlens);
+ retrieval_kvcache_layer_(q_in_data, init_block_num, local_block_num,
+ pick_block_num, q_len, generate_token_idx,
+ batch_size, layer_idx, cache_seqlens,
+ max_block_num, backend);
+ attention_layer_(q_in_data, output, attn_lse, batch_size, backend);
+ } else if (config_.retrieval_type == RetrievalType::KVHEAD) {
+ attn_initialize_kvhead_(batch_size, layer_idx, block_table,
+ max_block_num, cache_seqlens);
+ retrieval_kvcache_kvhead_(q_in_data, init_block_num, local_block_num,
+ pick_block_num, q_len, generate_token_idx,
+ batch_size, layer_idx, cache_seqlens,
+ max_block_num, backend);
+ attention_kvhead_(q_in_data, output, attn_lse, batch_size, backend);
+ }
+
+ // Timer end
+ auto end = std::chrono::high_resolution_clock::now();
+ std::chrono::duration diff = end - start;
+ // printf("layer %d time of computing attention: %f s\n", layer_idx,
+ // diff.count());
+}
+
+void KVCache::attn_with_kvcache(
+ const ggml_fp16_t *q_in, const ggml_fp16_t *k_in, const ggml_fp16_t *v_in,
+ ggml_fp16_t *output, float *attn_lse, int layer_idx, int generate_token_idx,
+ int q_len, int batch_size, int max_block_num, int *block_table,
+ int *cache_seqlens, int topk, int local, Backend *backend) {
+ // printf("attn_with_kvcache start\n");
+ assert(q_len == 1);
+ // Timer start
+ auto start = std::chrono::high_resolution_clock::now();
+
+ layer_id_ = layer_idx;
+
+ update_kvcache_fp16(k_in, v_in, layer_idx, block_table, batch_size,
+ max_block_num, cache_seqlens, q_len, backend);
+ // printf("update finished.\n");
+
+ // cache_seqlens memory is modified.
+ for (int i = 0; i < batch_size; i++) {
+ cache_seqlens[i] += q_len;
+ }
+ int init_block_num = 1;
+ if (config_.block_len <= 32) {
+ init_block_num = 64 / config_.block_len;
+ }
+
+ attn(q_in, output, attn_lse, layer_idx, generate_token_idx, q_len,
+ batch_size, max_block_num, block_table, cache_seqlens, topk,
+ init_block_num, local, backend);
+
+ // Timer end
+ auto end = std::chrono::high_resolution_clock::now();
+ std::chrono::duration diff = end - start;
+ // printf("layer %d time of computing attention with kvcache: %f s\n",
+ // layer_idx, diff.count());
+}
+
+void KVCache::quantize_q_(const uint16_t *q_in_data, int batch_size) {
+ // Timer start
+ auto start = std::chrono::high_resolution_clock::now();
+ for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {
+ if (config_.kv_type == ggml_type::GGML_TYPE_F16) {
+ // quantize q
+ for (int i = 0; i < config_.kv_head_num; i++) {
+ for (int j = 0; j < n_gqa_ * config_.head_dim; j++) {
+ q_fp32_[batch_idx][i][j] = GGML_FP16_TO_FP32(
+ q_in_data[batch_idx * config_.kv_head_num * n_gqa_ *
+ config_.head_dim +
+ i * n_gqa_ * config_.head_dim + j]);
+ }
+ }
+ } else {
+ // quantize q
+ for (int i = 0; i < config_.kv_head_num; i++) {
+ for (int j = 0; j < n_gqa_ * config_.head_dim; j++) {
+ q_fp32[j] = GGML_FP16_TO_FP32(
+ q_in_data[batch_idx * config_.kv_head_num * n_gqa_ *
+ config_.head_dim +
+ i * n_gqa_ * config_.head_dim + j]);
+ }
+ quantize_row_q8_0(q_fp32.data(), q_q8_0_[batch_idx][i].data(),
+ n_gqa_ * config_.head_dim);
+ }
+ }
+ }
+ // Timer end
+ auto end = std::chrono::high_resolution_clock::now();
+ // printf("time of quantizing q: %f s\n",
+ // std::chrono::duration(end - start).count());
+}
+void KVCache::attn_initialize_layer_(int batch_size, int layer_idx,
+ int *block_table, int &max_block_num,
+ int *cache_seqlens) {
+ // Timer start
+ auto start = std::chrono::high_resolution_clock::now();
+ for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {
+ // initialize output_fp32_ and attn_lse_
+ for (int i = 0; i < config_.kv_head_num; i++) {
+ for (int j = 0; j < n_gqa_ * config_.head_dim; j++) {
+ output_fp32_[batch_idx][i][j] = 0;
+ }
+ for (int j = 0; j < n_gqa_; j++) {
+ attn_lse_[batch_idx][i][j] = 0;
+ }
+ }
+ // clear top_similar_block_
+
+ while (!top_similar_block_[batch_idx].empty())
+ top_similar_block_[batch_idx].pop();
+ }
+
+ // get block_table_before_retrieval_ and cache_seqlens_
+ if (block_table == nullptr) {
+ max_block_num = past_block_num_[layer_idx];
+ for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {
+ if (cache_total_len_ != 0)
+ cache_seqlens_[batch_idx] = cache_total_len_;
+ else
+ cache_seqlens_[batch_idx] = max_block_num * config_.block_len;
+ for (int i = 0; i < max_block_num; i++) {
+ block_table_before_retrieval_[batch_idx][i] = i;
+ block_similar_[batch_idx][i] = 0;
+ }
+ }
+ } else {
+ for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {
+ cache_seqlens_[batch_idx] = cache_seqlens[batch_idx];
+ for (int i = 0; i < max_block_num; i++) {
+ block_table_before_retrieval_[batch_idx][i] =
+ block_table[batch_idx * max_block_num + i];
+ block_similar_[batch_idx][i] = 0;
+ }
+ }
+ }
+ // Timer end
+ auto end = std::chrono::high_resolution_clock::now();
+ // printf("layer %d time of initializing attention: %f s\n", layer_idx,
+ // std::chrono::duration(end - start).count());
+}
+
+void KVCache::calculate_block_similarity_layer_(
+ const uint16_t *q_in_data, int batch_size, int layer_idx, int q_len,
+ int max_block_num, int *cache_seqlens, int init_block_num,
+ int local_block_num, int pick_block_num, Backend *backend) {
+ // Timer start
+ auto start = std::chrono::high_resolution_clock::now();
+
+ if (batch_size == 1 &&
+ config_.anchor_num == 1) { // TODO: improve batch_size > 1
+ for (int batch_id = 0; batch_id < batch_size; batch_id++) {
+ if (q_len == 1) {
+ for (int j = 0; j < config_.head_dim * config_.q_head_num;
+ j++) {
+ avg_q[batch_id][j] = GGML_FP16_TO_FP32(
+ q_in_data[batch_id * q_len * config_.q_head_num *
+ config_.head_dim +
+ j]);
+ avg_q_fp16[batch_id][j] =
+ q_in_data[batch_id * q_len * config_.q_head_num *
+ config_.head_dim +
+ j];
+ }
+ } else {
+ for (int j = 0; j < config_.head_dim * config_.q_head_num;
+ j++) {
+ avg_q[batch_id][j] = 0;
+ }
+ for (int i = 0; i < q_len; i++) {
+ for (int j = 0; j < config_.head_dim; j++) {
+ avg_q[batch_id][j] += GGML_FP16_TO_FP32(
+ q_in_data[batch_id * q_len * config_.q_head_num *
+ config_.head_dim +
+ i * config_.q_head_num *
+ config_.head_dim +
+ j]);
+ }
+ }
+ for (int j = 0; j < config_.head_dim * config_.q_head_num;
+ j++) {
+ avg_q[batch_id][j] /= q_len;
+ avg_q_fp16[batch_id][j] =
+ GGML_FP32_TO_FP16(avg_q[batch_id][j]);
+ }
+ }
+ int seq_len = cache_seqlens_[batch_id];
+ int block_num = (seq_len / config_.block_len) - local_block_num -
+ init_block_num;
+ if (block_num <= 0) {
+ continue;
+ }
+ bool is_seq = true;
+ for (int i = init_block_num + 1;
+ i < (seq_len / config_.block_len) - local_block_num; i++) {
+ if (block_table_before_retrieval_[batch_id][i] !=
+ block_table_before_retrieval_[batch_id][i - 1] + 1) {
+ is_seq = false;
+ break;
+ }
+ }
+ if (is_seq) {
+ int nth = backend->get_thread_num();
+ backend->do_work_stealing_job(
+ nth, nullptr,
+ [&](int task_id) {
+ int ith = task_id;
+ bool ok = llamafile_sgemm(
+ block_num, 1, config_.q_head_num * config_.head_dim,
+ anchor_.data() +
+ (layer_idx * config_.max_block_num +
+ block_table_before_retrieval_
+ [batch_id][init_block_num]) *
+ config_.anchor_num * config_.q_head_num *
+ config_.head_dim,
+ config_.q_head_num * config_.head_dim,
+ avg_q_fp16[batch_id].data(),
+ config_.q_head_num * config_.head_dim,
+ block_similar_[batch_id].data() + init_block_num,
+ block_num, ith, nth, GGML_TASK_TYPE_COMPUTE,
+ GGML_TYPE_F16, GGML_TYPE_F16, GGML_TYPE_F32,
+ GGML_PREC_DEFAULT);
+ if (!ok) {
+ printf("llamafile_sgemm failed\n");
+ }
+ },
+ nullptr);
+ } else {
+ backend->do_work_stealing_job(
+ block_num, nullptr,
+ [&](int task_id) {
+ int block_id = task_id + init_block_num;
+ int block_idx =
+ block_table_before_retrieval_[batch_id][block_id];
+ bool ok = llamafile_sgemm(
+ 1, 1, config_.q_head_num * config_.head_dim,
+ anchor_.data() +
+ (layer_idx * config_.max_block_num +
+ block_table_before_retrieval_[batch_id]
+ [block_idx]) *
+ config_.anchor_num * config_.q_head_num *
+ config_.head_dim,
+ config_.q_head_num * config_.head_dim,
+ avg_q_fp16[batch_id].data(),
+ config_.q_head_num * config_.head_dim,
+ block_similar_[batch_id].data() + block_id, 1, 0, 1,
+ GGML_TASK_TYPE_COMPUTE, GGML_TYPE_F16,
+ GGML_TYPE_F16, GGML_TYPE_F32, GGML_PREC_DEFAULT);
+ if (!ok) {
+ printf("llamafile_sgemm failed\n");
+ }
+ },
+ nullptr);
+ }
+ }
+ } else {
+ backend->do_work_stealing_job(
+ batch_size * max_block_num, nullptr,
+ [&](int task_id) {
+ int batch_id = task_id / max_block_num;
+ int block_id = task_id % max_block_num;
+ int seq_len = cache_seqlens_[batch_id];
+
+ if (block_id < init_block_num ||
+ block_id >=
+ (seq_len / config_.block_len) - local_block_num) {
+ return;
+ }
+
+ int block_idx =
+ block_table_before_retrieval_[batch_id][block_id];
+ float sim = 0;
+
+ for (int head_id = 0; head_id < config_.q_head_num; head_id++) {
+ for (int i = 0; i < config_.head_dim; i++) {
+ float q_i = 0,
+ qa_i = std::numeric_limits::lowest();
+ for (int q_id = 0; q_id < q_len; q_id++) {
+ q_i += GGML_FP16_TO_FP32(
+ q_in_data[batch_id * q_len *
+ config_.q_head_num *
+ config_.head_dim +
+ q_id * config_.q_head_num *
+ config_.head_dim +
+ head_id * config_.head_dim + i]);
+ }
+ q_i /= q_len;
+ for (int anchor_id = 0; anchor_id < config_.anchor_num;
+ anchor_id++) {
+ qa_i = std::max(
+ qa_i,
+ GGML_FP16_TO_FP32(
+ anchor_[(long long)layer_idx *
+ config_.max_block_num *
+ config_.anchor_num *
+ config_.q_head_num *
+ config_.head_dim +
+ block_idx * config_.anchor_num *
+ config_.q_head_num *
+ config_.head_dim +
+ anchor_id * config_.q_head_num *
+ config_.head_dim +
+ head_id * config_.head_dim + i]) *
+ q_i);
+ }
+ sim += qa_i;
+ }
+ }
+ block_similar_[batch_id][block_id] = sim;
+ },
+ nullptr);
+ }
+ // Timer end
+ auto end = std::chrono::high_resolution_clock::now();
+ std::chrono::duration diff = end - start;
+ // printf("layer %d time of calculating similarity: %f s\n", layer_idx,
+ // diff.count());
+}
+
+void KVCache::select_block_layer_(int batch_size, int layer_idx,
+ int max_block_num, int init_block_num,
+ int local_block_num, int pick_block_num) {
+ // Timer start
+ auto start = std::chrono::high_resolution_clock::now();
+
+ for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {
+
+ if (cache_seqlens_[batch_idx] / config_.block_len <=
+ init_block_num + pick_block_num + local_block_num) {
+ block_table_after_retrieval_[batch_idx].swap(
+ block_table_before_retrieval_[batch_idx]);
+ selected_blocks_num_history_[(layer_idx - config_.layer_offset) /
+ config_.layer_step] = 0;
+ continue;
+ }
+
+ for (int block_id = init_block_num;
+ block_id <
+ (cache_seqlens_[batch_idx] / config_.block_len) - local_block_num;
+ block_id++) {
+ top_similar_block_[batch_idx].push(std::make_pair(
+ block_similar_[batch_idx][block_id],
+ block_table_before_retrieval_[batch_idx][block_id]));
+ if (top_similar_block_[batch_idx].size() > pick_block_num) {
+ top_similar_block_[batch_idx].pop();
+ }
+ }
+
+ int i = 0;
+ for (; i < init_block_num; i++) {
+ block_table_after_retrieval_[batch_idx][i] =
+ block_table_before_retrieval_[batch_idx][i];
+ }
+ while (!top_similar_block_[batch_idx].empty()) {
+ block_table_after_retrieval_[batch_idx][i] =
+ top_similar_block_[batch_idx].top().second;
+ top_similar_block_[batch_idx].pop();
+ i++;
+ }
+ for (; i < init_block_num + pick_block_num + local_block_num; i++) {
+ block_table_after_retrieval_[batch_idx][i] =
+ block_table_before_retrieval_[batch_idx]
+ [(cache_seqlens_[batch_idx] /
+ config_.block_len) -
+ local_block_num + i -
+ init_block_num - pick_block_num];
+ }
+ if (cache_seqlens_[batch_idx] % config_.block_len != 0) {
+ block_table_after_retrieval_[batch_idx][i] =
+ block_table_before_retrieval_[batch_idx][(
+ cache_seqlens_[batch_idx] / config_.block_len)];
+ cache_seqlens_[batch_idx] =
+ (cache_seqlens_[batch_idx] % config_.block_len) +
+ i * config_.block_len;
+ i++;
+ } else {
+ cache_seqlens_[batch_idx] =
+ (cache_seqlens_[batch_idx] % config_.block_len) +
+ i * config_.block_len;
+ }
+ for (int j = 0; j < i; j++) {
+ selected_blocks_history_[(layer_idx - config_.layer_offset) /
+ config_.layer_step][batch_idx][j] =
+ block_table_after_retrieval_[batch_idx][j];
+ }
+ selected_blocks_num_history_[(layer_idx - config_.layer_offset) /
+ config_.layer_step] = i;
+ }
+
+ // Timer end
+ auto end = std::chrono::high_resolution_clock::now();
+ std::chrono::duration diff = end - start;
+ // printf("layer %d time of selecting blocks: %f s\n", layer_idx,
+ // diff.count());
+}
+
+// retrieval kvcache, get the init_block_num block at beginning, top
+// pick_block_num similar and last local_block_num blocks. Each task
+// calculates the simlarity of a certain block with the query, then push
+// the block into the priority queue. Finally, the required blocks are
+// pushed into the block_table_after_retrieval_.
+void KVCache::retrieval_kvcache_layer_(const uint16_t *q_in_data,
+ int init_block_num, int local_block_num,
+ int pick_block_num, int q_len,
+ int generate_token_idx, int batch_size,
+ int layer_idx, int *cache_seqlens,
+ int &max_block_num, Backend *backend) {
+ // Timer start
+ auto start = std::chrono::high_resolution_clock::now();
+ max_block_num_after_retrieval_ = 0;
+ if (pick_block_num != -1 &&
+ (generate_token_idx % config_.token_step != 0 ||
+ (layer_idx % config_.layer_step != config_.layer_offset))) {
+
+ if (selected_blocks_num_history_[(layer_idx - config_.layer_offset) /
+ config_.layer_step] == 0) {
+ max_block_num_after_retrieval_ = max_block_num;
+ block_table_after_retrieval_.swap(block_table_before_retrieval_);
+ } else {
+ max_block_num_after_retrieval_ = selected_blocks_num_history_
+ [(layer_idx - config_.layer_offset) / config_.layer_step];
+ for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {
+ for (int i = 0; i < max_block_num_after_retrieval_; i++) {
+ block_table_after_retrieval_[batch_idx][i] =
+ selected_blocks_history_[(layer_idx -
+ config_.layer_offset) /
+ config_.layer_step][batch_idx]
+ [i];
+ }
+
+ if (cache_seqlens[batch_idx] % config_.block_len == 1) {
+ selected_blocks_num_history_[(layer_idx -
+ config_.layer_offset) /
+ config_.layer_step] += 1;
+ int x =
+ selected_blocks_num_history_[(layer_idx -
+ config_.layer_offset) /
+ config_.layer_step];
+ int last_block_idx =
+ block_table_before_retrieval_[batch_idx]
+ [cache_seqlens[batch_idx] /
+ config_.block_len];
+ selected_blocks_history_[(layer_idx -
+ config_.layer_offset) /
+ config_.layer_step][batch_idx]
+ [x - 1] = last_block_idx;
+ block_table_after_retrieval_[batch_idx][x - 1] =
+ last_block_idx;
+ }
+ cache_seqlens_[batch_idx] =
+ (cache_seqlens_[batch_idx] % config_.block_len) +
+ selected_blocks_num_history_[(layer_idx -
+ config_.layer_offset) /
+ config_.layer_step] *
+ config_.block_len -
+ config_.block_len;
+ }
+ }
+ } else if (pick_block_num != -1) {
+ max_block_num_after_retrieval_ =
+ std::min(max_block_num,
+ init_block_num + pick_block_num + local_block_num + 1);
+ calculate_block_similarity_layer_(q_in_data, batch_size, layer_idx,
+ q_len, max_block_num, cache_seqlens,
+ init_block_num, local_block_num,
+ pick_block_num, backend);
+ select_block_layer_(batch_size, layer_idx, max_block_num,
+ init_block_num, local_block_num, pick_block_num);
+ } else {
+ selected_blocks_num_history_[(layer_idx - config_.layer_offset) /
+ config_.layer_step] = 0;
+ max_block_num_after_retrieval_ = max_block_num;
+ block_table_after_retrieval_.swap(block_table_before_retrieval_);
+ }
+
+ // Timer end
+ auto end = std::chrono::high_resolution_clock::now();
+ // printf("layer %d time of retrieval kvcache: %f s\n", layer_idx,
+ // std::chrono::duration(end - start).count());
+}
+void KVCache::calculate_sparsity_layer_(const uint16_t *q_in_data,
+ float *attn_sparsity, int batch_size,
+ int max_block_num, int *block_table,
+ int *cache_seqlens, Backend *backend
+
+) {
+ // Timer start
+ auto start = std::chrono::high_resolution_clock::now();
+ seq_len_ = config_.block_len;
+ backend->do_work_stealing_job(
+ batch_size * config_.kv_head_num * max_block_num,
+ [&](int thread_id) {
+ thread_cur_head_idx_[thread_id].first = -1;
+ thread_cur_head_idx_[thread_id].second = -1;
+ },
+ [&](int task_id) {
+ int batch_id = task_id / (config_.kv_head_num * max_block_num);
+ int head_id = (task_id % (config_.kv_head_num * max_block_num)) /
+ max_block_num;
+ int block_id = task_id % max_block_num;
+ int thread_id = Backend::thread_local_id;
+ // If the block is out of the sequence length, skip it.
+ if (cache_seqlens[batch_id] / config_.block_len < block_id) {
+ return;
+ }
+ int block_idx = block_table[batch_id * max_block_num + block_id];
+ if (cache_seqlens_[batch_id] / config_.block_len == block_id) {
+ int seq_len = cache_seqlens_[batch_id] % config_.block_len;
+ if (seq_len == 0)
+ return;
+
+ // Prepare the attention mask for the last block.
+ int full_blocks = seq_len / 8;
+ int remaining_bits = seq_len % 8;
+ // Fill full blocks with 1s
+ for (int i = 0; i < full_blocks; ++i) {
+ thread_local_attn_mask_[thread_id][i] = 0xFF;
+ }
+ // Fill the remaining bits in the next block
+ if (remaining_bits > 0 && full_blocks < seq_len_ / 8) {
+ thread_local_attn_mask_[thread_id][full_blocks] =
+ (1 << remaining_bits) - 1;
+ } else {
+ thread_local_attn_mask_[thread_id][full_blocks] = 0;
+ }
+
+ for (int i = full_blocks + 1; i < seq_len_ / 8; ++i) {
+ thread_local_attn_mask_[thread_id][i] = 0;
+ }
+ if (config_.kv_type == ggml_type::GGML_TYPE_F16) {
+ attn_with_kvcache_one_block_(
+ config_.head_dim,
+ config_.q_head_num / config_.kv_head_num, GGML_TYPE_F16,
+ (void *)&q_in_data[batch_id * config_.kv_head_num *
+ n_gqa_ * config_.head_dim +
+ head_id * n_gqa_ * config_.head_dim],
+ seq_len_, 0, false,
+ thread_local_attn_mask_[thread_id].data(),
+ GGML_TYPE_F16, 0,
+ k_cache_fp16_[layer_id_][head_id][block_idx].data(), 0,
+ nullptr, nullptr, GGML_TYPE_F16, 1,
+ v_cache_fp16_[layer_id_][head_id][block_idx].data(), 0,
+ nullptr, nullptr,
+ thread_local_attn_score_[thread_id].data(),
+ thread_local_output_fp32_[thread_id].data(),
+ thread_local_attn_lse_[thread_id].data(),
+ thread_local_draft_[thread_id].data(), nullptr,
+ cos_.data(), sin_.data());
+ } else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {
+ attn_with_kvcache_one_block_(
+ config_.head_dim,
+ config_.q_head_num / config_.kv_head_num,
+ GGML_TYPE_Q8_0, q_q8_0_[batch_id][head_id].data(),
+ seq_len_, 0, false,
+ thread_local_attn_mask_[thread_id].data(),
+ GGML_TYPE_Q4_0, 0,
+ k_cache_q4[layer_id_][head_id][block_idx].data(), 0,
+ nullptr, nullptr, GGML_TYPE_Q4_0, 1,
+ v_cache_q4[layer_id_][head_id][block_idx].data(), 0,
+ nullptr, nullptr,
+ thread_local_attn_score_[thread_id].data(),
+ thread_local_output_q8_0_[thread_id].data(),
+ thread_local_attn_lse_[thread_id].data(),
+ thread_local_draft_[thread_id].data(), nullptr,
+ cos_.data(), sin_.data());
+ dequantize_row_q8_0(
+ thread_local_output_q8_0_[thread_id].data(),
+ thread_local_output_fp32_[thread_id].data(),
+ n_gqa_ * config_.head_dim);
+ } else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {
+ attn_with_kvcache_one_block_(
+ config_.head_dim,
+ config_.q_head_num / config_.kv_head_num,
+ GGML_TYPE_Q8_0, q_q8_0_[batch_id][head_id].data(),
+ seq_len_, 0, false,
+ thread_local_attn_mask_[thread_id].data(),
+ GGML_TYPE_Q8_0, 0,
+ k_cache_q8[layer_id_][head_id][block_idx].data(), 0,
+ nullptr, nullptr, GGML_TYPE_Q8_0, 1,
+ v_cache_q8[layer_id_][head_id][block_idx].data(), 0,
+ nullptr, nullptr,
+ thread_local_attn_score_[thread_id].data(),
+ thread_local_output_q8_0_[thread_id].data(),
+ thread_local_attn_lse_[thread_id].data(),
+ thread_local_draft_[thread_id].data(), nullptr,
+ cos_.data(), sin_.data());
+ dequantize_row_q8_0(
+ thread_local_output_q8_0_[thread_id].data(),
+ thread_local_output_fp32_[thread_id].data(),
+ n_gqa_ * config_.head_dim);
+ }
+ } else {
+ if (config_.kv_type == ggml_type::GGML_TYPE_F16) {
+ attn_with_kvcache_one_block_(
+ config_.head_dim,
+ config_.q_head_num / config_.kv_head_num, GGML_TYPE_F16,
+ (void *)&q_in_data[batch_id * config_.kv_head_num *
+ n_gqa_ * config_.head_dim +
+ head_id * n_gqa_ * config_.head_dim],
+ seq_len_, 0, true, nullptr, GGML_TYPE_F16, 0,
+ k_cache_fp16_[layer_id_][head_id][block_idx].data(), 0,
+ nullptr, nullptr, GGML_TYPE_F16, 1,
+ v_cache_fp16_[layer_id_][head_id][block_idx].data(), 0,
+ nullptr, nullptr,
+ thread_local_attn_score_[thread_id].data(),
+ thread_local_output_fp32_[thread_id].data(),
+ thread_local_attn_lse_[thread_id].data(),
+ thread_local_draft_[thread_id].data(), nullptr,
+ cos_.data(), sin_.data());
+
+ } else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {
+ attn_with_kvcache_one_block_(
+ config_.head_dim,
+ config_.q_head_num / config_.kv_head_num,
+ GGML_TYPE_Q8_0, q_q8_0_[batch_id][head_id].data(),
+ seq_len_, 0, true, nullptr, GGML_TYPE_Q4_0, 0,
+ k_cache_q4[layer_id_][head_id][block_idx].data(), 0,
+ nullptr, nullptr, GGML_TYPE_Q4_0, 1,
+ v_cache_q4[layer_id_][head_id][block_idx].data(), 0,
+ nullptr, nullptr,
+ thread_local_attn_score_[thread_id].data(),
+ thread_local_output_q8_0_[thread_id].data(),
+ thread_local_attn_lse_[thread_id].data(),
+ thread_local_draft_[thread_id].data(), nullptr,
+ cos_.data(), sin_.data());
+ dequantize_row_q8_0(
+ thread_local_output_q8_0_[thread_id].data(),
+ thread_local_output_fp32_[thread_id].data(),
+ n_gqa_ * config_.head_dim);
+ } else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {
+ attn_with_kvcache_one_block_(
+ config_.head_dim,
+ config_.q_head_num / config_.kv_head_num,
+ GGML_TYPE_Q8_0, q_q8_0_[batch_id][head_id].data(),
+ seq_len_, 0, true, nullptr, GGML_TYPE_Q8_0, 0,
+ k_cache_q8[layer_id_][head_id][block_idx].data(), 0,
+ nullptr, nullptr, GGML_TYPE_Q8_0, 1,
+ v_cache_q8[layer_id_][head_id][block_idx].data(), 0,
+ nullptr, nullptr,
+ thread_local_attn_score_[thread_id].data(),
+ thread_local_output_q8_0_[thread_id].data(),
+ thread_local_attn_lse_[thread_id].data(),
+ thread_local_draft_[thread_id].data(), nullptr,
+ cos_.data(), sin_.data());
+ dequantize_row_q8_0(
+ thread_local_output_q8_0_[thread_id].data(),
+ thread_local_output_fp32_[thread_id].data(),
+ n_gqa_ * config_.head_dim);
+ }
+ }
+ for (int i = 0; i < n_gqa_; i++) {
+ block_lse_[batch_id][block_idx][head_id * n_gqa_ + i] =
+ thread_local_attn_lse_[thread_id][i];
+ }
+ int cur_batch_idx = thread_cur_head_idx_[thread_id].first;
+ int cur_head_id = thread_cur_head_idx_[thread_id].second;
+ if (batch_id == cur_batch_idx && head_id == cur_head_id) {
+ for (int i = 0; i < n_gqa_; i++) {
+ float new_attn_lse =
+ thread_local_cur_attn_lse_[thread_id][i] +
+ std::log(
+ 1.0 +
+ std::exp(thread_local_attn_lse_[thread_id][i] -
+ thread_local_cur_attn_lse_[thread_id][i]));
+ ggml_vec_scale_f32(
+ config_.head_dim,
+ thread_local_cur_output_fp32_[thread_id].data() +
+ i * config_.head_dim,
+ std::exp(thread_local_cur_attn_lse_[thread_id][i] -
+ new_attn_lse));
+ ggml_vec_scale_f32(
+ config_.head_dim,
+ thread_local_output_fp32_[thread_id].data() +
+ i * config_.head_dim,
+ std::exp(thread_local_attn_lse_[thread_id][i] -
+ new_attn_lse));
+ for (int j = 0; j < config_.head_dim; j++) {
+ thread_local_cur_output_fp32_[thread_id]
+ [i * config_.head_dim +
+ j] +=
+ thread_local_output_fp32_[thread_id]
+ [i * config_.head_dim + j];
+ }
+ thread_local_cur_attn_lse_[thread_id][i] = new_attn_lse;
+ }
+ } else {
+ if (cur_batch_idx != -1) {
+ mutex_[cur_batch_idx][cur_head_id]->lock();
+ for (int i = 0; i < n_gqa_; i++) {
+ if (std::abs(attn_lse_[cur_batch_idx][cur_head_id][i]) <
+ 1e-6) {
+ attn_lse_[cur_batch_idx][cur_head_id][i] =
+ thread_local_cur_attn_lse_[thread_id][i];
+ for (int j = 0; j < config_.head_dim; j++) {
+ output_fp32_[cur_batch_idx][cur_head_id]
+ [i * config_.head_dim + j] =
+ thread_local_cur_output_fp32_
+ [thread_id]
+ [i * config_.head_dim + j];
+ }
+ continue;
+ }
+ float new_attn_lse =
+ attn_lse_[cur_batch_idx][cur_head_id][i] +
+ std::log(
+ 1.0 +
+ std::exp(
+ thread_local_cur_attn_lse_[thread_id][i] -
+ attn_lse_[cur_batch_idx][cur_head_id][i]));
+ ggml_vec_scale_f32(
+ config_.head_dim,
+ output_fp32_[cur_batch_idx][cur_head_id].data() +
+ i * config_.head_dim,
+ std::exp(attn_lse_[cur_batch_idx][cur_head_id][i] -
+ new_attn_lse));
+ ggml_vec_scale_f32(
+ config_.head_dim,
+ thread_local_cur_output_fp32_[thread_id].data() +
+ i * config_.head_dim,
+ std::exp(thread_local_cur_attn_lse_[thread_id][i] -
+ new_attn_lse));
+ for (int j = 0; j < config_.head_dim; j++) {
+ output_fp32_[cur_batch_idx][cur_head_id]
+ [i * config_.head_dim + j] +=
+ thread_local_cur_output_fp32_
+ [thread_id][i * config_.head_dim + j];
+ }
+ attn_lse_[cur_batch_idx][cur_head_id][i] = new_attn_lse;
+ }
+ mutex_[cur_batch_idx][cur_head_id]->unlock();
+ }
+ thread_cur_head_idx_[thread_id].first = batch_id;
+ thread_cur_head_idx_[thread_id].second = head_id;
+ for (int i = 0; i < n_gqa_; i++) {
+ thread_local_cur_attn_lse_[thread_id][i] =
+ thread_local_attn_lse_[thread_id][i];
+ for (int j = 0; j < config_.head_dim; j++) {
+ thread_local_cur_output_fp32_
+ [thread_id][i * config_.head_dim + j] =
+ thread_local_output_fp32_[thread_id]
+ [i * config_.head_dim +
+ j];
+ }
+ }
+ }
+ },
+ // Merge the results of the remaining blocks.
+ [&](int thread_id) {
+ int cur_batch_idx = thread_cur_head_idx_[thread_id].first;
+ int cur_head_id = thread_cur_head_idx_[thread_id].second;
+ if (cur_head_id != -1) {
+ mutex_[cur_batch_idx][cur_head_id]->lock();
+ for (int i = 0; i < n_gqa_; i++) {
+ float new_attn_lse;
+ if (std::abs(attn_lse_[cur_batch_idx][cur_head_id][i]) <
+ 1e-6) {
+ attn_lse_[cur_batch_idx][cur_head_id][i] =
+ thread_local_cur_attn_lse_[thread_id][i];
+ for (int j = 0; j < config_.head_dim; j++) {
+ output_fp32_[cur_batch_idx][cur_head_id]
+ [i * config_.head_dim + j] =
+ thread_local_cur_output_fp32_
+ [thread_id]
+ [i * config_.head_dim + j];
+ }
+ continue;
+ }
+ new_attn_lse =
+ attn_lse_[cur_batch_idx][cur_head_id][i] +
+ std::log(
+ 1.0 +
+ std::exp(thread_local_cur_attn_lse_[thread_id][i] -
+ attn_lse_[cur_batch_idx][cur_head_id][i]));
+ ggml_vec_scale_f32(
+ config_.head_dim,
+ output_fp32_[cur_batch_idx][cur_head_id].data() +
+ i * config_.head_dim,
+ std::exp(attn_lse_[cur_batch_idx][cur_head_id][i] -
+ new_attn_lse));
+ ggml_vec_scale_f32(
+ config_.head_dim,
+ thread_local_cur_output_fp32_[thread_id].data() +
+ i * config_.head_dim,
+ std::exp(thread_local_cur_attn_lse_[thread_id][i] -
+ new_attn_lse));
+ for (int j = 0; j < config_.head_dim; j++) {
+ output_fp32_[cur_batch_idx][cur_head_id]
+ [i * config_.head_dim + j] +=
+ thread_local_cur_output_fp32_[thread_id]
+ [i * config_.head_dim +
+ j];
+ }
+ attn_lse_[cur_batch_idx][cur_head_id][i] = new_attn_lse;
+ }
+ mutex_[cur_batch_idx][cur_head_id]->unlock();
+ }
+ });
+
+ for (int i = 0; i < batch_size; i++) {
+ for (int j = 0; j < max_block_num_after_retrieval_; j++) {
+ int block_idx = block_table_after_retrieval_[i][j];
+ for (int k = 0; k < config_.q_head_num; k++) {
+ attn_sparsity[i * config_.q_head_num + k] +=
+ std::exp(block_lse_[i][block_idx][k] -
+ attn_lse_[i][k / n_gqa_][k % n_gqa_]);
+ }
+ }
+ }
+
+ // Timer end
+ auto end = std::chrono::high_resolution_clock::now();
+ std::chrono::duration diff = end - start;
+ // printf("layer %d time of calculating sparsity: %f s\n", layer_id_,
+ // diff.count());
+}
+
+void KVCache::attn_initialize_kvhead_(int batch_size, int layer_idx,
+ int *block_table, int &max_block_num,
+ int *cache_seqlens) {
+ // Timer start
+ auto start = std::chrono::high_resolution_clock::now();
+ for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {
+ // initialize output_fp32_ and attn_lse_
+ for (int i = 0; i < config_.kv_head_num; i++) {
+ for (int j = 0; j < n_gqa_ * config_.head_dim; j++) {
+ output_fp32_[batch_idx][i][j] = 0;
+ }
+ for (int j = 0; j < n_gqa_; j++) {
+ attn_lse_[batch_idx][i][j] = 0;
+ }
+ }
+
+ // clear top_similar_block_
+ while (!top_similar_block_[batch_idx].empty())
+ top_similar_block_[batch_idx].pop();
+ }
+
+ for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {
+ cache_seqlens_[batch_idx] = cache_seqlens[batch_idx];
+ for (int i = 0; i < max_block_num; i++) {
+ for (int j = 0; j < config_.kv_head_num; j++) {
+ block_table_before_retrieval_kvhead_[batch_idx][i][j] =
+ block_table[batch_idx * max_block_num + i];
+ block_similar_kv_head_[batch_idx][i][j] = 0;
+ }
+ }
+ }
+
+ // Timer end
+ auto end = std::chrono::high_resolution_clock::now();
+ // printf("layer %d time of initializing attn: %f s\n", layer_idx,
+ // std::chrono::duration(end - start).count());
+}
+void KVCache::retrieval_kvcache_kvhead_(const uint16_t *q_in_data,
+ int init_block_num, int local_block_num,
+ int pick_block_num, int q_len,
+ int generate_token_idx, int batch_size,
+ int layer_idx, int *cache_seqlens,
+ int &max_block_num, Backend *backend) {
+ // Timer start
+ auto start = std::chrono::high_resolution_clock::now();
+ max_block_num_after_retrieval_ = 0;
+ if (pick_block_num != -1 &&
+ (generate_token_idx % config_.token_step != 0 ||
+ (layer_idx % config_.layer_step != config_.layer_offset))) {
+
+ if (selected_blocks_num_history_[(layer_idx - config_.layer_offset) /
+ config_.layer_step] == 0) {
+ max_block_num_after_retrieval_ = max_block_num;
+ for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {
+ for (int i = 0; i < max_block_num; i++) {
+ for (int j = 0; j < config_.kv_head_num; j++) {
+ block_table_after_retrieval_kvhead_[batch_idx][i][j] =
+ block_table_before_retrieval_kvhead_[batch_idx][i]
+ [j];
+ }
+ }
+ }
+ } else {
+
+ max_block_num_after_retrieval_ = selected_blocks_num_history_
+ [(layer_idx - config_.layer_offset) / config_.layer_step];
+
+ for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {
+ for (int i = 0; i < max_block_num_after_retrieval_; i++) {
+ for (int j = 0; j < config_.kv_head_num; j++) {
+ block_table_after_retrieval_kvhead_[batch_idx][i][j] =
+ selected_blocks_history_kvhead_
+ [(layer_idx - config_.layer_offset) /
+ config_.layer_step][batch_idx][i][j];
+ }
+ }
+
+ if (cache_seqlens[batch_idx] % config_.block_len == 1) {
+ selected_blocks_num_history_[(layer_idx -
+ config_.layer_offset) /
+ config_.layer_step] += 1;
+ int x =
+ selected_blocks_num_history_[(layer_idx -
+ config_.layer_offset) /
+ config_.layer_step];
+ for (int i = 0; i < config_.kv_head_num; i++) {
+ int last_block_idx =
+ block_table_before_retrieval_kvhead_
+ [batch_idx][cache_seqlens[batch_idx] /
+ config_.block_len][i];
+ selected_blocks_history_kvhead_[(layer_idx -
+ config_.layer_offset) /
+ config_.layer_step]
+ [batch_idx][x - 1][i] =
+ last_block_idx;
+ block_table_after_retrieval_kvhead_[batch_idx][x - 1]
+ [i] = last_block_idx;
+ }
+ }
+ cache_seqlens_[batch_idx] = std::min(
+ cache_seqlens_[batch_idx],
+ (cache_seqlens_[batch_idx] % config_.block_len) +
+ (init_block_num + pick_block_num + local_block_num) *
+ config_.block_len);
+ }
+ }
+ } else if (pick_block_num != -1) {
+ max_block_num_after_retrieval_ =
+ std::min(max_block_num,
+ init_block_num + pick_block_num + local_block_num + 1);
+ calculate_block_similarity_kvhead_(q_in_data, batch_size, layer_idx,
+ q_len, max_block_num, cache_seqlens,
+ init_block_num, local_block_num,
+ pick_block_num, backend);
+ select_block_kvhead_(batch_size, layer_idx, max_block_num,
+ init_block_num, local_block_num, pick_block_num);
+ } else {
+ selected_blocks_num_history_[(layer_idx - config_.layer_offset) /
+ config_.layer_step] = 0;
+ max_block_num_after_retrieval_ = max_block_num;
+ for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {
+ for (int i = 0; i < max_block_num; i++) {
+ for (int j = 0; j < config_.kv_head_num; j++) {
+ block_table_after_retrieval_kvhead_[batch_idx][i][j] =
+ block_table_before_retrieval_kvhead_[batch_idx][i][j];
+ }
+ }
+ }
+ }
+
+ // Timer end
+ auto end = std::chrono::high_resolution_clock::now();
+ // printf("layer %d time of retrieval kvcache: %f s\n", layer_idx,
+ // std::chrono::duration(end - start).count());
+}
+void KVCache::calculate_sparsity_kvhead_(const uint16_t *q_in_data,
+ float *attn_sparsity, int batch_size,
+ int max_block_num, int *block_table,
+ int *cache_seqlens, Backend *backend) {
+ // Timer start
+ auto start = std::chrono::high_resolution_clock::now();
+ seq_len_ = config_.block_len;
+ backend->do_work_stealing_job(
+ batch_size * config_.kv_head_num * max_block_num,
+ [&](int thread_id) {
+ thread_cur_head_idx_[thread_id].first = -1;
+ thread_cur_head_idx_[thread_id].second = -1;
+ },
+ [&](int task_id) {
+ int batch_id = task_id / (config_.kv_head_num * max_block_num);
+ int head_id = (task_id % (config_.kv_head_num * max_block_num)) /
+ max_block_num;
+ int block_id = task_id % max_block_num;
+ int thread_id = Backend::thread_local_id;
+ // If the block is out of the sequence length, skip it.
+ if (cache_seqlens[batch_id] / config_.block_len < block_id) {
+ return;
+ }
+ int block_idx = block_table[batch_id * max_block_num + block_id];
+ if (cache_seqlens_[batch_id] / config_.block_len == block_id) {
+ int seq_len = cache_seqlens_[batch_id] % config_.block_len;
+ if (seq_len == 0)
+ return;
+
+ // Prepare the attention mask for the last block.
+ int full_blocks = seq_len / 8;
+ int remaining_bits = seq_len % 8;
+
+ // Fill full blocks with 1s
+ for (int i = 0; i < full_blocks; ++i) {
+ thread_local_attn_mask_[thread_id][i] = 0xFF;
+ }
+ // Fill the remaining bits in the next block
+ if (remaining_bits > 0 && full_blocks < seq_len_ / 8) {
+ thread_local_attn_mask_[thread_id][full_blocks] =
+ (1 << remaining_bits) - 1;
+ } else {
+ thread_local_attn_mask_[thread_id][full_blocks] = 0;
+ }
+
+ for (int i = full_blocks + 1; i < seq_len_ / 8; ++i) {
+ thread_local_attn_mask_[thread_id][i] = 0;
+ }
+ if (config_.kv_type == ggml_type::GGML_TYPE_F16) {
+ attn_with_kvcache_one_block_(
+ config_.head_dim,
+ config_.q_head_num / config_.kv_head_num, GGML_TYPE_F16,
+ (void *)&q_in_data[batch_id * config_.kv_head_num *
+ n_gqa_ * config_.head_dim +
+ head_id * n_gqa_ * config_.head_dim],
+ seq_len_, 0, false,
+ thread_local_attn_mask_[thread_id].data(),
+ GGML_TYPE_F16, 0,
+ k_cache_fp16_[layer_id_][head_id][block_idx].data(), 0,
+ nullptr, nullptr, GGML_TYPE_F16, 1,
+ v_cache_fp16_[layer_id_][head_id][block_idx].data(), 0,
+ nullptr, nullptr,
+ thread_local_attn_score_[thread_id].data(),
+ thread_local_output_fp32_[thread_id].data(),
+ thread_local_attn_lse_[thread_id].data(),
+ thread_local_draft_[thread_id].data(), nullptr,
+ cos_.data(), sin_.data());
+ } else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {
+ attn_with_kvcache_one_block_(
+ config_.head_dim,
+ config_.q_head_num / config_.kv_head_num,
+ GGML_TYPE_Q8_0, q_q8_0_[batch_id][head_id].data(),
+ seq_len_, 0, false,
+ thread_local_attn_mask_[thread_id].data(),
+ GGML_TYPE_Q4_0, 0,
+ k_cache_q4[layer_id_][head_id][block_idx].data(), 0,
+ nullptr, nullptr, GGML_TYPE_Q4_0, 1,
+ v_cache_q4[layer_id_][head_id][block_idx].data(), 0,
+ nullptr, nullptr,
+ thread_local_attn_score_[thread_id].data(),
+ thread_local_output_q8_0_[thread_id].data(),
+ thread_local_attn_lse_[thread_id].data(),
+ thread_local_draft_[thread_id].data(), nullptr,
+ cos_.data(), sin_.data());
+ dequantize_row_q8_0(
+ thread_local_output_q8_0_[thread_id].data(),
+ thread_local_output_fp32_[thread_id].data(),
+ n_gqa_ * config_.head_dim);
+ } else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {
+ attn_with_kvcache_one_block_(
+ config_.head_dim,
+ config_.q_head_num / config_.kv_head_num,
+ GGML_TYPE_Q8_0, q_q8_0_[batch_id][head_id].data(),
+ seq_len_, 0, false,
+ thread_local_attn_mask_[thread_id].data(),
+ GGML_TYPE_Q8_0, 0,
+ k_cache_q8[layer_id_][head_id][block_idx].data(), 0,
+ nullptr, nullptr, GGML_TYPE_Q8_0, 1,
+ v_cache_q8[layer_id_][head_id][block_idx].data(), 0,
+ nullptr, nullptr,
+ thread_local_attn_score_[thread_id].data(),
+ thread_local_output_q8_0_[thread_id].data(),
+ thread_local_attn_lse_[thread_id].data(),
+ thread_local_draft_[thread_id].data(), nullptr,
+ cos_.data(), sin_.data());
+ dequantize_row_q8_0(
+ thread_local_output_q8_0_[thread_id].data(),
+ thread_local_output_fp32_[thread_id].data(),
+ n_gqa_ * config_.head_dim);
+ }
+ } else {
+ if (config_.kv_type == ggml_type::GGML_TYPE_F16) {
+ attn_with_kvcache_one_block_(
+ config_.head_dim,
+ config_.q_head_num / config_.kv_head_num, GGML_TYPE_F16,
+ (void *)&q_in_data[batch_id * config_.kv_head_num *
+ n_gqa_ * config_.head_dim +
+ head_id * n_gqa_ * config_.head_dim],
+ seq_len_, 0, true, nullptr, GGML_TYPE_F16, 0,
+ k_cache_fp16_[layer_id_][head_id][block_idx].data(), 0,
+ nullptr, nullptr, GGML_TYPE_F16, 1,
+ v_cache_fp16_[layer_id_][head_id][block_idx].data(), 0,
+ nullptr, nullptr,
+ thread_local_attn_score_[thread_id].data(),
+ thread_local_output_fp32_[thread_id].data(),
+ thread_local_attn_lse_[thread_id].data(),
+ thread_local_draft_[thread_id].data(), nullptr,
+ cos_.data(), sin_.data());
+
+ } else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {
+ attn_with_kvcache_one_block_(
+ config_.head_dim,
+ config_.q_head_num / config_.kv_head_num,
+ GGML_TYPE_Q8_0, q_q8_0_[batch_id][head_id].data(),
+ seq_len_, 0, true, nullptr, GGML_TYPE_Q4_0, 0,
+ k_cache_q4[layer_id_][head_id][block_idx].data(), 0,
+ nullptr, nullptr, GGML_TYPE_Q4_0, 1,
+ v_cache_q4[layer_id_][head_id][block_idx].data(), 0,
+ nullptr, nullptr,
+ thread_local_attn_score_[thread_id].data(),
+ thread_local_output_q8_0_[thread_id].data(),
+ thread_local_attn_lse_[thread_id].data(),
+ thread_local_draft_[thread_id].data(), nullptr,
+ cos_.data(), sin_.data());
+ dequantize_row_q8_0(
+ thread_local_output_q8_0_[thread_id].data(),
+ thread_local_output_fp32_[thread_id].data(),
+ n_gqa_ * config_.head_dim);
+ } else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {
+ attn_with_kvcache_one_block_(
+ config_.head_dim,
+ config_.q_head_num / config_.kv_head_num,
+ GGML_TYPE_Q8_0, q_q8_0_[batch_id][head_id].data(),
+ seq_len_, 0, true, nullptr, GGML_TYPE_Q8_0, 0,
+ k_cache_q8[layer_id_][head_id][block_idx].data(), 0,
+ nullptr, nullptr, GGML_TYPE_Q8_0, 1,
+ v_cache_q8[layer_id_][head_id][block_idx].data(), 0,
+ nullptr, nullptr,
+ thread_local_attn_score_[thread_id].data(),
+ thread_local_output_q8_0_[thread_id].data(),
+ thread_local_attn_lse_[thread_id].data(),
+ thread_local_draft_[thread_id].data(), nullptr,
+ cos_.data(), sin_.data());
+ dequantize_row_q8_0(
+ thread_local_output_q8_0_[thread_id].data(),
+ thread_local_output_fp32_[thread_id].data(),
+ n_gqa_ * config_.head_dim);
+ }
+ }
+ for (int i = 0; i < n_gqa_; i++) {
+ block_lse_[batch_id][block_idx][head_id * n_gqa_ + i] =
+ thread_local_attn_lse_[thread_id][i];
+ }
+ int cur_batch_idx = thread_cur_head_idx_[thread_id].first;
+ int cur_head_id = thread_cur_head_idx_[thread_id].second;
+ if (batch_id == cur_batch_idx && head_id == cur_head_id) {
+ for (int i = 0; i < n_gqa_; i++) {
+ float new_attn_lse =
+ thread_local_cur_attn_lse_[thread_id][i] +
+ std::log(
+ 1.0 +
+ std::exp(thread_local_attn_lse_[thread_id][i] -
+ thread_local_cur_attn_lse_[thread_id][i]));
+ ggml_vec_scale_f32(
+ config_.head_dim,
+ thread_local_cur_output_fp32_[thread_id].data() +
+ i * config_.head_dim,
+ std::exp(thread_local_cur_attn_lse_[thread_id][i] -
+ new_attn_lse));
+ ggml_vec_scale_f32(
+ config_.head_dim,
+ thread_local_output_fp32_[thread_id].data() +
+ i * config_.head_dim,
+ std::exp(thread_local_attn_lse_[thread_id][i] -
+ new_attn_lse));
+ for (int j = 0; j < config_.head_dim; j++) {
+ thread_local_cur_output_fp32_[thread_id]
+ [i * config_.head_dim +
+ j] +=
+ thread_local_output_fp32_[thread_id]
+ [i * config_.head_dim + j];
+ }
+ thread_local_cur_attn_lse_[thread_id][i] = new_attn_lse;
+ }
+ } else {
+ if (cur_batch_idx != -1) {
+ mutex_[cur_batch_idx][cur_head_id]->lock();
+ for (int i = 0; i < n_gqa_; i++) {
+ if (std::abs(attn_lse_[cur_batch_idx][cur_head_id][i]) <
+ 1e-6) {
+ attn_lse_[cur_batch_idx][cur_head_id][i] =
+ thread_local_cur_attn_lse_[thread_id][i];
+ for (int j = 0; j < config_.head_dim; j++) {
+ output_fp32_[cur_batch_idx][cur_head_id]
+ [i * config_.head_dim + j] =
+ thread_local_cur_output_fp32_
+ [thread_id]
+ [i * config_.head_dim + j];
+ }
+ continue;
+ }
+ float new_attn_lse =
+ attn_lse_[cur_batch_idx][cur_head_id][i] +
+ std::log(
+ 1.0 +
+ std::exp(
+ thread_local_cur_attn_lse_[thread_id][i] -
+ attn_lse_[cur_batch_idx][cur_head_id][i]));
+ ggml_vec_scale_f32(
+ config_.head_dim,
+ output_fp32_[cur_batch_idx][cur_head_id].data() +
+ i * config_.head_dim,
+ std::exp(attn_lse_[cur_batch_idx][cur_head_id][i] -
+ new_attn_lse));
+ ggml_vec_scale_f32(
+ config_.head_dim,
+ thread_local_cur_output_fp32_[thread_id].data() +
+ i * config_.head_dim,
+ std::exp(thread_local_cur_attn_lse_[thread_id][i] -
+ new_attn_lse));
+ for (int j = 0; j < config_.head_dim; j++) {
+ output_fp32_[cur_batch_idx][cur_head_id]
+ [i * config_.head_dim + j] +=
+ thread_local_cur_output_fp32_
+ [thread_id][i * config_.head_dim + j];
+ }
+ attn_lse_[cur_batch_idx][cur_head_id][i] = new_attn_lse;
+ }
+ mutex_[cur_batch_idx][cur_head_id]->unlock();
+ }
+ thread_cur_head_idx_[thread_id].first = batch_id;
+ thread_cur_head_idx_[thread_id].second = head_id;
+ for (int i = 0; i < n_gqa_; i++) {
+ thread_local_cur_attn_lse_[thread_id][i] =
+ thread_local_attn_lse_[thread_id][i];
+ for (int j = 0; j < config_.head_dim; j++) {
+ thread_local_cur_output_fp32_
+ [thread_id][i * config_.head_dim + j] =
+ thread_local_output_fp32_[thread_id]
+ [i * config_.head_dim +
+ j];
+ }
+ }
+ }
+ },
+ // Merge the results of the remaining blocks.
+ [&](int thread_id) {
+ int cur_batch_idx = thread_cur_head_idx_[thread_id].first;
+ int cur_head_id = thread_cur_head_idx_[thread_id].second;
+ if (cur_head_id != -1) {
+ mutex_[cur_batch_idx][cur_head_id]->lock();
+ for (int i = 0; i < n_gqa_; i++) {
+ float new_attn_lse;
+ if (std::abs(attn_lse_[cur_batch_idx][cur_head_id][i]) <
+ 1e-6) {
+ attn_lse_[cur_batch_idx][cur_head_id][i] =
+ thread_local_cur_attn_lse_[thread_id][i];
+ for (int j = 0; j < config_.head_dim; j++) {
+ output_fp32_[cur_batch_idx][cur_head_id]
+ [i * config_.head_dim + j] =
+ thread_local_cur_output_fp32_
+ [thread_id]
+ [i * config_.head_dim + j];
+ }
+ continue;
+ }
+ new_attn_lse =
+ attn_lse_[cur_batch_idx][cur_head_id][i] +
+ std::log(
+ 1.0 +
+ std::exp(thread_local_cur_attn_lse_[thread_id][i] -
+ attn_lse_[cur_batch_idx][cur_head_id][i]));
+ ggml_vec_scale_f32(
+ config_.head_dim,
+ output_fp32_[cur_batch_idx][cur_head_id].data() +
+ i * config_.head_dim,
+ std::exp(attn_lse_[cur_batch_idx][cur_head_id][i] -
+ new_attn_lse));
+ ggml_vec_scale_f32(
+ config_.head_dim,
+ thread_local_cur_output_fp32_[thread_id].data() +
+ i * config_.head_dim,
+ std::exp(thread_local_cur_attn_lse_[thread_id][i] -
+ new_attn_lse));
+ for (int j = 0; j < config_.head_dim; j++) {
+ output_fp32_[cur_batch_idx][cur_head_id]
+ [i * config_.head_dim + j] +=
+ thread_local_cur_output_fp32_[thread_id]
+ [i * config_.head_dim +
+ j];
+ }
+ attn_lse_[cur_batch_idx][cur_head_id][i] = new_attn_lse;
+ }
+ mutex_[cur_batch_idx][cur_head_id]->unlock();
+ }
+ });
+
+ for (int i = 0; i < batch_size; i++) {
+ for (int j = 0; j < max_block_num_after_retrieval_; j++) {
+ for (int k = 0; k < config_.q_head_num; k++) {
+ int block_idx =
+ block_table_after_retrieval_kvhead_[i][j][k / n_gqa_];
+ attn_sparsity[i * config_.q_head_num + k] +=
+ std::exp(block_lse_[i][block_idx][k] -
+ attn_lse_[i][k / n_gqa_][k % n_gqa_]);
+ }
+ }
+ }
+
+ // Timer end
+ auto end = std::chrono::high_resolution_clock::now();
+ std::chrono::duration diff = end - start;
+ // printf("layer %d time of calculating sparsity: %f s\n", layer_id_,
+ // diff.count());
+}
+void KVCache::calculate_block_similarity_kvhead_(
+ const uint16_t *q_in_data, int batch_size, int layer_idx, int q_len,
+ int max_block_num, int *cache_seqlens, int init_block_num,
+ int local_block_num, int pick_block_num, Backend *backend) {
+ // Timer start
+ auto start = std::chrono::high_resolution_clock::now();
+ backend->do_work_stealing_job(
+ batch_size * max_block_num, nullptr,
+ [&](int task_id) {
+ int batch_id = task_id / max_block_num;
+ int block_id = task_id % max_block_num;
+ int seq_len = cache_seqlens_[batch_id];
+
+ if (block_id < init_block_num ||
+ block_id >= (seq_len / config_.block_len) - local_block_num) {
+ return;
+ }
+ int block_idx =
+ block_table_before_retrieval_kvhead_[batch_id][block_id][0];
+
+ for (int head_id = 0; head_id < config_.q_head_num; head_id++) {
+ for (int i = 0; i < config_.head_dim; i++) {
+ float q_i = 0, qa_i = std::numeric_limits::lowest();
+ for (int q_id = 0; q_id < q_len; q_id++) {
+ q_i += GGML_FP16_TO_FP32(
+ q_in_data[batch_id * q_len * config_.q_head_num *
+ config_.head_dim +
+ q_id * config_.q_head_num *
+ config_.head_dim +
+ head_id * config_.head_dim + i]);
+ }
+ q_i /= q_len;
+ for (int anchor_id = 0; anchor_id < config_.anchor_num;
+ anchor_id++) {
+ qa_i = std::max(
+ qa_i,
+ GGML_FP16_TO_FP32(
+ anchor_[layer_idx * config_.max_block_num *
+ config_.anchor_num *
+ config_.q_head_num *
+ config_.head_dim +
+ block_idx * config_.anchor_num *
+ config_.q_head_num *
+ config_.head_dim +
+ anchor_id * config_.q_head_num *
+ config_.head_dim +
+ head_id * config_.head_dim + i]) *
+ q_i);
+ }
+ block_similar_kv_head_[batch_id][block_id]
+ [head_id / n_gqa_] += qa_i;
+ }
+ }
+ },
+ nullptr);
+
+ // Timer end
+ auto end = std::chrono::high_resolution_clock::now();
+ std::chrono::duration diff = end - start;
+ // printf("layer %d time of calculating similarity: %f s\n", layer_idx,
+ // diff.count());
+}
+void KVCache::select_block_kvhead_(int batch_size, int layer_idx,
+ int max_block_num, int init_block_num,
+ int local_block_num, int pick_block_num) {
+ // Timer start
+ auto start = std::chrono::high_resolution_clock::now();
+
+ for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {
+ int cache_len_after_retrieval = 0;
+ if (cache_seqlens_[batch_idx] / config_.block_len <=
+ init_block_num + pick_block_num + local_block_num) {
+ selected_blocks_num_history_[(layer_idx - config_.layer_offset) /
+ config_.layer_step] = 0;
+ for (int i = 0; i < max_block_num; i++) {
+ for (int j = 0; j < config_.kv_head_num; j++) {
+ block_table_after_retrieval_kvhead_[batch_idx][i][j] =
+ block_table_before_retrieval_kvhead_[batch_idx][i][j];
+ }
+ }
+ continue;
+ }
+ for (int head_id = 0; head_id < config_.kv_head_num; head_id++) {
+
+ for (int block_id = init_block_num;
+ block_id < (cache_seqlens_[batch_idx] / config_.block_len) -
+ local_block_num;
+ block_id++) {
+
+ top_similar_block_[batch_idx].push(std::make_pair(
+ block_similar_kv_head_[batch_idx][block_id][head_id],
+ block_table_before_retrieval_kvhead_[batch_idx][block_id]
+ [head_id]));
+ if (top_similar_block_[batch_idx].size() > pick_block_num) {
+ top_similar_block_[batch_idx].pop();
+ }
+ }
+
+ int i = 0;
+ for (; i < init_block_num; i++) {
+ block_table_after_retrieval_kvhead_[batch_idx][i][head_id] =
+ block_table_before_retrieval_kvhead_[batch_idx][i][head_id];
+ }
+ while (!top_similar_block_[batch_idx].empty()) {
+ block_table_after_retrieval_kvhead_[batch_idx][i][head_id] =
+ top_similar_block_[batch_idx].top().second;
+ top_similar_block_[batch_idx].pop();
+ i++;
+ }
+ for (; i < init_block_num + pick_block_num + local_block_num; i++) {
+ block_table_after_retrieval_kvhead_[batch_idx][i][head_id] =
+ block_table_before_retrieval_kvhead_
+ [batch_idx]
+ [(cache_seqlens_[batch_idx] / config_.block_len) -
+ local_block_num + i - init_block_num - pick_block_num]
+ [head_id];
+ }
+ if (cache_seqlens_[batch_idx] % config_.block_len != 0) {
+ block_table_after_retrieval_kvhead_[batch_idx][i][head_id] =
+ block_table_before_retrieval_kvhead_[batch_idx][(
+ cache_seqlens_[batch_idx] / config_.block_len)]
+ [head_id];
+ cache_len_after_retrieval =
+ (cache_seqlens_[batch_idx] % config_.block_len) +
+ i * config_.block_len;
+ i++;
+ } else {
+ cache_len_after_retrieval =
+ (cache_seqlens_[batch_idx] % config_.block_len) +
+ i * config_.block_len;
+ }
+ for (int j = 0; j < i; j++) {
+ selected_blocks_history_kvhead_
+ [(layer_idx - config_.layer_offset) / config_.layer_step]
+ [batch_idx][j][head_id] =
+ block_table_after_retrieval_kvhead_[batch_idx][j]
+ [head_id];
+ }
+ }
+ cache_seqlens_[batch_idx] = cache_len_after_retrieval;
+ selected_blocks_num_history_[(layer_idx - config_.layer_offset) /
+ config_.layer_step] =
+ (cache_len_after_retrieval + config_.block_len - 1) /
+ config_.block_len;
+ }
+
+ // Timer end
+ auto end = std::chrono::high_resolution_clock::now();
+ std::chrono::duration diff = end - start;
+ // printf("layer %d time of selecting block: %f s\n", layer_idx,
+ // diff.count())
+}
+
+void KVCache::get_attn_sparsity(const ggml_fp16_t *q_in, float *attn_sparsity,
+ int layer_idx, int generate_token_idx,
+ int q_len, int batch_size, int max_block_num,
+ int *block_table, int *cache_seqlens,
+ int *block_table_origin,
+ int *cache_seqlens_origin,
+ int max_block_num_origin, int topk, int local,
+ Backend *backend) {
+ // Timer start
+ auto start = std::chrono::high_resolution_clock::now();
+ layer_id_ = layer_idx;
+ int thread_num = backend->get_thread_num();
+ batch_size = 1;
+
+ const uint16_t *q_in_data = const_cast(q_in);
+
+ quantize_q_(q_in_data, batch_size);
+ if (config_.retrieval_type == RetrievalType::LAYER) {
+ attn_initialize_layer_(batch_size, layer_idx, block_table,
+ max_block_num, cache_seqlens);
+ retrieval_kvcache_layer_(q_in_data, 1, local, topk, q_len,
+ generate_token_idx, batch_size, layer_idx,
+ cache_seqlens, max_block_num, backend);
+ calculate_sparsity_layer_(q_in_data, attn_sparsity, batch_size,
+ max_block_num_origin, block_table_origin,
+ cache_seqlens_origin, backend);
+ } else if (config_.retrieval_type == RetrievalType::KVHEAD) {
+ attn_initialize_kvhead_(batch_size, layer_idx, block_table,
+ max_block_num, cache_seqlens);
+ retrieval_kvcache_kvhead_(q_in_data, 1, local, topk, q_len,
+ generate_token_idx, batch_size, layer_idx,
+ cache_seqlens, max_block_num, backend);
+ calculate_sparsity_kvhead_(q_in_data, attn_sparsity, batch_size,
+ max_block_num_origin, block_table_origin,
+ cache_seqlens_origin, backend);
+ }
+}
+
+void KVCache::attn_with_kvcache_one_block_(
+ int head_dim, int bsz,
+ ggml_type q_type, // GGML data type of `Q`, only supports fp16 and q8_0
+ // [bsz, head_dim]
+ // Quantization is always on the head_dim dimension (per_token). If
+ // head_dim % 32 != 0, an error will be raised. The size must be bsz *
+ // head_dim/32 * qtype_size.
+ const void *q,
+
+ int past_kv_len, int past_kv_offset,
+ bool is_full_attn, // true indicates a full 1 mask
+ // If is_full_attn = false, a bit matrix representing the mask is
+ // passed. [bsz, past_kv_len]
+ const uint8_t *attn_mask,
+
+ ggml_type k_type, // GGML data type of `K Cache`, only supports fp16,
+ // q4_0, q8_0
+ int k_quant_type, // 0 for per_token, 1 for per_channel, others raise an
+ // error
+ // [seq_len, head_dim]
+ // If quant_type == 0, head_dim % 32 must be 0.
+ // If quant_type == 1, seq_len % 32 must be 0.
+ const void *k_cache,
+
+ // k_anchor_type must be fp16
+ int num_k_anchor, // num_k_anchor == 0 indicates no anchor
+ // [num_k_anchor, head_dim]
+ const void *k_cache_anchors,
+ // Each token is associated with the nearest previous position's anchor,
+ // with the same distance.
+ const int *k_cache_anchor_pos,
+
+ // v_cache similar to k_cache
+ ggml_type v_type, int v_quant_type,
+ // [head_dim, seq_len]
+ const void *v_cache, int num_v_anchor, const void *v_cache_anchors,
+ const int *v_cache_anchor_pos,
+
+ // Pre-allocated buffer for intermediate calculations [bsz,
+ // past_kv_len]. No malloc is performed inside this function.
+ float *attn_score,
+
+ // Output: [bsz, head_dim], with the same type as q_type
+ void *output,
+ // [bsz]
+ float *lse,
+
+ // Pre-allocated temporary buffer with sufficient size:
+ // (2 * bsz * past_kv_len + 6 * bsz * head_dim + 2 * past_kv_len *
+ // head_dim + past_kv_len * head_dim / 32) bytes.
+ void *draft,
+
+ // Apply rotary embedding online
+ const int *rotary_angle, const void *rotary_cos, const void *rotary_sin
+ // rotary_cos=None,
+ // rotary_sin=None,
+ // cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
+ // cache_batch_idx: Optional[torch.Tensor] = None,
+ // rotary_interleaved=True,
+
+ // // Not supported for now
+ // window_size=(-1, -1), # -1 means infinite context window
+ // alibi_slopes=None,
+) {
+ assert(head_dim % 32 == 0);
+ assert(k_quant_type == 0);
+ assert(v_quant_type == 1);
+ assert(q_type == GGML_TYPE_F16 || q_type == GGML_TYPE_Q8_0);
+ if (q_type == GGML_TYPE_F16) {
+ assert(k_type == GGML_TYPE_F16);
+ assert(v_type == GGML_TYPE_F16);
+
+ // attn = q * k + q * k_anchor
+ // TODO: anchor
+ assert(num_k_anchor == 0);
+
+ if (rotary_angle != nullptr) {
+ ggml_fp16_t *k_cache_with_rope_fp16 =
+ (reinterpret_cast(draft) +
+ sizeof(block_q8_0) * bsz * past_kv_len / QK8_0 +
+ sizeof(float) * bsz * head_dim);
+ // dequant k_cache and apply rope
+ // k_rope(i) = k(i) * cos(i) - k(i+l) * sin(i)
+ // k_rope(i+l) = k(i+l) * cos(i+l) + k(i) * sin(i)
+
+ // k(i)cos(i) -> k_rope(i)
+ // k(i)sin(i+l) -> k_rope(i+l)
+
+ // k(i)cos(i) -> k_rope(i)
+ // -k(i)sin(i-l) -> k_rope(i-l)
+
+ std::vector block_fp32(32);
+ for (int k = 0; k < past_kv_len; k++) {
+ int angle = rotary_angle[k];
+ for (int l = 0; l < head_dim / 32; l++) {
+ for (int m = 0; m < 32; m++) {
+ float x = GGML_FP16_TO_FP32((
+ (ggml_fp16_t *)k_cache)[k * head_dim + l * 32 + m]);
+ float sin_val = GGML_FP16_TO_FP32(
+ ((ggml_fp16_t *)
+ rotary_sin)[angle * head_dim + l * 32 + m]);
+ float cos_val = GGML_FP16_TO_FP32(
+ ((ggml_fp16_t *)
+ rotary_cos)[angle * head_dim + l * 32 + m]);
+
+ if (l * 32 + m < head_dim / 2) {
+ k_cache_with_rope_fp16[k * head_dim + l * 32 + m] =
+ GGML_FP32_TO_FP16(x * cos_val);
+ k_cache_with_rope_fp16[k * head_dim + l * 32 + m +
+ head_dim / 2] =
+ GGML_FP32_TO_FP16(-x * sin_val);
+ } else {
+ k_cache_with_rope_fp16[k * head_dim + l * 32 + m] =
+ GGML_FP32_TO_FP16(
+ GGML_FP16_TO_FP32(
+ k_cache_with_rope_fp16[k * head_dim +
+ l * 32 + m]) +
+ x * sin_val);
+ k_cache_with_rope_fp16[k * head_dim + l * 32 + m -
+ head_dim / 2] =
+ GGML_FP32_TO_FP16(
+ GGML_FP16_TO_FP32(
+ k_cache_with_rope_fp16[k * head_dim +
+ l * 32 + m -
+ head_dim / 2]) -
+ x * cos_val);
+ }
+ }
+ }
+ }
+
+ llamafile_sgemm(past_kv_len, bsz, head_dim,
+ (ggml_fp16_t *)k_cache_with_rope_fp16, head_dim,
+ (ggml_fp16_t *)q, head_dim, attn_score, past_kv_len,
+ 0, 1, GGML_TASK_TYPE_COMPUTE, k_type, GGML_TYPE_F16,
+ GGML_TYPE_F32, GGML_PREC_DEFAULT);
+ } else {
+ bool ok = llamafile_sgemm(
+ past_kv_len, bsz, head_dim, (ggml_fp16_t *)k_cache, head_dim,
+ (ggml_fp16_t *)q, head_dim, attn_score, past_kv_len, 0, 1,
+ GGML_TASK_TYPE_COMPUTE, k_type, GGML_TYPE_F16, GGML_TYPE_F32,
+ GGML_PREC_DEFAULT);
+
+ if (!ok) {
+ printf("llamafile_sgemm failed\n");
+ }
+ }
+ // attn = attn * scale
+ float scale_factor = 1.0 / std::sqrt(float(head_dim));
+ ggml_vec_scale_f32(bsz * past_kv_len, attn_score, scale_factor);
+
+ // attn = attn & mask
+ if (!is_full_attn) {
+ for (int i = 0; i < bsz; i++) {
+ for (int j = 0; j < past_kv_len; j++) {
+ int index = i * past_kv_len + j;
+ if (!(attn_mask[j / 8] & (1 << (j % 8)))) {
+ attn_score[index] =
+ std::numeric_limits::lowest();
+ }
+ }
+ }
+ }
+
+ // attn = softmax(attn)
+ for (int i = 0; i < bsz; i++) {
+ float sum_exp = 0;
+ for (int j = 0; j < past_kv_len; j++) {
+ attn_score[i * past_kv_len + j] =
+ std::exp(attn_score[i * past_kv_len + j]);
+ sum_exp += attn_score[i * past_kv_len + j];
+ }
+ for (int j = 0; j < past_kv_len; j++) {
+ attn_score[i * past_kv_len + j] /= sum_exp;
+ }
+ if (lse != nullptr) {
+ lse[i] = std::log(sum_exp);
+ }
+ }
+
+ // output = attn * v + attn * v_anchor
+ // std::vector sum(bsz * head_dim);
+ float *sum = reinterpret_cast(reinterpret_cast(draft) +
+ sizeof(block_q8_0) * bsz *
+ past_kv_len / QK8_0);
+
+ // float* attn_score_fp16(bsz, past_kv_len)
+ ggml_fp16_t *attn_score_fp16 = (reinterpret_cast(
+ reinterpret_cast(draft) +
+ sizeof(block_q8_0) * bsz * past_kv_len / QK8_0 +
+ sizeof(float) * bsz * head_dim));
+
+ for (int i = 0; i < bsz * past_kv_len; i++) {
+ attn_score_fp16[i] = GGML_FP32_TO_FP16(attn_score[i]);
+ }
+
+ // TODO: anchor
+ assert(num_v_anchor == 0);
+ bool ok = llamafile_sgemm(
+ head_dim, bsz, past_kv_len, (ggml_fp16_t *)v_cache, past_kv_len,
+ (ggml_fp16_t *)attn_score_fp16, past_kv_len, sum, head_dim, 0, 1,
+ GGML_TASK_TYPE_COMPUTE, v_type, GGML_TYPE_F16, GGML_TYPE_F32,
+ GGML_PREC_DEFAULT);
+ if (!ok) {
+ printf("llamafile_sgemm failed\n");
+ }
+
+ // copy to output
+ for (int i = 0; i < bsz; i++) {
+ for (int j = 0; j < head_dim; j++) {
+ ((float *)output)[i * head_dim + j] = sum[i * head_dim + j];
+ }
+ }
+ } else {
+ assert(k_type == GGML_TYPE_Q4_0 || k_type == GGML_TYPE_Q8_0);
+ assert(v_type == GGML_TYPE_Q4_0 || v_type == GGML_TYPE_Q8_0);
+
+ // attn = q * k + q * k_anchor
+ // TODO: anchor
+ assert(num_k_anchor == 0);
+
+ if (rotary_angle != nullptr) {
+ ggml_fp16_t *k_cache_with_rope_fp16 =
+ (reinterpret_cast(draft) +
+ sizeof(block_q8_0) * bsz * past_kv_len / QK8_0 +
+ sizeof(float) * bsz * head_dim);
+ block_q4_0 *k_cache_with_rope_q4 =
+ (reinterpret_cast(draft) +
+ sizeof(block_q8_0) * bsz * past_kv_len / QK8_0 +
+ sizeof(float) * bsz * head_dim) +
+ sizeof(ggml_fp16_t) * bsz * head_dim;
+ // dequant k_cache and apply rope
+ // k_rope(i) = k(i) * cos(i) - k(i+l) * sin(i)
+ // k_rope(i+l) = k(i+l) * cos(i+l) + k(i) * sin(i)
+
+ // k(i)cos(i) -> k_rope(i)
+ // k(i)sin(i+l) -> k_rope(i+l)
+
+ // k(i)cos(i) -> k_rope(i)
+ // -k(i)sin(i-l) -> k_rope(i-l)
+
+ std::vector block_fp32(32);
+ for (int k = 0; k < past_kv_len; k++) {
+ int angle = rotary_angle[k];
+ for (int l = 0; l < head_dim / 32; l++) {
+ block_q4_0 block =
+ ((block_q4_0 *)k_cache)[k * head_dim / 32 + l];
+ dequantize_row_q4_0(&block, block_fp32.data(), 32);
+ for (int m = 0; m < 32; m++) {
+ float sin_val = GGML_FP16_TO_FP32(
+ ((ggml_fp16_t *)
+ rotary_sin)[angle * head_dim + l * 32 + m]);
+ float cos_val = GGML_FP16_TO_FP32(
+ ((ggml_fp16_t *)
+ rotary_cos)[angle * head_dim + l * 32 + m]);
+
+ if (l * 32 + m < head_dim / 2) {
+ k_cache_with_rope_fp16[k * head_dim + l * 32 + m] =
+ GGML_FP32_TO_FP16(block_fp32[m] * cos_val);
+ k_cache_with_rope_fp16[k * head_dim + l * 32 + m +
+ head_dim / 2] =
+ GGML_FP32_TO_FP16(-block_fp32[m] * sin_val);
+ } else {
+ k_cache_with_rope_fp16[k * head_dim + l * 32 + m] +=
+ GGML_FP32_TO_FP16(block_fp32[m] * sin_val);
+ k_cache_with_rope_fp16[k * head_dim + l * 32 + m -
+ head_dim / 2] -=
+ GGML_FP32_TO_FP16(block_fp32[m] * cos_val);
+ }
+ }
+ }
+ }
+ // quantize k_cache_with_rope_fp16
+ for (int k = 0; k < past_kv_len; k++) {
+ for (int l = 0; l < head_dim / 32; l++) {
+ for (int m = 0; m < 32; m++) {
+ block_fp32[m] = GGML_FP16_TO_FP32(
+ k_cache_with_rope_fp16[k * head_dim + l * 32 + m]);
+ }
+ quantize_row_q4_0(
+ block_fp32.data(),
+ &k_cache_with_rope_q4[k * head_dim / 32 + l], 32);
+ }
+ }
+
+ llamafile_sgemm(past_kv_len, bsz, head_dim / 32,
+ (block_q4_0 *)k_cache_with_rope_q4, head_dim / 32,
+ (block_q8_0 *)q, head_dim / 32, attn_score,
+ past_kv_len, 0, 1, GGML_TASK_TYPE_COMPUTE, k_type,
+ GGML_TYPE_Q8_0, GGML_TYPE_F32, GGML_PREC_DEFAULT);
+ } else {
+ llamafile_sgemm(past_kv_len, bsz, head_dim / 32,
+ (block_q4_0 *)k_cache, head_dim / 32,
+ (block_q8_0 *)q, head_dim / 32, attn_score,
+ past_kv_len, 0, 1, GGML_TASK_TYPE_COMPUTE, k_type,
+ GGML_TYPE_Q8_0, GGML_TYPE_F32, GGML_PREC_DEFAULT);
+ }
+
+ // attn = attn * scale
+ float scale_factor = 1.0 / std::sqrt(float(head_dim));
+ ggml_vec_scale_f32(bsz * past_kv_len, attn_score, scale_factor);
+
+ // attn = attn & mask
+ if (!is_full_attn) {
+ for (int i = 0; i < bsz; i++) {
+ for (int j = 0; j < past_kv_len; j++) {
+ int index = i * past_kv_len + j;
+ if (!(attn_mask[j / 8] & (1 << (j % 8)))) {
+ attn_score[index] =
+ std::numeric_limits::lowest();
+ }
+ }
+ }
+ }
+
+ // attn = softmax(attn)
+ for (int i = 0; i < bsz; i++) {
+ float sum_exp = 0;
+ for (int j = 0; j < past_kv_len; j++) {
+ attn_score[i * past_kv_len + j] =
+ std::exp(attn_score[i * past_kv_len + j]);
+ sum_exp += attn_score[i * past_kv_len + j];
+ }
+ for (int j = 0; j < past_kv_len; j++) {
+ attn_score[i * past_kv_len + j] /= sum_exp;
+ }
+ if (lse != nullptr) {
+ lse[i] = std::log(sum_exp);
+ }
+ }
+
+ // output = attn * v + attn * v_anchor
+ // std::vector attn_q8_0(bsz * past_kv_len / QK8_0);
+ block_q8_0 *attn_q8_0 = reinterpret_cast(draft);
+ quantize_row_q8_0(attn_score, attn_q8_0, bsz * past_kv_len);
+ // std::vector sum(bsz * head_dim);
+ float *sum = reinterpret_cast(reinterpret_cast(draft) +
+ sizeof(block_q8_0) * bsz *
+ past_kv_len / QK8_0);
+ // TODO: anchor
+ assert(num_v_anchor == 0);
+ llamafile_sgemm(head_dim, bsz, past_kv_len / 32, (block_q4_0 *)v_cache,
+ past_kv_len / 32, attn_q8_0, past_kv_len / 32, sum,
+ head_dim, 0, 1, GGML_TASK_TYPE_COMPUTE, v_type,
+ GGML_TYPE_Q8_0, GGML_TYPE_F32, GGML_PREC_DEFAULT);
+
+ quantize_row_q8_0(sum, (block_q8_0 *)output, bsz * head_dim);
+ }
+}
diff --git a/ktransformers/ktransformers_ext/operators/kvcache/kvcache_load_dump.cpp b/ktransformers/ktransformers_ext/operators/kvcache/kvcache_load_dump.cpp
new file mode 100644
index 0000000..eadf90f
--- /dev/null
+++ b/ktransformers/ktransformers_ext/operators/kvcache/kvcache_load_dump.cpp
@@ -0,0 +1,123 @@
+/**
+ * @Description :
+ * @Author : Jianwei Dong
+ * @Date : 2024-08-26 22:47:06
+ * @Version : 1.0.0
+ * @LastEditors : Jianwei Dong
+ * @LastEditTime : 2024-08-26 22:47:06
+ * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
+ **/
+
+#include "kvcache.h"
+void KVCache::load_kvcache(std::string tensor_file_path, Backend *backend) {
+ // Timer start
+ auto start = std::chrono::high_resolution_clock::now();
+ std::ifstream ifs_tensor(tensor_file_path, std::ios::binary);
+ if (!ifs_tensor) {
+ throw std::runtime_error("Failed to open tensor file");
+ }
+ ifs_tensor.read(reinterpret_cast(&cache_total_len_),
+ sizeof(cache_total_len_));
+ int past_block_num =
+ (cache_total_len_ + config_.block_len - 1) / config_.block_len;
+ printf("cache_total_len: %d, past_block_num: %d\n", cache_total_len_,
+ past_block_num);
+ for (int i = 0; i < config_.layer_num; ++i) {
+ past_block_num_[i] = past_block_num;
+ }
+ ifs_tensor.read(reinterpret_cast(anchor_.data()),
+ anchor_.size() * sizeof(ggml_fp16_t));
+ for (int i = 0; i < config_.layer_num; ++i) {
+ for (int j = 0; j < config_.kv_head_num; ++j) {
+ for (int k = 0; k < past_block_num_[i]; ++k) {
+ if (config_.kv_type == GGML_TYPE_F16) {
+ ifs_tensor.read(
+ reinterpret_cast(k_cache_fp16_[i][j][k].data()),
+ k_cache_fp16_[i][j][k].size() * sizeof(ggml_fp16_t));
+ ifs_tensor.read(
+ reinterpret_cast(v_cache_fp16_[i][j][k].data()),
+ v_cache_fp16_[i][j][k].size() * sizeof(ggml_fp16_t));
+ } else if (config_.kv_type == GGML_TYPE_Q4_0) {
+ ifs_tensor.read(
+ reinterpret_cast(k_cache_q4[i][j][k].data()),
+ k_cache_q4[i][j][k].size() * sizeof(block_q4_0));
+ ifs_tensor.read(
+ reinterpret_cast(v_cache_q4[i][j][k].data()),
+ v_cache_q4[i][j][k].size() * sizeof(block_q4_0));
+ }
+ }
+ }
+ for (int k = 0; k < past_block_num_[i]; ++k) {
+ for (int l = 0; l < config_.block_len; l++) {
+ ifs_tensor.read(
+ reinterpret_cast(importance_[i][k][l].data()),
+ importance_[i][k][l].size() * sizeof(ggml_fp16_t));
+ }
+ }
+ }
+ ifs_tensor.close();
+ // Timer end
+ auto end = std::chrono::high_resolution_clock::now();
+ std::chrono::duration diff = end - start;
+ printf("time of load: %f s\n", diff.count());
+}
+void KVCache::dump_kvcache(int *block_table, int cache_total_len,
+ std::string tensor_file_path, Backend *backend) {
+ // Timer start
+ auto start = std::chrono::high_resolution_clock::now();
+ std::ofstream ofs(tensor_file_path, std::ios::binary);
+ printf("dump_kvcache: %s\n", tensor_file_path.c_str());
+ if (!ofs.is_open()) {
+ std::cerr << "Cannot open file " << tensor_file_path << std::endl;
+ return;
+ }
+ ofs.write(reinterpret_cast(&cache_total_len),
+ sizeof(cache_total_len));
+ int past_block_num =
+ (cache_total_len + config_.block_len - 1) / config_.block_len;
+ printf("cache_total_len: %d, past_block_num: %d\n", cache_total_len,
+ past_block_num);
+ ofs.write(reinterpret_cast(anchor_.data()),
+ anchor_.size() * sizeof(ggml_fp16_t));
+ for (int i = 0; i < config_.layer_num; ++i) {
+ for (int j = 0; j < config_.kv_head_num; ++j) {
+ for (int k = 0; k < past_block_num; ++k) {
+ int block_idx = block_table[k];
+ if (config_.kv_type == GGML_TYPE_F16) {
+ ofs.write(reinterpret_cast(
+ k_cache_fp16_[i][j][block_idx].data()),
+ k_cache_fp16_[i][j][block_idx].size() *
+ sizeof(ggml_fp16_t));
+ ofs.write(reinterpret_cast(
+ v_cache_fp16_[i][j][block_idx].data()),
+ v_cache_fp16_[i][j][block_idx].size() *
+ sizeof(ggml_fp16_t));
+
+ } else if (config_.kv_type == GGML_TYPE_Q4_0) {
+ ofs.write(reinterpret_cast(
+ k_cache_q4[i][j][block_idx].data()),
+ k_cache_q4[i][j][block_idx].size() *
+ sizeof(block_q4_0));
+ ofs.write(reinterpret_cast(
+ v_cache_q4[i][j][block_idx].data()),
+ v_cache_q4[i][j][block_idx].size() *
+ sizeof(block_q4_0));
+ }
+ }
+ }
+ for (int k = 0; k < past_block_num; ++k) {
+ int block_idx = block_table[k];
+ for (int l = 0; l < config_.block_len; l++) {
+ ofs.write(reinterpret_cast(
+ importance_[i][block_idx][l].data()),
+ importance_[i][block_idx][l].size() *
+ sizeof(ggml_fp16_t));
+ }
+ }
+ }
+ ofs.close();
+ // Timer end
+ auto end = std::chrono::high_resolution_clock::now();
+ std::chrono::duration diff = end - start;
+ printf("time of dump: %f s\n", diff.count());
+}
\ No newline at end of file
diff --git a/ktransformers/ktransformers_ext/operators/kvcache/kvcache_read_write.cpp b/ktransformers/ktransformers_ext/operators/kvcache/kvcache_read_write.cpp
new file mode 100644
index 0000000..998f1b0
--- /dev/null
+++ b/ktransformers/ktransformers_ext/operators/kvcache/kvcache_read_write.cpp
@@ -0,0 +1,1019 @@
+/**
+ * @Description :
+ * @Author : Jianwei Dong
+ * @Date : 2024-08-26 22:47:06
+ * @Version : 1.0.0
+ * @LastEditors : Jianwei Dong
+ * @LastEditTime : 2024-08-26 22:47:06
+ * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
+ **/
+
+#include "kvcache.h"
+
+void KVCache::get_anchor_one_block(ggml_fp16_t *anchor, int layer_id,
+ int block_idx, Backend *backend) {
+ // Timer start
+ auto start = std::chrono::high_resolution_clock::now();
+
+ layer_id_ = layer_id;
+ block_idx = block_idx;
+ seq_len_ = config_.block_len;
+ anchor_data_ = const_cast(anchor);
+
+ // Timer end
+ auto end = std::chrono::high_resolution_clock::now();
+ std::chrono::duration duration = end - start;
+ printf("layer %d block %d time of reading anchor: %f s\n", layer_id,
+ block_idx, duration.count());
+}
+
+void KVCache::update_anchor_one_block(const ggml_fp16_t *anchor, int layer_id,
+ int block_idx, Backend *backend) {
+ // Timer start
+ auto start = std::chrono::high_resolution_clock::now();
+
+ layer_id_ = layer_id;
+ block_idx = block_idx;
+ seq_len_ = config_.block_len;
+ anchor_data_ = const_cast(anchor);
+
+ // Each task updates the anchor of a certain position
+ // backend->do_work_stealing_job(config_.anchor_num, [&](int task_id) {
+ // int k = task_id % config_.anchor_num;
+ // int head_id = task_id / config_.anchor_num;
+ // memcpy(anchor_[layer_id_][head_id][block_idx].data() +
+ // k * config_.head_dim,
+ // anchor_data_ + k * config_.head_dim,
+ // sizeof(uint16_t) * config_.head_dim);
+ // });
+
+ // Timer end
+ auto end = std::chrono::high_resolution_clock::now();
+ std::chrono::duration duration = end - start;
+ printf("layer %d block %d time of writting anchor: %f s\n", layer_id,
+ block_idx, duration.count());
+}
+
+void KVCache::update_importance_one_block(const ggml_fp16_t *importance,
+ int layer_id, int block_idx,
+ Backend *backend) {
+ // Timer start
+ auto start = std::chrono::high_resolution_clock::now();
+
+ layer_id_ = layer_id;
+ block_idx = block_idx;
+ seq_len_ = config_.block_len;
+ importance_data_ = const_cast(importance);
+
+ // Each task updates the importance of a certain position
+ backend->do_work_stealing_job(
+ config_.block_len, nullptr,
+ [&](int task_id) {
+ int k = task_id;
+ memcpy(importance_[layer_id_][block_idx].data() + k,
+ importance_data_ + k, sizeof(uint16_t));
+ },
+ nullptr);
+
+ // Timer end
+ auto end = std::chrono::high_resolution_clock::now();
+ std::chrono::duration duration = end - start;
+ printf("layer %d block %d time of writting importance: %f s\n", layer_id,
+ block_idx, duration.count());
+}
+
+void KVCache::get_importance_one_block(ggml_fp16_t *importance, int layer_id,
+ int block_idx, Backend *backend) {
+ // Timer start
+ auto start = std::chrono::high_resolution_clock::now();
+
+ layer_id_ = layer_id;
+ block_idx = block_idx;
+ seq_len_ = config_.block_len;
+ importance_data_ = const_cast(importance);
+
+ // Each task updates the importance of a certain position
+ backend->do_work_stealing_job(
+ config_.block_len, nullptr,
+ [&](int task_id) {
+ int k = task_id;
+ memcpy(importance_data_ + k,
+ importance_[layer_id_][block_idx].data() + k,
+ sizeof(uint16_t));
+ },
+ nullptr);
+
+ // Timer end
+ auto end = std::chrono::high_resolution_clock::now();
+ std::chrono::duration duration = end - start;
+ printf("layer %d block %d time of reading importance: %f s\n", layer_id,
+ block_idx, duration.count());
+}
+
+void KVCache::update_kvcache_one_block_fp16(const ggml_fp16_t *k_in,
+ const ggml_fp16_t *v_in,
+ int layer_id, int block_idx,
+ Backend *backend) {
+ // Timer start
+ auto start = std::chrono::high_resolution_clock::now();
+
+ layer_id_ = layer_id;
+ block_idx = block_idx;
+ seq_len_ = config_.block_len;
+ k_data_ = const_cast(k_in);
+ v_data_ = const_cast(v_in);
+
+ int new_block_num = std::max((int)past_block_num_[layer_id], block_idx + 1);
+
+ importance_[layer_id_].resize(new_block_num);
+
+ for (int i = 0; i < config_.kv_head_num; i++) {
+ k_cache_q4[layer_id][i].resize(new_block_num);
+ v_cache_q4[layer_id][i].resize(new_block_num);
+ // anchor_[layer_id][i].resize(new_block_num);
+ }
+
+ for (int i = 0; i < new_block_num; i++) {
+ importance_[layer_id][i].resize(config_.block_len);
+ }
+
+ // Each task updates the k cache or v cache of a certain header
+ backend->do_work_stealing_job(
+ config_.kv_head_num * 2, nullptr,
+ [&](int task_id) {
+ std::vector block_fp32(32);
+ int head_id = task_id / 2;
+ if (task_id & 1) {
+ // fill k_cache_
+ k_cache_q4[layer_id_][head_id][block_idx].resize(
+ config_.block_len * config_.head_dim / 32);
+ for (int k = 0; k < config_.block_len; k++) {
+ for (int l = 0; l < config_.head_dim / 32; l++) {
+ block_q4_0 block;
+ for (int m = 0; m < 32; m++) {
+
+ block_fp32[m] = GGML_FP16_TO_FP32(
+ k_data_[((0 * config_.kv_head_num + head_id) *
+ seq_len_ +
+ 0 * config_.block_len + k) *
+ config_.head_dim +
+ l * 32 + m]);
+ }
+ quantize_row_q4_0(block_fp32.data(), &block, 32);
+ k_cache_q4[layer_id_][head_id][block_idx]
+ [k * config_.head_dim / 32 + l] = block;
+ }
+ }
+ } else {
+ // fill v_cache_
+ v_cache_q4[layer_id_][head_id][block_idx].resize(
+ config_.head_dim * config_.block_len / 32);
+ for (int k = 0; k < config_.block_len / 32; k++) {
+ for (int l = 0; l < config_.head_dim; l++) {
+ block_q4_0 block;
+ for (int m = 0; m < 32; m++) {
+
+ block_fp32[m] = GGML_FP16_TO_FP32(
+ v_data_[((0 * config_.kv_head_num + head_id) *
+ seq_len_ +
+ 0 * config_.block_len + k * 32 + m) *
+ config_.head_dim +
+ l]);
+ }
+ quantize_row_q4_0(block_fp32.data(), &block, 32);
+ v_cache_q4[layer_id_][head_id][block_idx]
+ [l * config_.block_len / 32 + k] = block;
+ }
+ }
+ }
+ },
+ nullptr);
+ past_block_num_[layer_id] = new_block_num;
+
+ // Timer end
+ auto end = std::chrono::high_resolution_clock::now();
+ std::chrono::duration duration = end - start;
+ printf("layer %d block %d time of writting KV Cache: %f s\n", layer_id,
+ block_idx, duration.count());
+ // printf("get_one_block_fp16 duration: %ld\n", duration);
+}
+
+void KVCache::get_kvcache_one_block_fp16(ggml_fp16_t *k_in, ggml_fp16_t *v_in,
+ int layer_id, int block_idx,
+ Backend *backend) {
+ // Timer start
+ auto start = std::chrono::high_resolution_clock::now();
+
+ layer_id_ = layer_id;
+ seq_len_ = config_.block_len;
+ k_data_ = reinterpret_cast(k_in);
+ v_data_ = reinterpret_cast(v_in);
+
+ // printf("layer_id: %d, block_idx: %d\n", layer_id, block_idx);
+ // Each task gets the k cache or v cache of a certain header
+ backend->do_work_stealing_job(
+ config_.kv_head_num * 2, nullptr,
+ [&](int task_id) {
+ std::vector block_fp32(32);
+ int head_id = task_id / 2;
+ if (task_id & 1) {
+ // get k_cache_
+ for (int k = 0; k < config_.block_len; k++) {
+ for (int l = 0; l < config_.head_dim / 32; l++) {
+ block_q4_0 block =
+ k_cache_q4[layer_id_][head_id][block_idx]
+ [k * config_.head_dim / 32 + l];
+ dequantize_row_q4_0(&block, block_fp32.data(), 32);
+ for (int m = 0; m < 32; m++) {
+
+ k_data_[((0 * config_.kv_head_num + head_id) *
+ seq_len_ +
+ 0 * config_.block_len + k) *
+ config_.head_dim +
+ l * 32 + m] =
+ GGML_FP32_TO_FP16(block_fp32[m]);
+ }
+ }
+ }
+ } else {
+ // get v_cache_
+ for (int k = 0; k < config_.block_len / 32; k++) {
+ for (int l = 0; l < config_.head_dim; l++) {
+ block_q4_0 block =
+ v_cache_q4[layer_id_][head_id][block_idx]
+ [l * config_.block_len / 32 + k];
+ dequantize_row_q4_0(&block, block_fp32.data(), 32);
+ for (int m = 0; m < 32; m++) {
+
+ v_data_[((0 * config_.kv_head_num + head_id) *
+ seq_len_ +
+ 0 * config_.block_len + k * 32 + m) *
+ config_.head_dim +
+ l] = GGML_FP32_TO_FP16(block_fp32[m]);
+ }
+ }
+ }
+ }
+ },
+ nullptr);
+
+ // Timer end
+ auto end = std::chrono::high_resolution_clock::now();
+ std::chrono::duration duration = end - start;
+ printf("layer %d block %d time of reading KV Cache: %f s\n", layer_id,
+ block_idx, duration.count());
+ // printf("get_one_block_fp16 duration: %ld\n", duration);
+}
+
+// k_in: (batch_size, seq_len, head_num, head_dim)
+// v_in: (batch_size, seq_len, head_num, head_dim)
+void KVCache::get_and_update_kvcache_fp16(ggml_fp16_t *k_in, ggml_fp16_t *v_in,
+ int layer_id, int *block_table,
+ int batch_size, int max_block_num,
+ int *cache_seqlens, int q_len,
+ Backend *backend) {
+ // Timer start
+ auto start = std::chrono::high_resolution_clock::now();
+
+ layer_id_ = layer_id;
+ k_data_ = const_cast(k_in);
+ v_data_ = const_cast(v_in);
+
+ // Each task updates the k cache and v cache of a certain header
+ backend->do_work_stealing_job(
+ config_.kv_head_num * max_block_num * batch_size, nullptr,
+ [&](int task_id) {
+ // printf("block_idx: %d, task_id: %d\n", block_idx, task_id);
+ std::vector block_fp32(32);
+ int batch_id = task_id / (config_.kv_head_num * max_block_num);
+ int block_id = (task_id / config_.kv_head_num) % max_block_num;
+ int head_id = task_id % config_.kv_head_num;
+ int block_idx = block_table[batch_id * max_block_num + block_id];
+ int seq_len = cache_seqlens[batch_id];
+ int block_l = block_id * config_.block_len;
+ int block_r = block_id * config_.block_len + config_.block_len;
+
+ if (block_l < seq_len) {
+ if (config_.kv_type == ggml_type::GGML_TYPE_F16) {
+ for (int k = 0; k < config_.block_len; k++) {
+ if (block_id * config_.block_len + k >= seq_len)
+ break;
+ for (int l = 0; l < config_.head_dim; l++) {
+ k_data_
+ [batch_id *
+ (max_block_num * config_.block_len *
+ config_.kv_head_num * config_.head_dim) +
+ block_id *
+ (config_.block_len * config_.kv_head_num *
+ config_.head_dim) +
+ k * (config_.kv_head_num * config_.head_dim) +
+ head_id * config_.head_dim + l] =
+ k_cache_fp16_[layer_id_][head_id][block_idx]
+ [k * config_.head_dim + l];
+ v_data_
+ [batch_id *
+ (max_block_num * config_.block_len *
+ config_.kv_head_num * config_.head_dim) +
+ block_id *
+ (config_.block_len * config_.kv_head_num *
+ config_.head_dim) +
+ k * (config_.kv_head_num * config_.head_dim) +
+ head_id * config_.head_dim + l] =
+ v_cache_fp16_[layer_id_][head_id][block_idx]
+ [l * config_.block_len + k];
+ }
+ }
+ } else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {
+ // get k_cache_
+ for (int k = 0; k < config_.block_len; k++) {
+ if (block_id * config_.block_len + k >= seq_len)
+ break;
+ for (int l = 0; l < config_.head_dim / 32; l++) {
+ block_q4_0 block =
+ k_cache_q4[layer_id_][head_id][block_idx]
+ [k * config_.head_dim / 32 + l];
+ dequantize_row_q4_0(&block, block_fp32.data(), 32);
+ for (int m = 0; m < 32; m++) {
+
+ k_data_[batch_id *
+ (max_block_num * config_.block_len *
+ config_.kv_head_num *
+ config_.head_dim) +
+ block_id * (config_.block_len *
+ config_.kv_head_num *
+ config_.head_dim) +
+ k * (config_.kv_head_num *
+ config_.head_dim) +
+ head_id * config_.head_dim + l * 32 +
+ m] = GGML_FP32_TO_FP16(block_fp32[m]);
+ }
+ }
+ }
+ // get v_cache_
+ for (int k = 0; k < config_.block_len / 32; k++) {
+ for (int l = 0; l < config_.head_dim; l++) {
+ block_q4_0 block =
+ v_cache_q4[layer_id_][head_id][block_idx]
+ [l * config_.block_len / 32 + k];
+ dequantize_row_q4_0(&block, block_fp32.data(), 32);
+ for (int m = 0; m < 32; m++) {
+
+ if (block_id * config_.block_len + k * 32 + m >=
+ seq_len)
+ break;
+ v_data_[batch_id *
+ (max_block_num * config_.block_len *
+ config_.kv_head_num *
+ config_.head_dim) +
+ block_id * (config_.block_len *
+ config_.kv_head_num *
+ config_.head_dim) +
+ (k * 32 + m) * config_.kv_head_num *
+ config_.head_dim +
+ head_id * config_.head_dim + l] =
+ GGML_FP32_TO_FP16(block_fp32[m]);
+ }
+ }
+ }
+ } else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {
+ // get k_cache_
+ for (int k = 0; k < config_.block_len; k++) {
+ if (block_id * config_.block_len + k >= seq_len)
+ break;
+ for (int l = 0; l < config_.head_dim / 32; l++) {
+ block_q8_0 block =
+ k_cache_q8[layer_id_][head_id][block_idx]
+ [k * config_.head_dim / 32 + l];
+ dequantize_row_q8_0(&block, block_fp32.data(), 32);
+ for (int m = 0; m < 32; m++) {
+
+ k_data_[batch_id *
+ (max_block_num * config_.block_len *
+ config_.kv_head_num *
+ config_.head_dim) +
+ block_id * (config_.block_len *
+ config_.kv_head_num *
+ config_.head_dim) +
+ k * (config_.kv_head_num *
+ config_.head_dim) +
+ head_id * config_.head_dim + l * 32 +
+ m] = GGML_FP32_TO_FP16(block_fp32[m]);
+ }
+ }
+ }
+ // get v_cache_
+ for (int k = 0; k < config_.block_len / 32; k++) {
+ for (int l = 0; l < config_.head_dim; l++) {
+ block_q8_0 block =
+ v_cache_q8[layer_id_][head_id][block_idx]
+ [l * config_.block_len / 32 + k];
+ dequantize_row_q8_0(&block, block_fp32.data(), 32);
+ for (int m = 0; m < 32; m++) {
+
+ if (block_id * config_.block_len + k * 32 + m >=
+ seq_len)
+ break;
+ v_data_[batch_id *
+ (max_block_num * config_.block_len *
+ config_.kv_head_num *
+ config_.head_dim) +
+ block_id * (config_.block_len *
+ config_.kv_head_num *
+ config_.head_dim) +
+ (k * 32 + m) * config_.kv_head_num *
+ config_.head_dim +
+ head_id * config_.head_dim + l] =
+ GGML_FP32_TO_FP16(block_fp32[m]);
+ }
+ }
+ }
+ }
+ }
+ if (block_r > seq_len && block_l < seq_len + q_len) {
+ if (config_.kv_type == ggml_type::GGML_TYPE_F16) {
+ for (int k = 0; k < config_.block_len; k++) {
+ if (block_id * config_.block_len + k >=
+ seq_len + q_len ||
+ block_id * config_.block_len + k < seq_len)
+ continue;
+ for (int l = 0; l < config_.head_dim; l++) {
+ k_cache_fp16_[layer_id_][head_id][block_idx]
+ [k * config_.head_dim + l] = k_data_
+ [batch_id * (max_block_num *
+ config_.block_len *
+ config_.kv_head_num *
+ config_.head_dim) +
+ block_id * (config_.block_len *
+ config_.kv_head_num *
+ config_.head_dim) +
+ k * (config_.kv_head_num *
+ config_.head_dim) +
+ head_id * config_.head_dim + l];
+ v_cache_fp16_[layer_id_][head_id][block_idx]
+ [l * config_.block_len + k] = v_data_
+ [batch_id * (max_block_num *
+ config_.block_len *
+ config_.kv_head_num *
+ config_.head_dim) +
+ block_id * (config_.block_len *
+ config_.kv_head_num *
+ config_.head_dim) +
+ k * (config_.kv_head_num *
+ config_.head_dim) +
+ head_id * config_.head_dim + l];
+ }
+ }
+ } else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {
+ // fill k_cache_
+ for (int k = 0; k < config_.block_len; k++) {
+ if (block_id * config_.block_len + k >=
+ seq_len + q_len ||
+ block_id * config_.block_len + k < seq_len)
+ continue;
+ for (int l = 0; l < config_.head_dim / 32; l++) {
+ block_q4_0 block;
+ for (int m = 0; m < 32; m++) {
+
+ block_fp32[m] = GGML_FP16_TO_FP32(
+ k_data_[batch_id * (max_block_num *
+ config_.block_len *
+ config_.kv_head_num *
+ config_.head_dim) +
+ block_id * (config_.block_len *
+ config_.kv_head_num *
+ config_.head_dim) +
+ k * (config_.kv_head_num *
+ config_.head_dim) +
+ head_id * config_.head_dim +
+ l * 32 + m]);
+ }
+ quantize_row_q4_0(block_fp32.data(), &block, 32);
+ k_cache_q4[layer_id_][head_id][block_idx]
+ [k * config_.head_dim / 32 + l] = block;
+ }
+ }
+
+ // fill v_cache_
+ for (int k = 0; k < config_.block_len / 32; k++) {
+ for (int l = 0; l < config_.head_dim; l++) {
+ block_q4_0 block;
+ for (int m = 0; m < 32; m++) {
+
+ if (block_id * config_.block_len + k * 32 + m >=
+ seq_len + q_len) {
+ block_fp32[m] = 0;
+ continue;
+ }
+ block_fp32[m] = GGML_FP16_TO_FP32(
+ v_data_[batch_id * (max_block_num *
+ config_.block_len *
+ config_.kv_head_num *
+ config_.head_dim) +
+ block_id * (config_.block_len *
+ config_.kv_head_num *
+ config_.head_dim) +
+ (k * 32 + m) * config_.kv_head_num *
+ config_.head_dim +
+ head_id * config_.head_dim + l]);
+ }
+ quantize_row_q4_0(block_fp32.data(), &block, 32);
+ v_cache_q4[layer_id_][head_id][block_idx]
+ [l * config_.block_len / 32 + k] = block;
+ }
+ }
+ } else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {
+ // fill k_cache_
+ for (int k = 0; k < config_.block_len; k++) {
+ if (block_id * config_.block_len + k >=
+ seq_len + q_len ||
+ block_id * config_.block_len + k < seq_len)
+ continue;
+ for (int l = 0; l < config_.head_dim / 32; l++) {
+ block_q8_0 block;
+ for (int m = 0; m < 32; m++) {
+
+ block_fp32[m] = GGML_FP16_TO_FP32(
+ k_data_[batch_id * (max_block_num *
+ config_.block_len *
+ config_.kv_head_num *
+ config_.head_dim) +
+ block_id * (config_.block_len *
+ config_.kv_head_num *
+ config_.head_dim) +
+ k * (config_.kv_head_num *
+ config_.head_dim) +
+ head_id * config_.head_dim +
+ l * 32 + m]);
+ }
+ quantize_row_q8_0(block_fp32.data(), &block, 32);
+ k_cache_q8[layer_id_][head_id][block_idx]
+ [k * config_.head_dim / 32 + l] = block;
+ }
+ }
+
+ // fill v_cache_
+ for (int k = 0; k < config_.block_len / 32; k++) {
+ for (int l = 0; l < config_.head_dim; l++) {
+ block_q8_0 block;
+ for (int m = 0; m < 32; m++) {
+
+ if (block_id * config_.block_len + k * 32 + m >=
+ seq_len + q_len) {
+ block_fp32[m] = 0;
+ continue;
+ }
+ block_fp32[m] = GGML_FP16_TO_FP32(
+ v_data_[batch_id * (max_block_num *
+ config_.block_len *
+ config_.kv_head_num *
+ config_.head_dim) +
+ block_id * (config_.block_len *
+ config_.kv_head_num *
+ config_.head_dim) +
+ (k * 32 + m) * config_.kv_head_num *
+ config_.head_dim +
+ head_id * config_.head_dim + l]);
+ }
+ quantize_row_q8_0(block_fp32.data(), &block, 32);
+ v_cache_q8[layer_id_][head_id][block_idx]
+ [l * config_.block_len / 32 + k] = block;
+ }
+ }
+ }
+ }
+ },
+ nullptr);
+
+ // Timer end
+ auto end = std::chrono::high_resolution_clock::now();
+ std::chrono::duration duration = end - start;
+
+ // printf("layer %d time of reading and updating KV Cache: %f s\n",
+ // layer_id,
+ // duration.count());
+}
+
+void KVCache::update_importance(const ggml_fp16_t *importance, int layer_id,
+ int *block_table, int batch_size,
+ int max_block_num, int *offset, int width,
+ Backend *backend) {
+ // Timer start
+ auto start = std::chrono::high_resolution_clock::now();
+
+ layer_id_ = layer_id;
+ importance_data_ = const_cast(importance);
+
+ // Each task updates the importance of a certain position
+ backend->do_work_stealing_job(
+ max_block_num * batch_size, nullptr,
+ [&](int task_id) {
+ int block_id = task_id % max_block_num;
+ int batch_id = task_id / max_block_num;
+ int block_idx = block_table[batch_id * max_block_num + block_id];
+ if (block_id > (offset[batch_id] + width) / config_.block_len) {
+ return;
+ }
+ for (int k = 0; k < config_.block_len; k++) {
+ for (int head_id = 0; head_id < config_.q_head_num; head_id++) {
+ importance_[layer_id_][block_idx][k][head_id] =
+ GGML_FP32_TO_FP16(
+ GGML_FP16_TO_FP32(
+ importance_data_[batch_id * max_block_num *
+ config_.block_len *
+ config_.q_head_num +
+ (block_id * config_.block_len +
+ k) *
+ config_.q_head_num +
+ head_id]) +
+ GGML_FP16_TO_FP32(
+ importance_[layer_id_][block_idx][k][head_id]));
+ }
+ }
+ },
+ nullptr);
+
+ // Timer end
+ auto end = std::chrono::high_resolution_clock::now();
+ std::chrono::duration duration = end - start;
+
+ // printf("layer %d time of updating importance: %f s\n", layer_id,
+ // duration.count());
+}
+
+void KVCache::get_kvcache_fp16(ggml_fp16_t *k_in, ggml_fp16_t *v_in,
+ int layer_id, int *block_table, int batch_size,
+ int max_block_num, int *cache_seqlens,
+ Backend *backend) {
+ // Timer start
+ auto start = std::chrono::high_resolution_clock::now();
+
+ layer_id_ = layer_id;
+ k_data_ = const_cast(k_in);
+ v_data_ = const_cast(v_in);
+
+ // Each task updates the k cache and v cache of a certain header
+ backend->do_work_stealing_job(
+ config_.kv_head_num * max_block_num * batch_size, nullptr,
+ [&](int task_id) {
+ // printf("block_idx: %d, task_id: %d\n", block_idx, task_id);
+ std::vector block_fp32(32);
+ int batch_id = task_id / (config_.kv_head_num * max_block_num);
+ int block_id = (task_id / config_.kv_head_num) % max_block_num;
+ int head_id = task_id % config_.kv_head_num;
+ int block_idx = block_table[batch_id * max_block_num + block_id];
+ int seq_len = cache_seqlens[batch_id];
+ int block_l = block_id * config_.block_len;
+ int block_r = block_id * config_.block_len + config_.block_len;
+
+ if (block_l < seq_len) {
+ if (config_.kv_type == ggml_type::GGML_TYPE_F16) {
+ for (int k = 0; k < config_.block_len; k++) {
+ if (block_id * config_.block_len + k >= seq_len)
+ break;
+ for (int l = 0; l < config_.head_dim; l++) {
+ k_data_
+ [batch_id *
+ (max_block_num * config_.block_len *
+ config_.kv_head_num * config_.head_dim) +
+ block_id *
+ (config_.block_len * config_.kv_head_num *
+ config_.head_dim) +
+ k * (config_.kv_head_num * config_.head_dim) +
+ head_id * config_.head_dim + l] =
+ k_cache_fp16_[layer_id_][head_id][block_idx]
+ [k * config_.head_dim + l];
+ v_data_
+ [batch_id *
+ (max_block_num * config_.block_len *
+ config_.kv_head_num * config_.head_dim) +
+ block_id *
+ (config_.block_len * config_.kv_head_num *
+ config_.head_dim) +
+ k * (config_.kv_head_num * config_.head_dim) +
+ head_id * config_.head_dim + l] =
+ v_cache_fp16_[layer_id_][head_id][block_idx]
+ [l * config_.block_len + k];
+ }
+ }
+ } else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {
+ // get k_cache_
+ for (int k = 0; k < config_.block_len; k++) {
+ if (block_id * config_.block_len + k >= seq_len)
+ break;
+ for (int l = 0; l < config_.head_dim / 32; l++) {
+ block_q4_0 block =
+ k_cache_q4[layer_id_][head_id][block_idx]
+ [k * config_.head_dim / 32 + l];
+ dequantize_row_q4_0(&block, block_fp32.data(), 32);
+ for (int m = 0; m < 32; m++) {
+
+ k_data_[batch_id *
+ (max_block_num * config_.block_len *
+ config_.kv_head_num *
+ config_.head_dim) +
+ block_id * (config_.block_len *
+ config_.kv_head_num *
+ config_.head_dim) +
+ k * (config_.kv_head_num *
+ config_.head_dim) +
+ head_id * config_.head_dim + l * 32 +
+ m] = GGML_FP32_TO_FP16(block_fp32[m]);
+ }
+ }
+ }
+ // get v_cache_
+ for (int k = 0; k < config_.block_len / 32; k++) {
+ for (int l = 0; l < config_.head_dim; l++) {
+ block_q4_0 block =
+ v_cache_q4[layer_id_][head_id][block_idx]
+ [l * config_.block_len / 32 + k];
+ dequantize_row_q4_0(&block, block_fp32.data(), 32);
+ for (int m = 0; m < 32; m++) {
+
+ if (block_id * config_.block_len + k * 32 + m >=
+ seq_len)
+ break;
+ v_data_[batch_id *
+ (max_block_num * config_.block_len *
+ config_.kv_head_num *
+ config_.head_dim) +
+ block_id * (config_.block_len *
+ config_.kv_head_num *
+ config_.head_dim) +
+ (k * 32 + m) * config_.kv_head_num *
+ config_.head_dim +
+ head_id * config_.head_dim + l] =
+ GGML_FP32_TO_FP16(block_fp32[m]);
+ }
+ }
+ }
+ } else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {
+ // get k_cache_
+ for (int k = 0; k < config_.block_len; k++) {
+ if (block_id * config_.block_len + k >= seq_len)
+ break;
+ for (int l = 0; l < config_.head_dim / 32; l++) {
+ block_q8_0 block =
+ k_cache_q8[layer_id_][head_id][block_idx]
+ [k * config_.head_dim / 32 + l];
+ dequantize_row_q8_0(&block, block_fp32.data(), 32);
+ for (int m = 0; m < 32; m++) {
+
+ k_data_[batch_id *
+ (max_block_num * config_.block_len *
+ config_.kv_head_num *
+ config_.head_dim) +
+ block_id * (config_.block_len *
+ config_.kv_head_num *
+ config_.head_dim) +
+ k * (config_.kv_head_num *
+ config_.head_dim) +
+ head_id * config_.head_dim + l * 32 +
+ m] = GGML_FP32_TO_FP16(block_fp32[m]);
+ }
+ }
+ }
+ // get v_cache_
+ for (int k = 0; k < config_.block_len / 32; k++) {
+ for (int l = 0; l < config_.head_dim; l++) {
+ block_q8_0 block =
+ v_cache_q8[layer_id_][head_id][block_idx]
+ [l * config_.block_len / 32 + k];
+ dequantize_row_q8_0(&block, block_fp32.data(), 32);
+ for (int m = 0; m < 32; m++) {
+
+ if (block_id * config_.block_len + k * 32 + m >=
+ seq_len)
+ break;
+ v_data_[batch_id *
+ (max_block_num * config_.block_len *
+ config_.kv_head_num *
+ config_.head_dim) +
+ block_id * (config_.block_len *
+ config_.kv_head_num *
+ config_.head_dim) +
+ (k * 32 + m) * config_.kv_head_num *
+ config_.head_dim +
+ head_id * config_.head_dim + l] =
+ GGML_FP32_TO_FP16(block_fp32[m]);
+ }
+ }
+ }
+ }
+ }
+ },
+ nullptr);
+
+ // Timer end
+ auto end = std::chrono::high_resolution_clock::now();
+ std::chrono::duration duration = end - start;
+}
+
+void KVCache::update_kvcache_fp16(const ggml_fp16_t *k_in,
+ const ggml_fp16_t *v_in, int layer_id,
+ int *block_table, int batch_size,
+ int max_block_num, int *cache_seqlens,
+ int q_len, Backend *backend) {
+ // Timer start
+ auto start = std::chrono::high_resolution_clock::now();
+
+ layer_id_ = layer_id;
+ k_data_ = const_cast(k_in);
+ v_data_ = const_cast(v_in);
+ // Each task updates the k cache and v cache of a certain header
+ backend->do_work_stealing_job(
+ batch_size * config_.kv_head_num * q_len, nullptr,
+ [&](int task_id) {
+ int batch_id = task_id / (config_.kv_head_num * q_len);
+ int head_id = task_id / q_len % config_.kv_head_num;
+ int seq_len = cache_seqlens[batch_id] + task_id % q_len;
+ int q_offset = task_id % q_len;
+
+ int block_id = seq_len / config_.block_len;
+ int block_idx = block_table[batch_id * max_block_num + block_id];
+ int pos_in_block = seq_len % config_.block_len;
+
+ if (config_.kv_type == ggml_type::GGML_TYPE_F16) {
+ for (int l = 0; l < config_.head_dim; l++) {
+ k_cache_fp16_[layer_id_][head_id][block_idx]
+ [pos_in_block * config_.head_dim + l] =
+ k_data_[batch_id *
+ (q_len * config_.kv_head_num *
+ config_.head_dim) +
+ q_offset * config_.kv_head_num *
+ config_.head_dim +
+ head_id * config_.head_dim + l];
+ v_cache_fp16_[layer_id_][head_id][block_idx]
+ [l * config_.block_len + pos_in_block] =
+ v_data_[batch_id *
+ (q_len * config_.kv_head_num *
+ config_.head_dim) +
+ q_offset * config_.kv_head_num *
+ config_.head_dim +
+ head_id * config_.head_dim + l];
+ }
+ } else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {
+ std::vector block_fp32(32);
+ // fill k_cache_
+ for (int l = 0; l < config_.head_dim / 32; l++) {
+ block_q4_0 block;
+ for (int m = 0; m < 32; m++) {
+
+ block_fp32[m] = GGML_FP16_TO_FP32(
+ k_data_[batch_id * (q_len * config_.kv_head_num *
+ config_.head_dim) +
+ head_id * config_.head_dim + l * 32 + m]);
+ }
+ quantize_row_q4_0(block_fp32.data(), &block, 32);
+
+ k_cache_q4[layer_id_][head_id][block_idx]
+ [pos_in_block * config_.head_dim / 32 + l] =
+ block;
+ }
+
+ // fill v_cache_
+ for (int l = 0; l < config_.head_dim; l++) {
+ block_q4_0 block = v_cache_q4[layer_id_][head_id][block_idx]
+ [l * config_.block_len / 32 +
+ pos_in_block / 32];
+ dequantize_row_q4_0(&block, block_fp32.data(), 32);
+ block_fp32[pos_in_block % 32] = GGML_FP16_TO_FP32(
+ v_data_[batch_id * (q_len * config_.kv_head_num *
+ config_.head_dim) +
+ head_id * config_.head_dim + l]);
+ quantize_row_q4_0(block_fp32.data(), &block, 32);
+ v_cache_q4[layer_id_][head_id][block_idx]
+ [l * config_.block_len / 32 + pos_in_block / 32] =
+ block;
+ }
+ } else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {
+ std::vector block_fp32(32);
+ // fill k_cache_
+ for (int l = 0; l < config_.head_dim / 32; l++) {
+ block_q8_0 block;
+ for (int m = 0; m < 32; m++) {
+
+ block_fp32[m] = GGML_FP16_TO_FP32(
+ k_data_[batch_id * (q_len * config_.kv_head_num *
+ config_.head_dim) +
+ head_id * config_.head_dim + l * 32 + m]);
+ }
+ quantize_row_q8_0(block_fp32.data(), &block, 32);
+
+ k_cache_q8[layer_id_][head_id][block_idx]
+ [pos_in_block * config_.head_dim / 32 + l] =
+ block;
+ }
+
+ // fill v_cache_
+ for (int l = 0; l < config_.head_dim; l++) {
+ block_q8_0 block = v_cache_q8[layer_id_][head_id][block_idx]
+ [l * config_.block_len / 32 +
+ pos_in_block / 32];
+ dequantize_row_q8_0(&block, block_fp32.data(), 32);
+ block_fp32[pos_in_block % 32] = GGML_FP16_TO_FP32(
+ v_data_[batch_id * (q_len * config_.kv_head_num *
+ config_.head_dim) +
+ head_id * config_.head_dim + l]);
+ quantize_row_q8_0(block_fp32.data(), &block, 32);
+ v_cache_q8[layer_id_][head_id][block_idx]
+ [l * config_.block_len / 32 + pos_in_block / 32] =
+ block;
+ }
+ }
+ },
+ nullptr);
+
+ // Timer end
+ auto end = std::chrono::high_resolution_clock::now();
+ std::chrono::duration duration = end - start;
+ // printf("layer %d time of reading KV Cache: %f s\n", layer_id,
+ // duration.count());
+}
+
+void KVCache::get_all_kvcache_one_layer(int layer_id, ggml_fp16_t *k_in,
+ ggml_fp16_t *v_in, Backend *backend) {
+ // Timer start
+ auto start = std::chrono::high_resolution_clock::now();
+
+ layer_id_ = layer_id;
+ seq_len_ = config_.block_len;
+ block_num_ = get_cache_total_block_num();
+ k_data_ = reinterpret_cast(k_in);
+ v_data_ = reinterpret_cast(v_in);
+
+ // Each task gets the k cache or v cache of a certain header
+ backend->do_work_stealing_job(
+ config_.kv_head_num * past_block_num_[layer_id] * 2, nullptr,
+ [&](int task_id) {
+ std::vector block_fp32(32);
+ int head_id = task_id / 2 / past_block_num_[layer_id];
+ int block_idx = task_id / 2 % past_block_num_[layer_id];
+ if (block_idx >= block_num_)
+ return;
+
+ int max_offset = 0;
+ if (task_id & 1) {
+ // get k_cache_
+ for (int k = 0; k < config_.block_len; k++) {
+ if (block_idx * seq_len_ + k >= cache_total_len_)
+ break;
+ for (int l = 0; l < config_.head_dim / 32; l++) {
+ block_q4_0 block =
+ k_cache_q4[layer_id_][head_id][block_idx]
+ [k * config_.head_dim / 32 + l];
+ dequantize_row_q4_0(&block, block_fp32.data(), 32);
+ for (int m = 0; m < 32; m++) {
+
+ k_data_[(head_id * cache_total_len_ +
+ block_idx * config_.block_len + k) *
+ config_.head_dim +
+ l * 32 + m] =
+ GGML_FP32_TO_FP16(block_fp32[m]);
+ max_offset = std::max(
+ max_offset,
+ (int)(head_id * cache_total_len_ +
+ block_idx * config_.block_len + k) *
+ config_.head_dim +
+ l * 32 + m);
+ }
+ }
+ }
+ } else {
+ // get v_cache_
+ for (int k = 0; k < config_.block_len / 32; k++) {
+ for (int l = 0; l < config_.head_dim; l++) {
+ block_q4_0 block =
+ v_cache_q4[layer_id_][head_id][block_idx]
+ [l * config_.block_len / 32 + k];
+ dequantize_row_q4_0(&block, block_fp32.data(), 32);
+ for (int m = 0; m < 32; m++) {
+
+ if (block_idx * seq_len_ + k * 32 + m >=
+ cache_total_len_)
+ break;
+ v_data_[(head_id * cache_total_len_ +
+ block_idx * config_.block_len + k * 32 +
+ m) *
+ config_.head_dim +
+ l] = GGML_FP32_TO_FP16(block_fp32[m]);
+ max_offset =
+ std::max(max_offset,
+ (int)((head_id * cache_total_len_ +
+ block_idx * config_.block_len +
+ k * 32 + m) *
+ config_.head_dim +
+ l));
+ }
+ }
+ }
+ }
+ },
+ nullptr);
+
+ // Timer end
+ auto end = std::chrono::high_resolution_clock::now();
+ std::chrono::duration duration = end - start;
+ // printf("layer %d block num %d time of reading all KV Cache: %f s\n",
+ // layer_id, block_num_, duration.count());
+}
diff --git a/ktransformers/ktransformers_ext/operators/kvcache/kvcache_utils.cpp b/ktransformers/ktransformers_ext/operators/kvcache/kvcache_utils.cpp
new file mode 100644
index 0000000..f1d6f7d
--- /dev/null
+++ b/ktransformers/ktransformers_ext/operators/kvcache/kvcache_utils.cpp
@@ -0,0 +1,1157 @@
+/**
+ * @Description :
+ * @Author : Jianwei Dong
+ * @Date : 2024-08-26 22:47:06
+ * @Version : 1.0.0
+ * @LastEditors : Jianwei Dong
+ * @LastEditTime : 2024-08-26 22:47:06
+ * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
+ **/
+
+#include "kvcache.h"
+
+std::string ggml_type_to_string(ggml_type type) {
+ switch (type) {
+ case GGML_TYPE_F32:
+ return "GGML_TYPE_F32";
+ case GGML_TYPE_F16:
+ return "GGML_TYPE_F16";
+ case GGML_TYPE_Q4_0:
+ return "GGML_TYPE_Q4_0";
+ case GGML_TYPE_Q8_0:
+ return "GGML_TYPE_Q8_0";
+ }
+ return "UNDIFINED";
+}
+std::string AnchorTypeToString(AnchorType type) {
+ switch (type) {
+ case AnchorType::DYNAMIC:
+ return "DYNAMIC";
+ case AnchorType::BLOCK_MEAN:
+ return "BLOCK_MEAN";
+ case AnchorType::BLOCK_MAX:
+ return "BLOCK_MAX";
+ case AnchorType::FIXED_ANCHOR:
+ return "FIXED_ANCHOR";
+ case AnchorType::QUEST:
+ return "QUEST";
+ }
+ return "UNDIFINED";
+}
+std::string RetrievalTypeToString(RetrievalType type) {
+ switch (type) {
+ case RetrievalType::LAYER:
+ return "SHARED";
+ case RetrievalType::KVHEAD:
+ return "SEPARATE";
+ case RetrievalType::QHEAD:
+ return "INDIVIDUAL";
+ }
+ return "UNDIFINED";
+}
+KVCacheConfig::KVCacheConfig(int layer_num, int kv_head_num, int q_head_num,
+ int head_dim, int block_len, int anchor_num,
+ AnchorType anchor_type, ggml_type kv_type,
+ RetrievalType retrieval_type, int layer_step,
+ int token_step, int layer_offset,
+ int max_block_num, int max_batch_size,
+ int max_thread_num)
+ : layer_num(layer_num), kv_head_num(kv_head_num), q_head_num(q_head_num),
+ head_dim(head_dim), block_len(block_len), anchor_num(anchor_num),
+ anchor_type(anchor_type), kv_type(kv_type),
+ retrieval_type(retrieval_type), layer_step(layer_step),
+ token_step(token_step), layer_offset(layer_offset),
+ max_block_num(max_block_num), max_batch_size(max_batch_size),
+ max_thread_num(max_thread_num) {
+ printf(
+ "layer_num: %d, kv_head_num: %d, q_head_num: %d, head_dim: %d, "
+ "block_len: %d, anchor_num: %d, anchor_type: %s, kv_type: %s, "
+ "retrieval_type: %s, layer_step: %d, token_step: %d, layer_offset: %d,"
+ "max_block_num: %d, max_batch_size: %d, max_thread_num: %d\n",
+ layer_num, kv_head_num, q_head_num, head_dim, block_len, anchor_num,
+ AnchorTypeToString(anchor_type).c_str(),
+ ggml_type_to_string(kv_type).c_str(),
+ RetrievalTypeToString(retrieval_type).c_str(), layer_step, token_step,
+ layer_offset, max_block_num, max_batch_size, max_thread_num);
+ assert(q_head_num % kv_head_num == 0);
+}
+KVCache::KVCache(KVCacheConfig config) {
+ this->config_ = config;
+
+ n_gqa_ = config_.q_head_num / config_.kv_head_num;
+ if (config_.kv_type == ggml_type::GGML_TYPE_F16) {
+ // TODO: Elegant implement
+ k_cache_fp16_.resize(config_.layer_num);
+ v_cache_fp16_.resize(config_.layer_num);
+ selected_blocks_num_history_.resize(config_.layer_num /
+ config_.layer_step);
+ if (config_.retrieval_type == RetrievalType::LAYER) {
+ selected_blocks_history_.resize(config_.layer_num /
+ config_.layer_step);
+ } else if (config_.retrieval_type == RetrievalType::KVHEAD) {
+ selected_blocks_history_kvhead_.resize(config_.layer_num /
+ config_.layer_step);
+ } else if (config_.retrieval_type == RetrievalType::QHEAD) {
+ }
+ } else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {
+ k_cache_q4.resize(config.layer_num);
+ v_cache_q4.resize(config.layer_num);
+ } else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {
+ k_cache_q8.resize(config.layer_num);
+ v_cache_q8.resize(config.layer_num);
+ } else {
+ assert(false);
+ }
+ anchor_.resize(config.layer_num * config.max_block_num * config.anchor_num *
+ config.q_head_num * config.head_dim);
+ importance_.resize(config.layer_num);
+ past_block_num_.resize(config.layer_num);
+ for (int i = 0; i < config.layer_num; i++) {
+ past_block_num_[i] = 0;
+ }
+
+ ThreadResize(config.max_thread_num);
+ BatchResize(config.max_batch_size);
+ BlockResize(config.max_block_num);
+ q_fp32.resize(n_gqa_ * config.head_dim);
+}
+
+void KVCache::ThreadResize(int thread_num) {
+ thread_local_output_q8_0_.resize(thread_num);
+ thread_local_attn_score_.resize(thread_num);
+ thread_local_output_fp32_.resize(thread_num);
+ thread_local_attn_lse_.resize(thread_num);
+ thread_local_cur_output_fp32_.resize(thread_num);
+ thread_local_cur_attn_lse_.resize(thread_num);
+ thread_local_draft_.resize(thread_num);
+ thread_cur_head_idx_.resize(thread_num);
+ thread_local_attn_mask_.resize(thread_num);
+ for (int i = 0; i < thread_num; i++) {
+ thread_local_output_q8_0_[i].resize(n_gqa_ * config_.head_dim / QK8_0);
+ thread_local_attn_score_[i].resize(n_gqa_ * config_.block_len);
+ thread_local_output_fp32_[i].resize(n_gqa_ * config_.head_dim);
+ thread_local_attn_lse_[i].resize(n_gqa_);
+ thread_local_cur_output_fp32_[i].resize(n_gqa_ * config_.head_dim);
+ thread_local_cur_attn_lse_[i].resize(n_gqa_);
+ thread_local_draft_[i].resize(
+ 2 * n_gqa_ * config_.block_len + 6 * n_gqa_ * config_.head_dim +
+ 2 * config_.block_len * config_.head_dim +
+ config_.block_len * config_.head_dim / QK4_0);
+ thread_local_attn_mask_[i].resize(config_.block_len / 8);
+ }
+}
+void KVCache::BatchResize(int batch_size) {
+ mutex_.resize(batch_size);
+ q_q8_0_.resize(batch_size);
+ q_fp32_.resize(batch_size);
+ output_fp32_.resize(batch_size);
+ attn_lse_.resize(batch_size);
+ block_lse_.resize(batch_size);
+ attn_sparsity_.resize(batch_size);
+
+ if (config_.retrieval_type == RetrievalType::LAYER) {
+ block_table_before_retrieval_.resize(batch_size);
+ block_table_after_retrieval_.resize(batch_size);
+
+ for (int i = 0; i < config_.layer_num / config_.layer_step; i++) {
+ selected_blocks_history_[i].resize(batch_size);
+ }
+
+ } else if (config_.retrieval_type == RetrievalType::KVHEAD) {
+ block_table_before_retrieval_kvhead_.resize(batch_size);
+ block_table_after_retrieval_kvhead_.resize(batch_size);
+ for (int i = 0; i < config_.layer_num / config_.layer_step; i++) {
+ selected_blocks_history_kvhead_[i].resize(batch_size);
+ }
+ } else if (config_.retrieval_type == RetrievalType::QHEAD) {
+ block_table_before_retrieval_qhead_.resize(batch_size);
+ block_table_after_retrieval_qhead_.resize(batch_size);
+ }
+ cache_seqlens_.resize(batch_size);
+ if (config_.retrieval_type == RetrievalType::LAYER) {
+ block_similar_.resize(batch_size);
+ } else if (config_.retrieval_type == RetrievalType::KVHEAD) {
+ block_similar_kv_head_.resize(batch_size);
+ } else if (config_.retrieval_type == RetrievalType::QHEAD) {
+ block_similar_q_head_.resize(batch_size);
+ }
+ for (int i = 0; i < batch_size; i++) {
+ top_similar_block_.resize(batch_size);
+
+ mutex_[i].resize(config_.kv_head_num);
+ q_q8_0_[i].resize(config_.kv_head_num);
+ q_fp32_[i].resize(config_.kv_head_num);
+ output_fp32_[i].resize(config_.kv_head_num);
+ attn_lse_[i].resize(config_.kv_head_num);
+
+ for (int j = 0; j < config_.kv_head_num; j++) {
+ if (!mutex_[i][j]) {
+ mutex_[i][j] = std::make_unique();
+ }
+ q_q8_0_[i][j].resize(n_gqa_ * config_.head_dim / QK8_0);
+ q_fp32_[i][j].resize(n_gqa_ * config_.head_dim);
+ output_fp32_[i][j].resize(n_gqa_ * config_.head_dim);
+ attn_lse_[i][j].resize(n_gqa_);
+ }
+ }
+ avg_q.resize(batch_size);
+ avg_q_fp16.resize(batch_size);
+ for (int i = 0; i < batch_size; i++) {
+ attn_sparsity_[i].resize(config_.q_head_num);
+ avg_q[i].resize(config_.q_head_num * config_.head_dim);
+ avg_q_fp16[i].resize(config_.q_head_num * config_.head_dim);
+ }
+}
+
+void KVCache::BlockResize(int max_block_num) {
+ sin_.resize(max_block_num * config_.block_len);
+ cos_.resize(max_block_num * config_.block_len);
+ for (int i = 0; i < max_block_num * config_.block_len; i++) {
+ sin_[i].resize(config_.head_dim);
+ cos_[i].resize(config_.head_dim);
+ }
+
+ for (int i = 0; i < config_.layer_num / config_.layer_step; i++) {
+ for (int j = 0; j < config_.max_batch_size; j++) {
+ if (config_.retrieval_type == RetrievalType::LAYER) {
+ selected_blocks_history_[i][j].resize(max_block_num);
+ } else if (config_.retrieval_type == RetrievalType::KVHEAD) {
+ selected_blocks_history_kvhead_[i][j].resize(max_block_num);
+ for (int k = 0; k < config_.max_block_num; k++) {
+ selected_blocks_history_kvhead_[i][j][k].resize(
+ config_.kv_head_num);
+ }
+ } else if (config_.retrieval_type == RetrievalType::QHEAD) {
+ }
+ }
+ }
+
+ for (int layer_id = 0; layer_id < config_.layer_num; layer_id++) {
+ importance_[layer_id].resize(max_block_num);
+
+ if (config_.kv_type == ggml_type::GGML_TYPE_F16) {
+ // TODO: Elegant implement
+ k_cache_fp16_[layer_id].resize(config_.kv_head_num);
+ v_cache_fp16_[layer_id].resize(config_.kv_head_num);
+
+ for (int i = 0; i < config_.kv_head_num; i++) {
+ k_cache_fp16_[layer_id][i].resize(max_block_num);
+ v_cache_fp16_[layer_id][i].resize(max_block_num);
+
+ for (int j = 0; j < max_block_num; j++) {
+ k_cache_fp16_[layer_id][i][j].resize(config_.block_len *
+ config_.head_dim);
+ v_cache_fp16_[layer_id][i][j].resize(config_.block_len *
+ config_.head_dim);
+ }
+ }
+
+ } else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {
+ k_cache_q4[layer_id].resize(config_.kv_head_num);
+ v_cache_q4[layer_id].resize(config_.kv_head_num);
+ for (int i = 0; i < config_.kv_head_num; i++) {
+ k_cache_q4[layer_id][i].resize(max_block_num);
+ v_cache_q4[layer_id][i].resize(max_block_num);
+
+ for (int j = 0; j < max_block_num; j++) {
+ k_cache_q4[layer_id][i][j].resize(config_.block_len *
+ config_.head_dim / 32);
+ v_cache_q4[layer_id][i][j].resize(config_.block_len *
+ config_.head_dim / 32);
+ }
+ }
+ } else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {
+ k_cache_q8[layer_id].resize(config_.kv_head_num);
+ v_cache_q8[layer_id].resize(config_.kv_head_num);
+ for (int i = 0; i < config_.kv_head_num; i++) {
+ k_cache_q8[layer_id][i].resize(max_block_num);
+ v_cache_q8[layer_id][i].resize(max_block_num);
+
+ for (int j = 0; j < max_block_num; j++) {
+ k_cache_q8[layer_id][i][j].resize(config_.block_len *
+ config_.head_dim / 32);
+ v_cache_q8[layer_id][i][j].resize(config_.block_len *
+ config_.head_dim / 32);
+ }
+ }
+ } else {
+ assert(false);
+ }
+ for (int i = 0; i < config_.max_batch_size; i++) {
+ if (config_.retrieval_type == RetrievalType::LAYER) {
+ block_similar_[i].resize(max_block_num);
+ block_table_before_retrieval_[i].resize(max_block_num);
+ block_table_after_retrieval_[i].resize(max_block_num);
+ } else if (config_.retrieval_type == RetrievalType::KVHEAD) {
+ block_similar_kv_head_[i].resize(max_block_num);
+ block_table_before_retrieval_kvhead_[i].resize(max_block_num);
+ block_table_after_retrieval_kvhead_[i].resize(max_block_num);
+ for (int j = 0; j < max_block_num; j++) {
+ block_similar_kv_head_[i][j].resize(config_.kv_head_num);
+ block_table_before_retrieval_kvhead_[i][j].resize(
+ config_.kv_head_num);
+ block_table_after_retrieval_kvhead_[i][j].resize(
+ config_.kv_head_num);
+ }
+ } else if (config_.retrieval_type == RetrievalType::QHEAD) {
+ block_similar_q_head_[i].resize(max_block_num);
+ block_table_before_retrieval_qhead_[i].resize(max_block_num);
+ block_table_after_retrieval_qhead_[i].resize(max_block_num);
+ for (int j = 0; j < max_block_num; j++) {
+ block_similar_q_head_[i][j].resize(config_.q_head_num);
+ block_table_before_retrieval_qhead_[i][j].resize(
+ config_.q_head_num);
+ block_table_after_retrieval_qhead_[i][j].resize(
+ config_.q_head_num);
+ }
+ }
+ block_lse_[i].resize(max_block_num);
+ for (int j = 0; j < max_block_num; j++) {
+ block_lse_[i][j].resize(config_.q_head_num);
+ }
+ }
+
+ for (int i = 0; i < max_block_num; i++) {
+ importance_[layer_id][i].resize(config_.block_len);
+ for (int j = 0; j < config_.block_len; j++) {
+ importance_[layer_id][i][j].resize(config_.q_head_num);
+ }
+ }
+ }
+}
+
+void KVCache::calc_anchor_all_layers(int *block_table, int *cache_seqlens,
+ int batch_size, int max_block_num,
+ Backend *backend) {
+ // Timer start
+ auto start = std::chrono::high_resolution_clock::now();
+
+ // Each task updates the importance of a certain block
+ seq_len_ = config_.block_len;
+ backend->do_work_stealing_job(
+ config_.layer_num * batch_size * max_block_num, nullptr,
+ [&](int task_id) {
+ int layer_id = task_id / (batch_size * max_block_num);
+ int batch_id = (task_id / max_block_num) % batch_size;
+ int block_id = task_id % max_block_num;
+ // If the block is out of the sequence length, skip it. In
+ // particular, the last block of the sequence that is shorter than
+ // the block length should be skipped.
+
+ if (cache_seqlens[batch_id] / config_.block_len < block_id) {
+ return;
+ }
+ int block_idx = block_table[batch_id * max_block_num + block_id];
+
+ std::vector block_fp32(32);
+ if (config_.anchor_type == AnchorType::DYNAMIC) {
+
+ // clear anchor_
+ for (int anchor_id = 0; anchor_id < 1; anchor_id++) {
+ for (int head_id = 0; head_id < config_.q_head_num;
+ head_id++) {
+ for (int l = 0; l < config_.head_dim; l++) {
+ anchor_[layer_id * config_.max_block_num *
+ config_.anchor_num *
+ config_.q_head_num * config_.head_dim +
+ block_idx * config_.anchor_num *
+ config_.q_head_num * config_.head_dim +
+ anchor_id * config_.q_head_num *
+ config_.head_dim +
+ head_id * config_.head_dim + l] = 0;
+ }
+ }
+ }
+
+ // find top anchor_num importances and their corresponding
+ // positions in the importance_ tensor
+ // TODO: Move top_importances to the class member to avoid
+ // repeated memory allocation
+ std::priority_queue<
+ std::pair>,
+ std::vector>>,
+ std::greater<>>
+ top_importances;
+ for (int head_id = 0; head_id < config_.q_head_num; head_id++) {
+ for (int k = 0; k < seq_len_; k++) {
+ top_importances.push(std::make_pair(
+ GGML_FP16_TO_FP32(
+ importance_[layer_id][block_idx][k][head_id]),
+ std::make_pair(block_idx, k)));
+ // TODO: change to config_ item
+ if (top_importances.size() > config_.anchor_num) {
+ top_importances.pop();
+ }
+ }
+
+ // fill anchor_
+
+ for (int l = 0; l < config_.head_dim; l++) {
+ anchor_[layer_id * config_.max_block_num *
+ config_.anchor_num * config_.q_head_num *
+ config_.head_dim +
+ block_idx * config_.anchor_num *
+ config_.q_head_num * config_.head_dim +
+ 0 * config_.q_head_num * config_.head_dim +
+ head_id * config_.head_dim + l] = 0;
+ }
+ for (int k = 0; k < config_.anchor_num; k++) {
+ int top_indice = top_importances.top().second.second;
+ int top_block_idx = top_importances.top().second.first;
+
+ if (config_.kv_type == ggml_type::GGML_TYPE_F16) {
+
+ for (int l = 0; l < config_.head_dim; l++) {
+ anchor_[layer_id * config_.max_block_num *
+ config_.anchor_num *
+ config_.q_head_num *
+ config_.head_dim +
+ top_block_idx * config_.anchor_num *
+ config_.q_head_num *
+ config_.head_dim +
+ 0 * config_.q_head_num *
+ config_.head_dim +
+ head_id * config_.head_dim + l] =
+ GGML_FP32_TO_FP16(
+ GGML_FP16_TO_FP32(
+ anchor_[layer_id *
+ config_.max_block_num *
+ config_.anchor_num *
+ config_.q_head_num *
+ config_.head_dim +
+ top_block_idx *
+ config_.anchor_num *
+ config_.q_head_num *
+ config_.head_dim +
+ 0 * config_.q_head_num *
+ config_.head_dim +
+ head_id * config_.head_dim +
+ l]) +
+ GGML_FP16_TO_FP32(
+ k_cache_fp16_[layer_id]
+ [head_id / n_gqa_]
+ [top_block_idx]
+ [top_indice *
+ config_.head_dim +
+ l]));
+ }
+
+ } else if (config_.kv_type ==
+ ggml_type::GGML_TYPE_Q4_0) {
+ for (int l = 0; l < config_.head_dim / 32; l++) {
+ block_q4_0 block = k_cache_q4
+ [layer_id][head_id / n_gqa_][top_block_idx]
+ [top_indice * config_.head_dim / 32 + l];
+ dequantize_row_q4_0(&block, block_fp32.data(),
+ 32);
+ for (int m = 0; m < 32; m++) {
+ anchor_[layer_id * config_.max_block_num *
+ config_.anchor_num *
+ config_.q_head_num *
+ config_.head_dim +
+ top_block_idx * config_.anchor_num *
+ config_.q_head_num *
+ config_.head_dim +
+ 0 * config_.q_head_num *
+ config_.head_dim +
+ head_id * config_.head_dim +
+ l * 32 + m] =
+ GGML_FP32_TO_FP16(
+ block_fp32[m] / 4 +
+ GGML_FP16_TO_FP32(
+ anchor_[layer_id *
+ config_
+ .max_block_num *
+ config_.anchor_num *
+ config_.q_head_num *
+ config_.head_dim +
+ top_block_idx *
+ config_.anchor_num *
+ config_.q_head_num *
+ config_.head_dim +
+ 0 * config_.q_head_num *
+ config_.head_dim +
+ head_id *
+ config_.head_dim +
+ l * 32 + m]));
+ }
+ }
+ } else if (config_.kv_type ==
+ ggml_type::GGML_TYPE_Q8_0) {
+ for (int l = 0; l < config_.head_dim / 32; l++) {
+ block_q8_0 block = k_cache_q8
+ [layer_id][head_id / n_gqa_][top_block_idx]
+ [top_indice * config_.head_dim / 32 + l];
+ dequantize_row_q8_0(&block, block_fp32.data(),
+ 32);
+ for (int m = 0; m < 32; m++) {
+ anchor_[layer_id * config_.max_block_num *
+ config_.anchor_num *
+ config_.q_head_num *
+ config_.head_dim +
+ top_block_idx * config_.anchor_num *
+ config_.q_head_num *
+ config_.head_dim +
+ 0 * config_.q_head_num *
+ config_.head_dim +
+ head_id * config_.head_dim +
+ l * 32 + m] =
+ GGML_FP32_TO_FP16(
+ block_fp32[m] / 4 +
+ GGML_FP16_TO_FP32(
+ anchor_[layer_id *
+ config_
+ .max_block_num *
+ config_.anchor_num *
+ config_.q_head_num *
+ config_.head_dim +
+ top_block_idx *
+ config_.anchor_num *
+ config_.q_head_num *
+ config_.head_dim +
+ 0 * config_.q_head_num *
+ config_.head_dim +
+ head_id *
+ config_.head_dim +
+ l * 32 + m]));
+ }
+ }
+ }
+ top_importances.pop();
+ }
+ }
+ } else if (config_.anchor_type == AnchorType::BLOCK_MEAN) {
+ // clear anchor_
+ for (int anchor_id = 0; anchor_id < config_.anchor_num;
+ anchor_id++) {
+ for (int head_id = 0; head_id < config_.q_head_num;
+ head_id++) {
+ for (int l = 0; l < config_.head_dim; l++) {
+ anchor_[layer_id * config_.max_block_num *
+ config_.anchor_num *
+ config_.q_head_num * config_.head_dim +
+ block_idx * config_.anchor_num *
+ config_.q_head_num * config_.head_dim +
+ anchor_id * config_.q_head_num *
+ config_.head_dim +
+ head_id * config_.head_dim + l] = 0;
+ }
+ }
+ }
+
+ // fill anchor_
+ if (config_.kv_type == ggml_type::GGML_TYPE_F16) {
+
+ for (int head_id = 0; head_id < config_.q_head_num;
+ head_id++) {
+ for (int k = 0; k < config_.block_len; k++) {
+ for (int l = 0; l < config_.head_dim; l++) {
+ anchor_[layer_id * config_.max_block_num *
+ config_.anchor_num *
+ config_.q_head_num *
+ config_.head_dim +
+ block_idx * config_.anchor_num *
+ config_.q_head_num *
+ config_.head_dim +
+ 0 * config_.q_head_num *
+ config_.head_dim +
+ head_id * config_.head_dim + l] =
+ GGML_FP32_TO_FP16(
+ GGML_FP16_TO_FP32(
+ anchor_[layer_id *
+ config_.max_block_num *
+ config_.anchor_num *
+ config_.q_head_num *
+ config_.head_dim +
+ block_idx *
+ config_.anchor_num *
+ config_.q_head_num *
+ config_.head_dim +
+ 0 * config_.q_head_num *
+ config_.head_dim +
+ head_id * config_.head_dim +
+ l]) +
+ GGML_FP16_TO_FP32(
+ k_cache_fp16_[layer_id]
+ [head_id / n_gqa_]
+ [block_idx]
+ [k * config_.head_dim +
+ l]) /
+ config_.block_len);
+ }
+ }
+ }
+ }
+ } else if (config_.anchor_type == AnchorType::BLOCK_MAX) {
+ // clear anchor_
+ for (int anchor_id = 0; anchor_id < config_.anchor_num;
+ anchor_id++) {
+ for (int head_id = 0; head_id < config_.q_head_num;
+ head_id++) {
+ for (int l = 0; l < config_.head_dim; l++) {
+ anchor_[layer_id * config_.max_block_num *
+ config_.anchor_num *
+ config_.q_head_num * config_.head_dim +
+ block_idx * config_.anchor_num *
+ config_.q_head_num * config_.head_dim +
+ anchor_id * config_.q_head_num *
+ config_.head_dim +
+ head_id * config_.head_dim + l] = 0;
+ }
+ }
+ }
+
+ // fill anchor_
+ if (config_.kv_type == ggml_type::GGML_TYPE_F16) {
+
+ for (int head_id = 0; head_id < config_.q_head_num;
+ head_id++) {
+ for (int k = 0; k < config_.block_len; k++) {
+ for (int l = 0; l < config_.head_dim; l++) {
+ anchor_[layer_id * config_.max_block_num *
+ config_.anchor_num *
+ config_.q_head_num *
+ config_.head_dim +
+ block_idx * config_.anchor_num *
+ config_.q_head_num *
+ config_.head_dim +
+ 0 * config_.q_head_num *
+ config_.head_dim +
+ head_id * config_.head_dim + l] =
+ GGML_FP32_TO_FP16(std::max(
+ GGML_FP16_TO_FP32(
+ anchor_[layer_id *
+ config_.max_block_num *
+ config_.anchor_num *
+ config_.q_head_num *
+ config_.head_dim +
+ block_idx *
+ config_.anchor_num *
+ config_.q_head_num *
+ config_.head_dim +
+ 0 * config_.q_head_num *
+ config_.head_dim +
+ head_id * config_.head_dim +
+ l]),
+ GGML_FP16_TO_FP32(
+ k_cache_fp16_
+ [layer_id][head_id / n_gqa_]
+ [block_idx]
+ [k * config_.head_dim + l])));
+ }
+ }
+ }
+ }
+ } else if (config_.anchor_type == AnchorType::FIXED_ANCHOR) {
+ // clear anchor_
+ for (int anchor_id = 0; anchor_id < 1; anchor_id++) {
+ for (int head_id = 0; head_id < config_.q_head_num;
+ head_id++) {
+ for (int l = 0; l < config_.head_dim; l++) {
+ anchor_[layer_id * config_.max_block_num *
+ config_.anchor_num *
+ config_.q_head_num * config_.head_dim +
+ block_idx * config_.anchor_num *
+ config_.q_head_num * config_.head_dim +
+ anchor_id * config_.q_head_num *
+ config_.head_dim +
+ head_id * config_.head_dim + l] = 0;
+ }
+ }
+ }
+
+ // fill anchor_
+ if (config_.kv_type == ggml_type::GGML_TYPE_F16) {
+
+ int stride = config_.block_len / config_.anchor_num;
+ for (int head_id = 0; head_id < config_.q_head_num;
+ head_id++) {
+ for (int k = 0, tot = 0;
+ k < config_.block_len, tot < config_.anchor_num;
+ k += stride, tot++) {
+ for (int l = 0; l < config_.head_dim; l++) {
+ anchor_[layer_id * config_.max_block_num *
+ config_.anchor_num *
+ config_.q_head_num *
+ config_.head_dim +
+ block_idx * config_.anchor_num *
+ config_.q_head_num *
+ config_.head_dim +
+ 0 * config_.q_head_num *
+ config_.head_dim +
+ head_id * config_.head_dim + l] =
+ GGML_FP32_TO_FP16(
+ GGML_FP16_TO_FP32(
+ anchor_[layer_id *
+ config_.max_block_num *
+ config_.anchor_num *
+ config_.q_head_num *
+ config_.head_dim +
+ block_idx *
+ config_.anchor_num *
+ config_.q_head_num *
+ config_.head_dim +
+ 0 * config_.q_head_num *
+ config_.head_dim +
+ head_id * config_.head_dim +
+ l]) +
+ GGML_FP16_TO_FP32(
+ k_cache_fp16_[layer_id]
+ [head_id / n_gqa_]
+ [block_idx]
+ [k * config_.head_dim +
+ l]) /
+ config_.anchor_num);
+ }
+ }
+ }
+ }
+
+ } else if (config_.anchor_type == AnchorType::QUEST) {
+ // clear anchor_
+ for (int head_id = 0; head_id < config_.q_head_num; head_id++) {
+ for (int l = 0; l < config_.head_dim; l++) {
+ anchor_[layer_id * config_.max_block_num *
+ config_.anchor_num * config_.q_head_num *
+ config_.head_dim +
+ block_idx * config_.anchor_num *
+ config_.q_head_num * config_.head_dim +
+ 1 * config_.q_head_num * config_.head_dim +
+ head_id * config_.head_dim + l] =
+ GGML_FP32_TO_FP16(
+ std::numeric_limits::max());
+
+ anchor_[layer_id * config_.max_block_num *
+ config_.anchor_num * config_.q_head_num *
+ config_.head_dim +
+ block_idx * config_.anchor_num *
+ config_.q_head_num * config_.head_dim +
+ 0 * config_.q_head_num * config_.head_dim +
+ head_id * config_.head_dim + l] =
+ GGML_FP32_TO_FP16(
+ std::numeric_limits::min());
+ }
+ }
+
+ // fill anchor_
+
+ if (config_.kv_type == ggml_type::GGML_TYPE_F16) {
+ for (int indice = 0; indice < seq_len_; indice++) {
+ for (int head_id = 0; head_id < config_.kv_head_num;
+ head_id++) {
+ for (int l = 0; l < config_.head_dim; l++) {
+ anchor_[layer_id * config_.max_block_num *
+ config_.anchor_num *
+ config_.q_head_num *
+ config_.head_dim +
+ block_idx * config_.anchor_num *
+ config_.q_head_num *
+ config_.head_dim +
+ 0 * config_.q_head_num *
+ config_.head_dim +
+ head_id * config_.head_dim + l] =
+ GGML_FP32_TO_FP16(std::max(
+ GGML_FP16_TO_FP32(
+ k_cache_fp16_
+ [layer_id][head_id][block_idx]
+ [indice * config_.head_dim +
+ l]),
+ GGML_FP16_TO_FP32(
+ anchor_[layer_id *
+ config_.max_block_num *
+ config_.anchor_num *
+ config_.q_head_num *
+ config_.head_dim +
+ block_idx *
+ config_.anchor_num *
+ config_.q_head_num *
+ config_.head_dim +
+ 0 * config_.q_head_num *
+ config_.head_dim +
+ head_id * config_.head_dim +
+ l])));
+
+ anchor_[layer_id * config_.max_block_num *
+ config_.anchor_num *
+ config_.q_head_num *
+ config_.head_dim +
+ block_idx * config_.anchor_num *
+ config_.q_head_num *
+ config_.head_dim +
+ 1 * config_.q_head_num *
+ config_.head_dim +
+ head_id * config_.head_dim + l] =
+ GGML_FP32_TO_FP16(std::min(
+ GGML_FP16_TO_FP32(
+ k_cache_fp16_
+ [layer_id][head_id][block_idx]
+ [indice * config_.head_dim +
+ l]),
+ GGML_FP16_TO_FP32(
+ anchor_[layer_id *
+ config_.max_block_num *
+ config_.anchor_num *
+ config_.q_head_num *
+ config_.head_dim +
+ block_idx *
+ config_.anchor_num *
+ config_.q_head_num *
+ config_.head_dim +
+ 1 * config_.q_head_num *
+ config_.head_dim +
+ head_id * config_.head_dim +
+ l])));
+ }
+ }
+ }
+
+ } else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {
+ for (int indice = 0; indice < seq_len_; indice++) {
+ for (int head_id = 0; head_id < config_.kv_head_num;
+ head_id++) {
+ for (int l = 0; l < config_.head_dim / 32; l++) {
+ block_q4_0 block =
+ k_cache_q4[layer_id][head_id][block_idx]
+ [indice * config_.head_dim / 32 +
+ l];
+ dequantize_row_q4_0(&block, block_fp32.data(),
+ 32);
+
+ for (int m = 0; m < 32; m++) {
+ for (int gqa_idx = 0; gqa_idx < n_gqa_;
+ gqa_idx++) {
+
+ anchor_[layer_id *
+ config_.max_block_num *
+ config_.anchor_num *
+ config_.q_head_num *
+ config_.head_dim +
+ block_idx * config_.anchor_num *
+ config_.q_head_num *
+ config_.head_dim +
+ 0 * config_.q_head_num *
+ config_.head_dim +
+ head_id * config_.head_dim +
+ l * 32 + m] =
+ GGML_FP32_TO_FP16(std::max(
+ block_fp32[m],
+ GGML_FP16_TO_FP32(
+ anchor_
+ [layer_id *
+ config_
+ .max_block_num *
+ config_
+ .anchor_num *
+ config_
+ .q_head_num *
+ config_.head_dim +
+ block_idx *
+ config_
+ .anchor_num *
+ config_
+ .q_head_num *
+ config_.head_dim +
+ 0 *
+ config_
+ .q_head_num *
+ config_.head_dim +
+ head_id *
+ config_.head_dim +
+ l * 32 + m])));
+
+ anchor_[layer_id *
+ config_.max_block_num *
+ config_.anchor_num *
+ config_.q_head_num *
+ config_.head_dim +
+ block_idx * config_.anchor_num *
+ config_.q_head_num *
+ config_.head_dim +
+ 1 * config_.q_head_num *
+ config_.head_dim +
+ head_id * config_.head_dim +
+ l * 32 + m] =
+ GGML_FP32_TO_FP16(std::min(
+ block_fp32[m],
+ GGML_FP16_TO_FP32(
+ anchor_
+ [layer_id *
+ config_
+ .max_block_num *
+ config_
+ .anchor_num *
+ config_
+ .q_head_num *
+ config_.head_dim +
+ block_idx *
+ config_
+ .anchor_num *
+ config_
+ .q_head_num *
+ config_.head_dim +
+ 1 *
+ config_
+ .q_head_num *
+ config_.head_dim +
+ head_id *
+ config_.head_dim +
+ l * 32 + m])));
+ }
+ }
+ }
+ }
+ }
+ } else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {
+ for (int indice = 0; indice < seq_len_; indice++) {
+ for (int head_id = 0; head_id < config_.kv_head_num;
+ head_id++) {
+ for (int l = 0; l < config_.head_dim / 32; l++) {
+ block_q8_0 block =
+ k_cache_q8[layer_id][head_id][block_idx]
+ [indice * config_.head_dim / 32 +
+ l];
+ dequantize_row_q8_0(&block, block_fp32.data(),
+ 32);
+
+ for (int m = 0; m < 32; m++) {
+ for (int gqa_idx = 0; gqa_idx < n_gqa_;
+ gqa_idx++) {
+
+ anchor_[layer_id *
+ config_.max_block_num *
+ config_.anchor_num *
+ config_.q_head_num *
+ config_.head_dim +
+ block_idx * config_.anchor_num *
+ config_.q_head_num *
+ config_.head_dim +
+ 0 * config_.q_head_num *
+ config_.head_dim +
+ head_id * config_.head_dim +
+ l * 32 + m] =
+ GGML_FP32_TO_FP16(std::max(
+ block_fp32[m],
+ GGML_FP16_TO_FP32(
+ anchor_
+ [layer_id *
+ config_
+ .max_block_num *
+ config_
+ .anchor_num *
+ config_
+ .q_head_num *
+ config_.head_dim +
+ block_idx *
+ config_
+ .anchor_num *
+ config_
+ .q_head_num *
+ config_.head_dim +
+ 0 *
+ config_
+ .q_head_num *
+ config_.head_dim +
+ head_id *
+ config_.head_dim +
+ l * 32 + m])));
+
+ anchor_[layer_id *
+ config_.max_block_num *
+ config_.anchor_num *
+ config_.q_head_num *
+ config_.head_dim +
+ block_idx * config_.anchor_num *
+ config_.q_head_num *
+ config_.head_dim +
+ 1 * config_.q_head_num *
+ config_.head_dim +
+ head_id * config_.head_dim +
+ l * 32 + m] =
+ GGML_FP32_TO_FP16(std::min(
+ block_fp32[m],
+ GGML_FP16_TO_FP32(
+ anchor_
+ [layer_id *
+ config_
+ .max_block_num *
+ config_
+ .anchor_num *
+ config_
+ .q_head_num *
+ config_.head_dim +
+ block_idx *
+ config_
+ .anchor_num *
+ config_
+ .q_head_num *
+ config_.head_dim +
+ 1 *
+ config_
+ .q_head_num *
+ config_.head_dim +
+ head_id *
+ config_.head_dim +
+ l * 32 + m])));
+ }
+ }
+ }
+ }
+ }
+ }
+ } else {
+ assert(false);
+ }
+ },
+ nullptr);
+
+ // Timer end
+ auto end = std::chrono::high_resolution_clock::now();
+ std::chrono::duration duration = end - start;
+ // printf("time of calc_anchor_all_layers: %f s\n", duration.count());
+}
+
+void KVCache::clear_importance_all_layers(int *block_table, int *cache_seqlens,
+ int batch_size, int max_block_num,
+ Backend *backend) {
+ // Timer start
+ auto start = std::chrono::high_resolution_clock::now();
+
+ // Each task updates the importance of a certain block
+ seq_len_ = config_.block_len;
+ backend->do_work_stealing_job(
+ config_.layer_num * batch_size * max_block_num, nullptr,
+ [&](int task_id) {
+ int layer_id = task_id / (batch_size * max_block_num);
+ int batch_id = (task_id / max_block_num) % batch_size;
+ int block_id = task_id % max_block_num;
+ // If the block is out of the sequence length, skip it. In
+ // particular, the last block of the sequence that is shorter than
+ // the block length should be skipped.
+
+ if (cache_seqlens[batch_id] / config_.block_len < block_id) {
+ return;
+ }
+ int block_idx = block_table[batch_id * max_block_num + block_id];
+
+ if (config_.anchor_type == AnchorType::DYNAMIC) {
+
+ // clear anchor_
+ for (int head_id = 0; head_id < config_.q_head_num; head_id++) {
+ for (int l = 0; l < config_.block_len; l++) {
+ importance_[layer_id][block_idx][l][head_id] = 0;
+ }
+ }
+ }
+ },
+ nullptr);
+
+ // Timer end
+ auto end = std::chrono::high_resolution_clock::now();
+ std::chrono::duration duration = end - start;
+ // printf("time of clear_importance_all_layerssssss: %f s\n",
+ // duration.count());
+}
+
+void KVCache::clear_kvcache_all_layers(int *block_table, int *cache_seqlens,
+ int batch_size, int max_block_num,
+ Backend *backend) {
+ // Timer start
+ auto start = std::chrono::high_resolution_clock::now();
+
+ // Each task updates the importance of a certain block
+ seq_len_ = config_.block_len;
+ backend->do_work_stealing_job(
+ config_.layer_num * batch_size * max_block_num * config_.kv_head_num,
+ nullptr,
+ [&](int task_id) {
+ int layer_id =
+ task_id / (batch_size * max_block_num * config_.kv_head_num);
+ int batch_id =
+ (task_id / (max_block_num * config_.kv_head_num)) % batch_size;
+ int block_id = task_id / config_.kv_head_num % max_block_num;
+ int head_id = task_id % config_.kv_head_num;
+ // If the block is out of the sequence length, skip it. In
+ // particular, the last block of the sequence that is shorter than
+ // the block length should be skipped.
+ if (cache_seqlens[batch_id] / config_.block_len < block_id) {
+ return;
+ }
+ int block_idx = block_table[batch_id * max_block_num + block_id];
+
+ if (config_.kv_type == ggml_type::GGML_TYPE_F16) {
+ for (int l = 0; l < config_.block_len * config_.head_dim; l++) {
+ k_cache_fp16_[layer_id][head_id][block_idx][l] = 0;
+ v_cache_fp16_[layer_id][head_id][block_idx][l] = 0;
+ }
+ } else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {
+ for (int l = 0; l < config_.block_len * config_.head_dim / 32;
+ l++) {
+ k_cache_q4[layer_id][head_id][block_idx][l].d = 0;
+ v_cache_q4[layer_id][head_id][block_idx][l].d = 0;
+ }
+ } else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {
+ for (int l = 0; l < config_.block_len * config_.head_dim / 32;
+ l++) {
+ k_cache_q8[layer_id][head_id][block_idx][l].d = 0;
+ v_cache_q8[layer_id][head_id][block_idx][l].d = 0;
+ }
+ }
+ },
+ nullptr);
+
+ // Timer end
+ auto end = std::chrono::high_resolution_clock::now();
+ std::chrono::duration duration = end - start;
+ // printf("time of clear_kvcache_all_layers: %f s\n", duration.count());
+}
+
+void KVCache::get_sincos(ggml_fp16_t *sin, ggml_fp16_t *cos, int seqlen) {
+ // Timer start
+ auto start = std::chrono::high_resolution_clock::now();
+
+ const uint16_t *sin_data = const_cast(sin);
+ const uint16_t *cos_data = const_cast(cos);
+
+ for (int i = 0; i < seqlen; i++) {
+ for (int j = 0; j < config_.head_dim; j++) {
+ sin_[i][j] = sin_data[i * config_.head_dim + j];
+ cos_[i][j] = cos_data[i * config_.head_dim + j];
+ }
+ }
+
+ // Timer end
+ auto end = std::chrono::high_resolution_clock::now();
+ std::chrono::duration duration = end - start;
+ printf("time of get_sincos: %f s\n", duration.count());
+}
+
+void ggml_vec_scale_f32(const int n, float *y, const float v) {
+#if defined(GGML_USE_ACCELERATE)
+ vDSP_vsmul(y, 1, &v, y, 1, n);
+#elif defined(GGML_SIMD)
+ const int np = (n & ~(GGML_F32_STEP - 1));
+
+ GGML_F32_VEC vx = GGML_F32_VEC_SET1(v);
+
+ GGML_F32_VEC ay[GGML_F32_ARR];
+
+ for (int i = 0; i < np; i += GGML_F32_STEP) {
+ for (int j = 0; j < GGML_F32_ARR; j++) {
+ ay[j] = GGML_F32_VEC_LOAD(y + i + j * GGML_F32_EPR);
+ ay[j] = GGML_F32_VEC_MUL(ay[j], vx);
+
+ GGML_F32_VEC_STORE(y + i + j * GGML_F32_EPR, ay[j]);
+ }
+ }
+
+ // leftovers
+ for (int i = np; i < n; ++i) {
+ y[i] *= v;
+ }
+#else
+ // scalar
+ for (int i = 0; i < n; ++i) {
+ y[i] *= v;
+ }
+#endif
+}
\ No newline at end of file
diff --git a/ktransformers/ktransformers_ext/operators/llamafile/linear.cpp b/ktransformers/ktransformers_ext/operators/llamafile/linear.cpp
index 81e5006..d1e7967 100644
--- a/ktransformers/ktransformers_ext/operators/llamafile/linear.cpp
+++ b/ktransformers/ktransformers_ext/operators/llamafile/linear.cpp
@@ -3,8 +3,8 @@
* @Author : chenht2022
* @Date : 2024-07-12 10:07:58
* @Version : 1.0.0
- * @LastEditors : chenht2022
- * @LastEditTime : 2024-07-25 10:34:58
+ * @LastEditors : kkk1nak0
+ * @LastEditTime : 2024-08-15 07:45:18
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
#include "linear.h"
@@ -24,10 +24,14 @@ Linear::~Linear() {
shared_mem_buffer.dealloc(this);
}
-void Linear::warm_up(Backend* backend) {
+void Linear::warm_up(Backend *backend) {
std::vector input_fp32(config_.input_size);
- std::vector input(config_.input_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type));
- std::vector output(config_.output_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type));
+ std::vector input(config_.input_size *
+ ggml_type_size(config_.hidden_type) /
+ ggml_blck_size(config_.hidden_type));
+ std::vector output(config_.output_size *
+ ggml_type_size(config_.hidden_type) /
+ ggml_blck_size(config_.hidden_type));
for (int i = 0; i < config_.input_size; i++) {
input_fp32[i] = 0;
}
@@ -45,7 +49,7 @@ void Linear::forward_many(int qlen, const void* input, void* output, Backend* ba
proj_input_ptr = proj_input_;
}
int nth = config_.output_size / config_.stride;
- backend->do_work_stealing_job(nth, [&](int task_id) {
+ backend->do_work_stealing_job(nth, nullptr, [&](int task_id) {
int ith = task_id;
void* proj_ptr = (uint8_t*)proj_ + ith * config_.stride * config_.input_size * ggml_type_size(config_.proj_type) / ggml_blck_size(config_.proj_type);
float* proj_output_ptr = proj_output_ + ith * config_.stride;
@@ -57,7 +61,7 @@ void Linear::forward_many(int qlen, const void* input, void* output, Backend* ba
from_float(output_fp32_ptr, output_ptr, config_.stride, config_.hidden_type);
}
}
- });
+ }, nullptr);
if (config_.stride % ggml_blck_size(config_.hidden_type) != 0) {
from_float(proj_output_, output, qlen * config_.output_size, config_.hidden_type);
}
diff --git a/ktransformers/ktransformers_ext/operators/llamafile/mlp.cpp b/ktransformers/ktransformers_ext/operators/llamafile/mlp.cpp
index abad01e..602fdcb 100644
--- a/ktransformers/ktransformers_ext/operators/llamafile/mlp.cpp
+++ b/ktransformers/ktransformers_ext/operators/llamafile/mlp.cpp
@@ -3,8 +3,8 @@
* @Author : chenht2022
* @Date : 2024-07-16 10:43:18
* @Version : 1.0.0
- * @LastEditors : chenht2022
- * @LastEditTime : 2024-07-25 10:35:04
+ * @LastEditors : kkk1nak0
+ * @LastEditTime : 2024-08-15 07:44:38
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
#include "mlp.h"
@@ -31,10 +31,14 @@ MLP::~MLP() {
shared_mem_buffer.dealloc(this);
}
-void MLP::warm_up(Backend* backend) {
+void MLP::warm_up(Backend *backend) {
std::vector input_fp32(config_.hidden_size);
- std::vector input(config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type));
- std::vector output(config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type));
+ std::vector input(config_.hidden_size *
+ ggml_type_size(config_.hidden_type) /
+ ggml_blck_size(config_.hidden_type));
+ std::vector output(config_.hidden_size *
+ ggml_type_size(config_.hidden_type) /
+ ggml_blck_size(config_.hidden_type));
for (int i = 0; i < config_.hidden_size; i++) {
input_fp32[i] = 0;
}
@@ -42,9 +46,7 @@ void MLP::warm_up(Backend* backend) {
forward_many(1, input.data(), output.data(), backend);
}
-static float act_fn(float x) {
- return x / (1.0f + expf(-x));
-}
+static float act_fn(float x) { return x / (1.0f + expf(-x)); }
void MLP::forward_many(int qlen, const void* input, void* output, Backend* backend) {
const void* gate_input_ptr;
@@ -72,7 +74,7 @@ void MLP::forward_many(int qlen, const void* input, void* output, Backend* backe
}
}
int nth = config_.intermediate_size / config_.stride;
- backend->do_work_stealing_job(nth, [&](int task_id) {
+ backend->do_work_stealing_job(nth, nullptr, [&](int task_id) {
int ith = task_id;
void* gate_proj_ptr = (uint8_t*)gate_proj_ + ith * config_.stride * config_.hidden_size * ggml_type_size(config_.gate_type) / ggml_blck_size(config_.gate_type);
float* gate_output_ptr = gate_output_ + ith * config_.stride;
@@ -90,12 +92,12 @@ void MLP::forward_many(int qlen, const void* input, void* output, Backend* backe
from_float(intermediate_fp32_ptr, down_input_ptr, config_.stride, ggml_internal_get_type_traits(config_.down_type).vec_dot_type);
}
}
- });
+ }, nullptr);
if (config_.stride % ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) != 0) {
from_float(intermediate_fp32_, down_input_, qlen * config_.intermediate_size, ggml_internal_get_type_traits(config_.down_type).vec_dot_type);
}
nth = config_.hidden_size / config_.stride;
- backend->do_work_stealing_job(nth, [&](int task_id) {
+ backend->do_work_stealing_job(nth, nullptr, [&](int task_id) {
int ith = task_id;
void* down_proj_ptr = (uint8_t*)down_proj_ + ith * config_.stride * config_.intermediate_size * ggml_type_size(config_.down_type) / ggml_blck_size(config_.down_type);
float* down_output_ptr = down_output_ + ith * config_.stride;
@@ -107,7 +109,7 @@ void MLP::forward_many(int qlen, const void* input, void* output, Backend* backe
from_float(output_fp32_ptr, output_ptr, config_.stride, config_.hidden_type);
}
}
- });
+ }, nullptr);
if (config_.stride % ggml_blck_size(config_.hidden_type) != 0) {
from_float(down_output_, output, qlen * config_.hidden_size, config_.hidden_type);
}
diff --git a/ktransformers/ktransformers_ext/operators/llamafile/moe.cpp b/ktransformers/ktransformers_ext/operators/llamafile/moe.cpp
index d75db65..0fcf9df 100644
--- a/ktransformers/ktransformers_ext/operators/llamafile/moe.cpp
+++ b/ktransformers/ktransformers_ext/operators/llamafile/moe.cpp
@@ -3,8 +3,8 @@
* @Author : chenht2022
* @Date : 2024-07-22 02:03:22
* @Version : 1.0.0
- * @LastEditors : chenht2022
- * @LastEditTime : 2024-07-25 10:35:07
+ * @LastEditors : kkk1nak0
+ * @LastEditTime : 2024-08-15 07:43:41
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
#include "moe.h"
@@ -121,7 +121,7 @@ void MOE::forward_one(int k, const uint64_t* expert_ids, const float* weights, c
}
}
int nth = config_.intermediate_size / config_.stride;
- backend->do_work_stealing_job(nth * k, [&](int task_id) {
+ backend->do_work_stealing_job(nth * k, nullptr, [&](int task_id) {
int expert_idx = task_id / nth;
uint64_t expert_id = expert_ids[expert_idx];
int ith = task_id % nth;
@@ -139,14 +139,14 @@ void MOE::forward_one(int k, const uint64_t* expert_ids, const float* weights, c
void* down_input_ptr = s_down_input_[expert_idx] + ith * config_.stride * ggml_type_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type);
from_float(intermediate_fp32_ptr, down_input_ptr, config_.stride, ggml_internal_get_type_traits(config_.down_type).vec_dot_type);
}
- });
+ }, nullptr);
if (config_.stride % ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) != 0) {
for (int i = 0; i < k; i++) {
from_float(s_intermediate_fp32_[i], s_down_input_[i], config_.intermediate_size, ggml_internal_get_type_traits(config_.down_type).vec_dot_type);
}
}
nth = config_.hidden_size / config_.stride;
- backend->do_work_stealing_job(nth, [&](int task_id) {
+ backend->do_work_stealing_job(nth, nullptr, [&](int task_id) {
int ith = task_id;
for (int i = ith * config_.stride; i < (ith + 1) * config_.stride; i++) {
s_output_fp32_[i] = 0;
@@ -165,7 +165,7 @@ void MOE::forward_one(int k, const uint64_t* expert_ids, const float* weights, c
void* output_ptr = (uint8_t*)output + ith * config_.stride * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type);
from_float(output_fp32_ptr, output_ptr, config_.stride, config_.hidden_type);
}
- });
+ }, nullptr);
if (config_.stride % ggml_blck_size(config_.hidden_type) != 0) {
from_float(s_output_fp32_, output, config_.hidden_size, config_.hidden_type);
}
@@ -191,7 +191,7 @@ void MOE::forward_many(int qlen, int k, const uint64_t* expert_ids, const float*
m_local_down_output_ptr_[i] = m_local_down_output_ + offset * config_.hidden_size;
offset += m_local_num_[i];
}
- backend->do_work_stealing_job(qlen, [&](int i) {
+ backend->do_work_stealing_job(qlen, nullptr, [&](int i) {
const void* gate_input_ptr;
const void* up_input_ptr;
if (config_.hidden_type == ggml_internal_get_type_traits(config_.gate_type).vec_dot_type && config_.hidden_type == ggml_internal_get_type_traits(config_.up_type).vec_dot_type) {
@@ -220,10 +220,10 @@ void MOE::forward_many(int qlen, int k, const uint64_t* expert_ids, const float*
memcpy(m_local_gate_input_ptr_[expert_ids[i * k + j]] + m_local_pos_[i][j] * config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type), gate_input_ptr, config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type));
memcpy(m_local_up_input_ptr_[expert_ids[i * k + j]] + m_local_pos_[i][j] * config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type), up_input_ptr, config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type));
}
- });
+ }, nullptr);
int stride = QK_K;
int nth = config_.intermediate_size / stride;
- backend->do_work_stealing_job(nth * config_.expert_num, [&](int task_id) {
+ backend->do_work_stealing_job(nth * config_.expert_num, nullptr, [&](int task_id) {
int expert_idx = task_id / nth;
int ith = task_id % nth;
void* gate_input_ptr = m_local_gate_input_ptr_[expert_idx];
@@ -242,18 +242,18 @@ void MOE::forward_many(int qlen, int k, const uint64_t* expert_ids, const float*
void* down_input_ptr = m_local_down_input_ptr_[expert_idx] + i * config_.intermediate_size * ggml_type_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) + ith * stride * ggml_type_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type);
from_float(intermediate_fp32_ptr, down_input_ptr, stride, ggml_internal_get_type_traits(config_.down_type).vec_dot_type);
}
- });
+ }, nullptr);
stride = QK_K;
nth = config_.hidden_size / stride;
- backend->do_work_stealing_job(nth * config_.expert_num, [&](int task_id) {
+ backend->do_work_stealing_job(nth * config_.expert_num, nullptr, [&](int task_id) {
int expert_idx = task_id / nth;
int ith = task_id % nth;
void* down_input_ptr = m_local_down_input_ptr_[expert_idx];
void* down_proj_ptr = (uint8_t*)down_proj_ + (expert_idx * config_.hidden_size + ith * stride) * config_.intermediate_size * ggml_type_size(config_.down_type) / ggml_blck_size(config_.down_type);
float* down_output_ptr = m_local_down_output_ptr_[expert_idx] + ith * stride;
llamafile_sgemm(stride, m_local_num_[expert_idx], config_.intermediate_size / ggml_blck_size(config_.down_type), down_proj_ptr, config_.intermediate_size / ggml_blck_size(config_.down_type), down_input_ptr, config_.intermediate_size / ggml_blck_size(config_.down_type), down_output_ptr, config_.hidden_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.down_type, ggml_internal_get_type_traits(config_.down_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);
- });
- backend->do_work_stealing_job(qlen, [&](int i) {
+ }, nullptr);
+ backend->do_work_stealing_job(qlen, nullptr, [&](int i) {
for (int e = 0; e < config_.hidden_size; e++) {
m_output_fp32_[i][e] = 0;
}
@@ -263,7 +263,7 @@ void MOE::forward_many(int qlen, int k, const uint64_t* expert_ids, const float*
}
}
from_float(m_output_fp32_[i], (uint8_t*)output + i * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), config_.hidden_size, config_.hidden_type);
- });
+ }, nullptr);
}
void MOE::forward(int qlen, int k, const uint64_t* expert_ids, const float* weights, const void* input, void* output, Backend* backend) {
diff --git a/ktransformers/local_chat.py b/ktransformers/local_chat.py
old mode 100755
new mode 100644
index b5782d1..3fcbf6f
--- a/ktransformers/local_chat.py
+++ b/ktransformers/local_chat.py
@@ -1,20 +1,14 @@
-# Copyright 2024 Shaoyuan Chen
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
+"""
+Description :
+Author : Boxin Zhang, Azure-Tang
+Version : 0.1.0
+Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
+"""
import os
import platform
import sys
+
project_dir = os.path.dirname(os.path.dirname(__file__))
sys.path.insert(0, project_dir)
import torch
@@ -31,6 +25,7 @@
from ktransformers.optimize.optimize import optimize_and_load_gguf
from ktransformers.models.modeling_deepseek import DeepseekV2ForCausalLM
from ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM
+from ktransformers.models.modeling_llama import LlamaForCausalLM
from ktransformers.models.modeling_mixtral import MixtralForCausalLM
from ktransformers.util.utils import prefill_and_generate
from ktransformers.server.config.config import Config
@@ -38,38 +33,56 @@
custom_models = {
"DeepseekV2ForCausalLM": DeepseekV2ForCausalLM,
"Qwen2MoeForCausalLM": Qwen2MoeForCausalLM,
+ "LlamaForCausalLM": LlamaForCausalLM,
"MixtralForCausalLM": MixtralForCausalLM,
}
-ktransformer_rules_dir = os.path.dirname(os.path.abspath(__file__)) + "/optimize/optimize_rules/"
-default_optimize_rules ={
+ktransformer_rules_dir = (
+ os.path.dirname(os.path.abspath(__file__)) + "/optimize/optimize_rules/"
+)
+default_optimize_rules = {
"DeepseekV2ForCausalLM": ktransformer_rules_dir + "DeepSeek-V2-Chat.yaml",
"Qwen2MoeForCausalLM": ktransformer_rules_dir + "Qwen2-57B-A14B-Instruct.yaml",
+ "LlamaForCausalLM": ktransformer_rules_dir + "Internlm2_5-7b-Chat-1m.yaml",
"MixtralForCausalLM": ktransformer_rules_dir + "Mixtral.yaml",
}
+
def local_chat(
- model_path: str,
+ model_path: str | None = None,
optimize_rule_path: str = None,
- gguf_path: str = None,
+ gguf_path: str | None = None,
max_new_tokens: int = 1000,
cpu_infer: int = Config().cpu_infer,
use_cuda_graph: bool = True,
+ prompt_file : str | None = None,
+ mode: str = "normal",
):
+
+
torch.set_grad_enabled(False)
-
+
Config().cpu_infer = cpu_infer
- tokenizer = AutoTokenizer.from_pretrained(model_path)
+
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
- torch.set_default_dtype(config.torch_dtype)
+ if mode == 'long_context':
+ torch.set_default_dtype(torch.float16)
+ else:
+ torch.set_default_dtype(config.torch_dtype)
with torch.device("meta"):
if config.architectures[0] in custom_models:
print("using custom modeling_xxx.py.")
- if "Qwen2Moe" in config.architectures[0]: # Qwen2Moe must use flash_attention_2 to avoid overflow.
+ if (
+ "Qwen2Moe" in config.architectures[0]
+ ): # Qwen2Moe must use flash_attention_2 to avoid overflow.
config._attn_implementation = "flash_attention_2"
- if "Mixtral" in config.architectures[0]:
+ if "Llama" in config.architectures[0]:
+ config._attn_implementation = "eager"
+ if "Mixtral" in config.architectures[0]:
config._attn_implementation = "flash_attention_2"
+
model = custom_models[config.architectures[0]](config)
else:
model = AutoModelForCausalLM.from_config(
@@ -95,26 +108,50 @@ def local_chat(
if model.generation_config.pad_token_id is None:
model.generation_config.pad_token_id = model.generation_config.eos_token_id
model.eval()
-
logging.basicConfig(level=logging.INFO)
system = platform.system()
- if (system == u'Windows'):
- os.system('cls')
+ if system == "Windows":
+ os.system("cls")
else:
- os.system('clear')
+ os.system("clear")
while True:
content = input("Chat: ")
- if content == "":
- content = "Please write a piece of quicksort code in C++."
+ if content.startswith('"""'): # prefix """
+ # multi lines input
+ content = content[3:] + "\n"
+ while True:
+ line = input("")
+ if line.endswith('"""'):
+ # end multi lines input
+ line = line[:-3] # suffix """
+ if line:
+ content += line + "\n"
+ break
+ else:
+ content += line + "\n"
+ if content == "":
+ if prompt_file != None:
+ content = open(prompt_file, "r").read()
+ else:
+ content = "Please write a piece of quicksort code in C++."
+ elif os.path.isfile(content):
+ content = open(content, "r").read()
messages = [{"role": "user", "content": content}]
input_tensor = tokenizer.apply_chat_template(
messages, add_generation_prompt=True, return_tensors="pt"
)
- torch.set_default_dtype(torch.bfloat16) # TODO: Remove this, replace dtype using config
- generated = prefill_and_generate(model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph)
+ assert Config().long_context_config['max_seq_len'] > input_tensor.shape[1] + max_new_tokens, \
+ "please change max_seq_len in ~/.ktransformers/config.yaml"
+ torch.set_default_dtype(
+ torch.bfloat16
+ ) # TODO: Remove this, replace dtype using config
+ generated = prefill_and_generate(
+ model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode
+ )
+
if __name__ == "__main__":
- fire.Fire(local_chat)
\ No newline at end of file
+ fire.Fire(local_chat)
diff --git a/ktransformers/models/configuration_llama.py b/ktransformers/models/configuration_llama.py
new file mode 100644
index 0000000..2b4f4db
--- /dev/null
+++ b/ktransformers/models/configuration_llama.py
@@ -0,0 +1,203 @@
+# coding=utf-8
+# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""LLaMA model configuration"""
+
+from transformers.configuration_utils import PretrainedConfig
+from transformers.modeling_rope_utils import rope_config_validation
+
+
+class LlamaConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+ defaults will yield a similar configuration to that of the LLaMA-7B.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 32000):
+ Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`LlamaModel`]
+ hidden_size (`int`, *optional*, defaults to 4096):
+ Dimension of the hidden representations.
+ intermediate_size (`int`, *optional*, defaults to 11008):
+ Dimension of the MLP representations.
+ num_hidden_layers (`int`, *optional*, defaults to 32):
+ Number of hidden layers in the Transformer decoder.
+ num_attention_heads (`int`, *optional*, defaults to 32):
+ Number of attention heads for each attention layer in the Transformer decoder.
+ num_key_value_heads (`int`, *optional*):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+ by meanpooling all the original heads within that group. For more details checkout [this
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
+ `num_attention_heads`.
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in the decoder.
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
+ The maximum sequence length that this model might ever be used with. Llama 1 supports up to 2048 tokens,
+ Llama 2 up to 4096, CodeLlama up to 16384.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon used by the rms normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ pad_token_id (`int`, *optional*):
+ Padding token id.
+ bos_token_id (`int`, *optional*, defaults to 1):
+ Beginning of stream token id.
+ eos_token_id (`int`, *optional*, defaults to 2):
+ End of stream token id.
+ pretraining_tp (`int`, *optional*, defaults to 1):
+ Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
+ document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to
+ understand more about it. This value is necessary to ensure exact reproducibility of the pretraining
+ results. Please refer to [this issue](https://github.com/pytorch/pytorch/issues/76232).
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether to tie weight embeddings
+ rope_theta (`float`, *optional*, defaults to 10000.0):
+ The base period of the RoPE embeddings.
+ rope_scaling (`Dict`, *optional*):
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
+ accordingly.
+ Expected contents:
+ `rope_type` (`str`):
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
+ 'llama3'], with 'default' being the original RoPE implementation.
+ `factor` (`float`, *optional*):
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
+ original maximum pre-trained length.
+ `original_max_position_embeddings` (`int`, *optional*):
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
+ pretraining.
+ `attention_factor` (`float`, *optional*):
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
+ `factor` field to infer the suggested value.
+ `beta_fast` (`float`, *optional*):
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
+ ramp function. If unspecified, it defaults to 32.
+ `beta_slow` (`float`, *optional*):
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
+ ramp function. If unspecified, it defaults to 1.
+ `short_factor` (`List[float]`, *optional*):
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+ size divided by the number of attention heads divided by 2
+ `long_factor` (`List[float]`, *optional*):
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+ size divided by the number of attention heads divided by 2
+ `low_freq_factor` (`float`, *optional*):
+ Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
+ `high_freq_factor` (`float`, *optional*):
+ Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
+ attention_bias (`bool`, *optional*, defaults to `False`):
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ mlp_bias (`bool`, *optional*, defaults to `False`):
+ Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers.
+
+ ```python
+ >>> from transformers import LlamaModel, LlamaConfig
+
+ >>> # Initializing a LLaMA llama-7b style configuration
+ >>> configuration = LlamaConfig()
+
+ >>> # Initializing a model from the llama-7b style configuration
+ >>> model = LlamaModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "llama"
+ keys_to_ignore_at_inference = ["past_key_values"]
+
+ def __init__(
+ self,
+ vocab_size=32000,
+ hidden_size=4096,
+ intermediate_size=11008,
+ num_hidden_layers=32,
+ num_attention_heads=32,
+ num_key_value_heads=None,
+ hidden_act="silu",
+ max_position_embeddings=2048,
+ initializer_range=0.02,
+ rms_norm_eps=1e-6,
+ use_cache=True,
+ pad_token_id=None,
+ bos_token_id=1,
+ eos_token_id=2,
+ pretraining_tp=1,
+ tie_word_embeddings=False,
+ rope_theta=10000.0,
+ rope_scaling=None,
+ attention_bias=False,
+ attention_dropout=0.0,
+ mlp_bias=False,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+
+ # for backward compatibility
+ if num_key_value_heads is None:
+ num_key_value_heads = num_attention_heads
+
+ self.num_key_value_heads = num_key_value_heads
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.pretraining_tp = pretraining_tp
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.rope_scaling = rope_scaling
+ self.attention_bias = attention_bias
+ self.attention_dropout = attention_dropout
+ self.mlp_bias = mlp_bias
+
+ # Validate the correctness of rotary position embeddings parameters
+ # BC: if there is a 'type' field, move it to 'rope_type'.
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
+ rope_config_validation(self)
+
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ tie_word_embeddings=tie_word_embeddings,
+ **kwargs,
+ )
diff --git a/ktransformers/models/modeling_llama.py b/ktransformers/models/modeling_llama.py
new file mode 100644
index 0000000..5271ed5
--- /dev/null
+++ b/ktransformers/models/modeling_llama.py
@@ -0,0 +1,1744 @@
+# coding=utf-8
+# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import math
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from transformers.activations import ACT2FN
+from transformers.cache_utils import Cache, DynamicCache, StaticCache
+from transformers.modeling_attn_mask_utils import AttentionMaskConverter
+from transformers.modeling_flash_attention_utils import _flash_attention_forward
+from transformers.modeling_outputs import (
+ BaseModelOutputWithPast,
+ CausalLMOutputWithPast,
+ QuestionAnsweringModelOutput,
+ SequenceClassifierOutputWithPast,
+ TokenClassifierOutput,
+)
+from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
+from transformers.modeling_utils import PreTrainedModel
+from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
+from transformers.utils import (
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ is_flash_attn_greater_or_equal_2_10,
+ logging,
+ replace_return_docstrings,
+)
+from .configuration_llama import LlamaConfig
+
+
+logger = logging.get_logger(__name__)
+
+_CONFIG_FOR_DOC = "LlamaConfig"
+
+
+class LlamaRMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ LlamaRMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+
+ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm)
+
+
+class LlamaRotaryEmbedding(nn.Module):
+ def __init__(
+ self,
+ dim=None,
+ max_position_embeddings=2048,
+ base=10000,
+ device=None,
+ scaling_factor=1.0,
+ rope_type="default",
+ config: Optional[LlamaConfig] = None,
+ ):
+ super().__init__()
+ self.dim = dim
+ self.max_position_embeddings = max_position_embeddings
+ self.base = base
+ self.device = device
+ self.scaling_factor = scaling_factor
+ self.rope_type = rope_type
+ self.config = config
+ # TODO (joao): remove the `if` below, only used for BC
+ self.rope_kwargs = {}
+ if config is None:
+ logger.warning_once(
+ "`LlamaRotaryEmbedding` can now be fully parameterized by passing the model config through the "
+ "`config` argument. All other arguments will be removed in v4.45"
+ )
+ self.rope_kwargs = {
+ "rope_type": rope_type,
+ "factor": scaling_factor,
+ "dim": dim,
+ "base": base,
+ "max_position_embeddings": max_position_embeddings,
+ }
+ self.rope_type = rope_type
+ self.max_seq_len_cached = max_position_embeddings
+ self.original_max_seq_len = max_position_embeddings
+ else:
+ # BC: "rope_type" was originally "type"
+ if config.rope_scaling is not None:
+ self.rope_type = config.rope_scaling.get(
+ "rope_type", config.rope_scaling.get("type")
+ )
+ else:
+ self.rope_type = "default"
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
+
+ self.config = config
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+
+ inv_freq, self.attention_scaling = self.rope_init_fn(
+ self.config, device, **self.rope_kwargs
+ )
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self.original_inv_freq = self.inv_freq
+
+ def _dynamic_frequency_update(self, position_ids, device):
+ """
+ dynamic RoPE layers should recompute `inv_freq` in the following situations:
+ 1 - growing beyond the cached sequence length (allow scaling)
+ 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
+ """
+ seq_len = torch.max(position_ids) + 1
+ # seq_len = position_ids[0, -1] + 1
+ if seq_len > self.max_seq_len_cached: # growth
+ inv_freq, self.attention_scaling = self.rope_init_fn(
+ self.config, device, seq_len=seq_len, **self.rope_kwargs
+ )
+ self.register_buffer(
+ "inv_freq", inv_freq, persistent=False
+ ) # TODO joao: may break with compilation
+ self.max_seq_len_cached = seq_len
+
+ if (
+ seq_len < self.original_max_seq_len
+ and self.max_seq_len_cached > self.original_max_seq_len
+ ): # reset
+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
+ self.max_seq_len_cached = self.original_max_seq_len
+
+ @torch.no_grad()
+ def forward(self, x, position_ids):
+ # if "dynamic" in self.rope_type:
+ # self._dynamic_frequency_update(position_ids, device=x.device)
+
+ # Core RoPE block
+ inv_freq_expanded = (
+ self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
+ )
+ position_ids_expanded = position_ids[:, None, :].float()
+ # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
+ device_type = x.device.type
+ device_type = (
+ device_type
+ if isinstance(device_type, str) and device_type != "mps"
+ else "cpu"
+ )
+ with torch.autocast(device_type=device_type, enabled=False):
+ freqs = (
+ inv_freq_expanded.float() @ position_ids_expanded.float()
+ ).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos()
+ sin = emb.sin()
+
+ # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
+ cos = cos * self.attention_scaling
+ sin = sin * self.attention_scaling
+
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
+ """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
+
+ def __init__(self, *args, **kwargs):
+ logger.warning_once(
+ "`LlamaLinearScalingRotaryEmbedding` is deprecated an will be removed in v4.45. Please use "
+ "`LlamaRotaryEmbedding`, which now also does linear scaling (simply pass the model config to __init__)."
+ )
+ kwargs["rope_type"] = "linear"
+ super().__init__(*args, **kwargs)
+
+
+class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
+ """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
+
+ def __init__(self, *args, **kwargs):
+ logger.warning_once(
+ "`LlamaDynamicNTKScalingRotaryEmbedding` is deprecated an will be removed in v4.45. Please use "
+ "`LlamaRotaryEmbedding`, which now also does dynamic ntk scaling (simply pass the model config to "
+ "__init__)."
+ )
+ kwargs["rope_type"] = "dynamic"
+ super().__init__(*args, **kwargs)
+
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding to the query and key tensors.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ position_ids (`torch.Tensor`, *optional*):
+ Deprecated and unused.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+class LlamaMLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = config.intermediate_size
+ self.gate_proj = nn.Linear(
+ self.hidden_size, self.intermediate_size, bias=config.mlp_bias
+ )
+ self.up_proj = nn.Linear(
+ self.hidden_size, self.intermediate_size, bias=config.mlp_bias
+ )
+ self.down_proj = nn.Linear(
+ self.intermediate_size, self.hidden_size, bias=config.mlp_bias
+ )
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, x):
+ if self.config.pretraining_tp > 1:
+ slice = self.intermediate_size // self.config.pretraining_tp
+ gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
+ up_proj_slices = self.up_proj.weight.split(slice, dim=0)
+ down_proj_slices = self.down_proj.weight.split(slice, dim=1)
+
+ gate_proj = torch.cat(
+ [
+ F.linear(x, gate_proj_slices[i])
+ for i in range(self.config.pretraining_tp)
+ ],
+ dim=-1,
+ )
+ up_proj = torch.cat(
+ [
+ F.linear(x, up_proj_slices[i])
+ for i in range(self.config.pretraining_tp)
+ ],
+ dim=-1,
+ )
+
+ intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
+ down_proj = [
+ F.linear(intermediate_states[i], down_proj_slices[i])
+ for i in range(self.config.pretraining_tp)
+ ]
+ down_proj = sum(down_proj)
+ else:
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+
+ return down_proj
+
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(
+ batch, num_key_value_heads, n_rep, slen, head_dim
+ )
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+class LlamaAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ if layer_idx is None:
+ logger.warning_once(
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
+ "when creating this class."
+ )
+
+ self.attention_dropout = config.attention_dropout
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.hidden_size // self.num_heads
+ self.num_key_value_heads = config.num_key_value_heads
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+ self.max_position_embeddings = config.max_position_embeddings
+ self.rope_theta = config.rope_theta
+ self.is_causal = True
+
+ if (self.head_dim * self.num_heads) != self.hidden_size:
+ raise ValueError(
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
+ f" and `num_heads`: {self.num_heads})."
+ )
+
+ self.q_proj = nn.Linear(
+ self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.k_proj = nn.Linear(
+ self.hidden_size,
+ self.num_key_value_heads * self.head_dim,
+ bias=config.attention_bias,
+ )
+ self.v_proj = nn.Linear(
+ self.hidden_size,
+ self.num_key_value_heads * self.head_dim,
+ bias=config.attention_bias,
+ )
+ self.o_proj = nn.Linear(
+ self.hidden_size, self.hidden_size, bias=config.attention_bias
+ )
+
+ # TODO (joao): remove in v4.45 (RoPE is computed in the model, not in the decoder layers)
+ self.rotary_emb = LlamaRotaryEmbedding(config=self.config)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[
+ Tuple[torch.Tensor, torch.Tensor]
+ ] = None, # will become mandatory in v4.45
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ bsz, q_len, _ = hidden_states.size()
+
+ if self.config.pretraining_tp > 1:
+ key_value_slicing = (
+ self.num_key_value_heads * self.head_dim
+ ) // self.config.pretraining_tp
+ query_slices = self.q_proj.weight.split(
+ (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
+ )
+ key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
+ value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
+
+ query_states = [
+ F.linear(hidden_states, query_slices[i])
+ for i in range(self.config.pretraining_tp)
+ ]
+ query_states = torch.cat(query_states, dim=-1)
+
+ key_states = [
+ F.linear(hidden_states, key_slices[i])
+ for i in range(self.config.pretraining_tp)
+ ]
+ key_states = torch.cat(key_states, dim=-1)
+
+ value_states = [
+ F.linear(hidden_states, value_slices[i])
+ for i in range(self.config.pretraining_tp)
+ ]
+ value_states = torch.cat(value_states, dim=-1)
+
+ else:
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(
+ bsz, q_len, self.num_heads, self.head_dim
+ ).transpose(1, 2)
+ key_states = key_states.view(
+ bsz, q_len, self.num_key_value_heads, self.head_dim
+ ).transpose(1, 2)
+ value_states = value_states.view(
+ bsz, q_len, self.num_key_value_heads, self.head_dim
+ ).transpose(1, 2)
+
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ else:
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(
+ query_states, key_states, cos, sin
+ )
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_value.update(
+ key_states, value_states, self.layer_idx, cache_kwargs
+ )
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ attn_weights = torch.matmul(
+ query_states, key_states.transpose(2, 3)
+ ) / math.sqrt(self.head_dim)
+
+ if attention_mask is not None: # no matter the length, we just slice it
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(
+ attn_weights, dim=-1, dtype=torch.float32
+ ).to(query_states.dtype)
+ attn_weights = nn.functional.dropout(
+ attn_weights, p=self.attention_dropout, training=self.training
+ )
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ attn_output = attn_output.reshape(bsz, q_len, -1)
+
+ if self.config.pretraining_tp > 1:
+ attn_output = attn_output.split(
+ self.hidden_size // self.config.pretraining_tp, dim=2
+ )
+ o_proj_slices = self.o_proj.weight.split(
+ self.hidden_size // self.config.pretraining_tp, dim=1
+ )
+ attn_output = sum(
+ [
+ F.linear(attn_output[i], o_proj_slices[i])
+ for i in range(self.config.pretraining_tp)
+ ]
+ )
+ else:
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+class LlamaFlashAttention2(LlamaAttention):
+ """
+ Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[
+ Tuple[torch.Tensor, torch.Tensor]
+ ] = None, # will become mandatory in v4.45
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ if isinstance(past_key_value, StaticCache):
+ raise ValueError(
+ "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
+ "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
+ )
+
+ output_attentions = False
+
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ # Flash attention requires the input to have the shape
+ # batch_size x seq_length x head_dim x hidden_dim
+ # therefore we just need to keep the original shape
+ query_states = query_states.view(
+ bsz, q_len, self.num_heads, self.head_dim
+ ).transpose(1, 2)
+ key_states = key_states.view(
+ bsz, q_len, self.num_key_value_heads, self.head_dim
+ ).transpose(1, 2)
+ value_states = value_states.view(
+ bsz, q_len, self.num_key_value_heads, self.head_dim
+ ).transpose(1, 2)
+
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ else:
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(
+ query_states, key_states, cos, sin
+ )
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_value.update(
+ key_states, value_states, self.layer_idx, cache_kwargs
+ )
+
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
+ # to be able to avoid many of these transpose/reshape/view.
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ dropout_rate = self.attention_dropout if self.training else 0.0
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in the correct dtype just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+ # in fp32. (LlamaRMSNorm handles it correctly)
+
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ attn_output = _flash_attention_forward(
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ q_len,
+ dropout=dropout_rate,
+ sliding_window=getattr(self, "sliding_window", None),
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
+ is_causal=self.is_causal,
+ )
+
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+class LlamaSdpaAttention(LlamaAttention):
+ """
+ Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
+ `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
+ SDPA API.
+ """
+
+ # Adapted from LlamaAttention.forward
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[
+ Tuple[torch.Tensor, torch.Tensor]
+ ] = None, # will become mandatory in v4.45
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ if output_attentions:
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
+ logger.warning_once(
+ "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+ )
+ return super().forward(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ )
+
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(
+ bsz, q_len, self.num_heads, self.head_dim
+ ).transpose(1, 2)
+ key_states = key_states.view(
+ bsz, q_len, self.num_key_value_heads, self.head_dim
+ ).transpose(1, 2)
+ value_states = value_states.view(
+ bsz, q_len, self.num_key_value_heads, self.head_dim
+ ).transpose(1, 2)
+
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ else:
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(
+ query_states, key_states, cos, sin
+ )
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_value.update(
+ key_states, value_states, self.layer_idx, cache_kwargs
+ )
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ causal_mask = attention_mask
+ if attention_mask is not None:
+ causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
+
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
+ if query_states.device.type == "cuda" and causal_mask is not None:
+ query_states = query_states.contiguous()
+ key_states = key_states.contiguous()
+ value_states = value_states.contiguous()
+
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
+ is_causal = True if causal_mask is None and q_len > 1 else False
+
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
+ query_states,
+ key_states,
+ value_states,
+ attn_mask=causal_mask,
+ dropout_p=self.attention_dropout if self.training else 0.0,
+ is_causal=is_causal,
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.view(bsz, q_len, -1)
+
+ attn_output = self.o_proj(attn_output)
+
+ return attn_output, None, past_key_value
+
+
+LLAMA_ATTENTION_CLASSES = {
+ "eager": LlamaAttention,
+ "flash_attention_2": LlamaFlashAttention2,
+ "sdpa": LlamaSdpaAttention,
+}
+
+
+class LlamaDecoderLayer(nn.Module):
+ def __init__(self, config: LlamaConfig, layer_idx: int):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+
+ self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](
+ config=config, layer_idx=layer_idx
+ )
+
+ self.mlp = LlamaMLP(config)
+ self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = LlamaRMSNorm(
+ config.hidden_size, eps=config.rms_norm_eps
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[
+ Tuple[torch.Tensor, torch.Tensor]
+ ] = None, # will become mandatory in v4.45
+ **kwargs,
+ ) -> Tuple[
+ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
+ ]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`, *optional*):
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
+ query_sequence_length, key_sequence_length)` if default attention is used.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence
+ position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
+ Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
+ with `head_dim` being the embedding dimension of each attention head.
+ kwargs (`dict`, *optional*):
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
+ into the model
+ """
+ residual = hidden_states
+
+ hidden_states = self.input_layernorm(hidden_states)
+
+ # Self Attention
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ if use_cache:
+ outputs += (present_key_value,)
+
+ return outputs
+
+
+LLAMA_START_DOCSTRING = r"""
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+ and behavior.
+
+ Parameters:
+ config ([`LlamaConfig`]):
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
+ load the weights associated with the model, only the configuration. Check out the
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+
+@add_start_docstrings(
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
+ LLAMA_START_DOCSTRING,
+)
+class LlamaPreTrainedModel(PreTrainedModel):
+ config_class = LlamaConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["LlamaDecoderLayer"]
+ _skip_keys_device_placement = ["past_key_values"]
+ _supports_flash_attn_2 = True
+ _supports_sdpa = True
+ _supports_cache_class = True
+ _supports_quantized_cache = True
+ _supports_static_cache = True
+
+ def _init_weights(self, module):
+ std = self.config.initializer_range
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+
+LLAMA_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+ it.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
+ `past_key_values`).
+
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
+ information on the default strategy.
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.n_positions - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
+
+ Two formats are allowed:
+ - a [`~cache_utils.Cache`] instance;
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
+ cache format.
+
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
+ legacy cache format will be returned.
+
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
+ of shape `(batch_size, sequence_length)`.
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
+ the complete sequence length.
+"""
+
+
+@add_start_docstrings(
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
+ LLAMA_START_DOCSTRING,
+)
+class LlamaModel(LlamaPreTrainedModel):
+ """
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
+
+ Args:
+ config: LlamaConfig
+ """
+
+ def __init__(self, config: LlamaConfig):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(
+ config.vocab_size, config.hidden_size, self.padding_idx
+ )
+ self.layers = nn.ModuleList(
+ [
+ LlamaDecoderLayer(config, layer_idx)
+ for layer_idx in range(config.num_hidden_layers)
+ ]
+ )
+ self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.rotary_emb = LlamaRotaryEmbedding(config=config)
+ self.gradient_checkpointing = False
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.embed_tokens = value
+
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
+ output_attentions = (
+ output_attentions
+ if output_attentions is not None
+ else self.config.output_attentions
+ )
+ output_hidden_states = (
+ output_hidden_states
+ if output_hidden_states is not None
+ else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = (
+ return_dict if return_dict is not None else self.config.use_return_dict
+ )
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError(
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
+ )
+
+ if self.gradient_checkpointing and self.training and use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
+ )
+ use_cache = False
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ return_legacy_cache = False
+ if (
+ use_cache and not isinstance(past_key_values, Cache) and not self.training
+ ): # kept for BC (non `Cache` `past_key_values` inputs)
+ return_legacy_cache = True
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
+ logger.warning_once(
+ "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. "
+ "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)"
+ )
+
+ if cache_position is None:
+ past_seen_tokens = (
+ past_key_values.get_seq_length() if past_key_values is not None else 0
+ )
+ cache_position = torch.arange(
+ past_seen_tokens,
+ past_seen_tokens + inputs_embeds.shape[1],
+ device=inputs_embeds.device,
+ )
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ causal_mask = self._update_causal_mask(
+ attention_mask,
+ inputs_embeds,
+ cache_position,
+ past_key_values,
+ output_attentions,
+ )
+ hidden_states = inputs_embeds
+
+ # create position embeddings to be shared across the decoder layers
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ next_decoder_cache = None
+
+ for decoder_layer in self.layers:
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ decoder_layer.__call__,
+ hidden_states,
+ causal_mask,
+ position_ids,
+ past_key_values,
+ output_attentions,
+ use_cache,
+ cache_position,
+ position_embeddings,
+ )
+ else:
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ hidden_states = self.norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ next_cache = next_decoder_cache if use_cache else None
+ if return_legacy_cache:
+ next_cache = next_cache.to_legacy_cache()
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
+ if v is not None
+ )
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+ def _update_causal_mask(
+ self,
+ attention_mask: torch.Tensor,
+ input_tensor: torch.Tensor,
+ cache_position: torch.Tensor,
+ past_key_values: Cache,
+ output_attentions: bool,
+ ):
+ # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
+ # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
+ # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
+ # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
+
+ if self.config._attn_implementation == "flash_attention_2":
+ if attention_mask is not None and 0.0 in attention_mask:
+ return attention_mask
+ return None
+
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
+ # to infer the attention mask.
+ past_seen_tokens = (
+ past_key_values.get_seq_length() if past_key_values is not None else 0
+ )
+ using_static_cache = isinstance(past_key_values, StaticCache)
+
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
+ if (
+ self.config._attn_implementation == "sdpa"
+ and not using_static_cache
+ and not output_attentions
+ ):
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
+ attention_mask,
+ inputs_embeds=input_tensor,
+ past_key_values_length=past_seen_tokens,
+ is_training=self.training,
+ ):
+ return None
+
+ dtype, device = input_tensor.dtype, input_tensor.device
+ min_dtype = torch.finfo(dtype).min
+ sequence_length = input_tensor.shape[1]
+ if using_static_cache:
+ target_length = past_key_values.get_max_length()
+ else:
+ target_length = (
+ attention_mask.shape[-1]
+ if isinstance(attention_mask, torch.Tensor)
+ else past_seen_tokens + sequence_length + 1
+ )
+
+ if attention_mask is not None and attention_mask.dim() == 4:
+ # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
+ if attention_mask.max() != 0:
+ raise ValueError(
+ "Custom 4D attention mask should be passed in inverted form with max==0`"
+ )
+ causal_mask = attention_mask
+ else:
+ causal_mask = torch.full(
+ (sequence_length, target_length),
+ fill_value=min_dtype,
+ dtype=dtype,
+ device=device,
+ )
+ if sequence_length != 1:
+ causal_mask = torch.triu(causal_mask, diagonal=1)
+ causal_mask *= torch.arange(
+ target_length, device=device
+ ) > cache_position.reshape(-1, 1)
+ causal_mask = causal_mask[None, None, :, :].expand(
+ input_tensor.shape[0], 1, -1, -1
+ )
+ if attention_mask is not None:
+ causal_mask = (
+ causal_mask.clone()
+ ) # copy to contiguous memory for in-place edit
+ mask_length = attention_mask.shape[-1]
+ padding_mask = (
+ causal_mask[:, :, :, :mask_length]
+ + attention_mask[:, None, None, :]
+ )
+ padding_mask = padding_mask == 0
+ causal_mask[:, :, :, :mask_length] = causal_mask[
+ :, :, :, :mask_length
+ ].masked_fill(padding_mask, min_dtype)
+ if (
+ self.config._attn_implementation == "sdpa"
+ and attention_mask is not None
+ and attention_mask.device.type == "cuda"
+ and not output_attentions
+ ):
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
+ # Details: https://github.com/pytorch/pytorch/issues/110213
+ causal_mask = AttentionMaskConverter._unmask_unattended(
+ causal_mask, min_dtype
+ )
+
+ return causal_mask
+
+
+class LlamaForCausalLM(LlamaPreTrainedModel):
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = LlamaModel(config)
+ self.vocab_size = config.vocab_size
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.embed_tokens = value
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ def set_decoder(self, decoder):
+ self.model = decoder
+
+ def get_decoder(self):
+ return self.model
+
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
+ @replace_return_docstrings(
+ output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
+ )
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
+ r"""
+ Args:
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, LlamaForCausalLM
+
+ >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
+
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
+ ```"""
+ output_attentions = (
+ output_attentions
+ if output_attentions is not None
+ else self.config.output_attentions
+ )
+ output_hidden_states = (
+ output_hidden_states
+ if output_hidden_states is not None
+ else self.config.output_hidden_states
+ )
+ return_dict = (
+ return_dict if return_dict is not None else self.config.use_return_dict
+ )
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ cache_position=cache_position,
+ )
+
+ hidden_states = outputs[0]
+ if self.config.pretraining_tp > 1:
+ lm_head_slices = self.lm_head.weight.split(
+ self.vocab_size // self.config.pretraining_tp, dim=0
+ )
+ logits = [
+ F.linear(hidden_states, lm_head_slices[i])
+ for i in range(self.config.pretraining_tp)
+ ]
+ logits = torch.cat(logits, dim=-1)
+ else:
+ logits = self.lm_head(hidden_states)
+ # logits = logits.float()
+
+ loss = None
+ if labels is not None:
+ # Shift so that tokens < n predict n
+ shift_logits = logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ # Flatten the tokens
+ loss_fct = CrossEntropyLoss()
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
+ shift_labels = shift_labels.view(-1)
+ # Enable model parallelism
+ shift_labels = shift_labels.to(shift_logits.device)
+ loss = loss_fct(shift_logits, shift_labels)
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return (loss,) + output if loss is not None else output
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ past_key_values=None,
+ attention_mask=None,
+ inputs_embeds=None,
+ cache_position=None,
+ position_ids=None,
+ use_cache=True,
+ **kwargs,
+ ):
+ # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
+ # Exception 1: when passing input_embeds, input_ids may be missing entries
+ # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
+ if past_key_values is not None:
+ if inputs_embeds is not None: # Exception 1
+ input_ids = input_ids[:, -cache_position.shape[0] :]
+ elif (
+ input_ids.shape[1] != cache_position.shape[0]
+ ): # Default case (the "else", a no op, is Exception 2)
+ input_ids = input_ids[:, cache_position]
+
+ if attention_mask is not None and position_ids is None:
+ # create position_ids on the fly for batch generation
+ position_ids = attention_mask.long().cumsum(-1) - 1
+ position_ids.masked_fill_(attention_mask == 0, 1)
+ if past_key_values:
+ position_ids = position_ids[:, -input_ids.shape[1] :]
+
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
+ if inputs_embeds is not None and cache_position[0] == 0:
+ model_inputs = {"inputs_embeds": inputs_embeds}
+ else:
+ model_inputs = {
+ "input_ids": input_ids.contiguous()
+ } # `contiguous()` needed for compilation use cases
+
+ model_inputs.update(
+ {
+ "position_ids": position_ids,
+ "cache_position": cache_position,
+ "past_key_values": past_key_values,
+ "use_cache": use_cache,
+ "attention_mask": attention_mask,
+ }
+ )
+ return model_inputs
+
+
+@add_start_docstrings(
+ """
+ The LLaMa Model transformer with a sequence classification head on top (linear layer).
+
+ [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
+ (e.g. GPT-2) do.
+
+ Since it does classification on the last token, it requires to know the position of the last token. If a
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
+ each row of the batch).
+ """,
+ LLAMA_START_DOCSTRING,
+)
+class LlamaForSequenceClassification(LlamaPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.model = LlamaModel(config)
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.embed_tokens = value
+
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = (
+ return_dict if return_dict is not None else self.config.use_return_dict
+ )
+
+ transformer_outputs = self.model(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ hidden_states = transformer_outputs[0]
+ logits = self.score(hidden_states)
+
+ if input_ids is not None:
+ batch_size = input_ids.shape[0]
+ else:
+ batch_size = inputs_embeds.shape[0]
+
+ if self.config.pad_token_id is None and batch_size != 1:
+ raise ValueError(
+ "Cannot handle batch sizes > 1 if no padding token is defined."
+ )
+ if self.config.pad_token_id is None:
+ sequence_lengths = -1
+ else:
+ if input_ids is not None:
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
+ sequence_lengths = (
+ torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
+ )
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
+ sequence_lengths = sequence_lengths.to(logits.device)
+ else:
+ sequence_lengths = -1
+
+ pooled_logits = logits[
+ torch.arange(batch_size, device=logits.device), sequence_lengths
+ ]
+
+ loss = None
+ if labels is not None:
+ labels = labels.to(logits.device)
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (
+ labels.dtype == torch.long or labels.dtype == torch.int
+ ):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(pooled_logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(
+ pooled_logits.view(-1, self.num_labels), labels.view(-1)
+ )
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(pooled_logits, labels)
+ if not return_dict:
+ output = (pooled_logits,) + transformer_outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SequenceClassifierOutputWithPast(
+ loss=loss,
+ logits=pooled_logits,
+ past_key_values=transformer_outputs.past_key_values,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+The Llama Model transformer with a span classification head on top for extractive question-answering tasks like
+SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
+ """,
+ LLAMA_START_DOCSTRING,
+)
+class LlamaForQuestionAnswering(LlamaPreTrainedModel):
+ base_model_prefix = "transformer"
+
+ # Copied from transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Llama
+ def __init__(self, config):
+ super().__init__(config)
+ self.transformer = LlamaModel(config)
+ self.qa_outputs = nn.Linear(config.hidden_size, 2)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.transformer.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.transformer.embed_tokens = value
+
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ start_positions: Optional[torch.LongTensor] = None,
+ end_positions: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, QuestionAnsweringModelOutput]:
+ r"""
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
+ are not taken into account for computing the loss.
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
+ are not taken into account for computing the loss.
+ """
+ return_dict = (
+ return_dict if return_dict is not None else self.config.use_return_dict
+ )
+
+ outputs = self.transformer(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+
+ logits = self.qa_outputs(sequence_output)
+ start_logits, end_logits = logits.split(1, dim=-1)
+ start_logits = start_logits.squeeze(-1).contiguous()
+ end_logits = end_logits.squeeze(-1).contiguous()
+
+ total_loss = None
+ if start_positions is not None and end_positions is not None:
+ # If we are on multi-GPU, split add a dimension
+ if len(start_positions.size()) > 1:
+ start_positions = start_positions.squeeze(-1).to(start_logits.device)
+ if len(end_positions.size()) > 1:
+ end_positions = end_positions.squeeze(-1).to(end_logits.device)
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
+ ignored_index = start_logits.size(1)
+ start_positions = start_positions.clamp(0, ignored_index)
+ end_positions = end_positions.clamp(0, ignored_index)
+
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
+ start_loss = loss_fct(start_logits, start_positions)
+ end_loss = loss_fct(end_logits, end_positions)
+ total_loss = (start_loss + end_loss) / 2
+
+ if not return_dict:
+ output = (start_logits, end_logits) + outputs[2:]
+ return ((total_loss,) + output) if total_loss is not None else output
+
+ return QuestionAnsweringModelOutput(
+ loss=total_loss,
+ start_logits=start_logits,
+ end_logits=end_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ The Llama Model transformer with a token classification head on top (a linear layer on top of the hidden-states
+ output) e.g. for Named-Entity-Recognition (NER) tasks.
+ """,
+ LLAMA_START_DOCSTRING,
+)
+class LlamaForTokenClassification(LlamaPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.model = LlamaModel(config)
+ if getattr(config, "classifier_dropout", None) is not None:
+ classifier_dropout = config.classifier_dropout
+ elif getattr(config, "hidden_dropout", None) is not None:
+ classifier_dropout = config.hidden_dropout
+ else:
+ classifier_dropout = 0.1
+ self.dropout = nn.Dropout(classifier_dropout)
+ self.score = nn.Linear(config.hidden_size, config.num_labels)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.embed_tokens = value
+
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, TokenClassifierOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = (
+ return_dict if return_dict is not None else self.config.use_return_dict
+ )
+
+ outputs = self.model(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ sequence_output = outputs[0]
+ sequence_output = self.dropout(sequence_output)
+ logits = self.score(sequence_output)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TokenClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
diff --git a/ktransformers/operators/RoPE.py b/ktransformers/operators/RoPE.py
index 9dc233b..dca441d 100644
--- a/ktransformers/operators/RoPE.py
+++ b/ktransformers/operators/RoPE.py
@@ -1,67 +1,128 @@
-'''
+"""
Description :
Author : Boxin Zhang
Version : 0.1.0
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
-'''
+"""
+
from torch import nn
-from ktransformers.models.modeling_deepseek import DeepseekV2YarnRotaryEmbedding, DeepseekV2RotaryEmbedding
+from transformers import ROPE_INIT_FUNCTIONS
+from ktransformers.models.modeling_llama import (
+ LlamaRotaryEmbedding,
+ LlamaLinearScalingRotaryEmbedding,
+ LlamaDynamicNTKScalingRotaryEmbedding,
+)
+from ktransformers.models.modeling_deepseek import (
+ DeepseekV2YarnRotaryEmbedding,
+ DeepseekV2RotaryEmbedding,
+)
from ktransformers.operators.base_operator import BaseInjectedModule
from ktransformers.util.custom_gguf import GGUFLoader
from ktransformers.util.utils import InferenceState
from transformers.configuration_utils import PretrainedConfig
+
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Qwen2Moe
class RotaryEmbedding(BaseInjectedModule, DeepseekV2RotaryEmbedding):
- def __init__(self,
- key: str,
- gguf_loader : GGUFLoader,
- config: PretrainedConfig,
- orig_module: nn.Module,
- # device: str = "cuda",
- generate_device: str = "cuda",
- prefill_device: str = "cuda",
- **kwargs):
- BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
- self.orig_module.__init__(orig_module.dim,
+ def __init__(
+ self,
+ key: str,
+ gguf_loader: GGUFLoader,
+ config: PretrainedConfig,
+ orig_module: nn.Module,
+ # device: str = "cuda",
+ generate_device: str = "cuda",
+ prefill_device: str = "cuda",
+ **kwargs,
+ ):
+ BaseInjectedModule.__init__(
+ self, key, gguf_loader, config, orig_module, generate_device, **kwargs
+ )
+ self.orig_module.__init__(
+ orig_module.dim, orig_module.max_position_embeddings, orig_module.base
+ )
+ self.generate_device = generate_device
+ self.prefill_device = prefill_device
+
+ def load(self):
+ self.orig_module.__init__(
+ self.orig_module.dim,
+ self.orig_module.max_position_embeddings,
+ self.orig_module.base,
+ self.device,
+ )
+
+
+class RotaryEmbeddingV2(BaseInjectedModule, LlamaRotaryEmbedding):
+ def __init__(
+ self,
+ key: str,
+ gguf_loader: GGUFLoader,
+ config: PretrainedConfig,
+ orig_module: nn.Module,
+ generate_device: str = "cuda",
+ prefill_device: str = "cuda",
+ **kwargs,
+ ):
+ BaseInjectedModule.__init__(
+ self, key, gguf_loader, config, orig_module, generate_device, **kwargs
+ )
+ self.orig_module.__init__(
+ orig_module.dim,
orig_module.max_position_embeddings,
- orig_module.base)
+ orig_module.base,
+ None,
+ orig_module.scaling_factor,
+ orig_module.rope_type,
+ orig_module.config,
+ )
self.generate_device = generate_device
self.prefill_device = prefill_device
-
+
def load(self):
- self.orig_module.__init__(self.orig_module.dim,
+ self.orig_module.__init__(
+ self.orig_module.dim,
self.orig_module.max_position_embeddings,
self.orig_module.base,
- self.device)
-
+ self.device,
+ self.orig_module.scaling_factor,
+ self.orig_module.rope_type,
+ self.orig_module.config,
+ )
+
class YarnRotaryEmbedding(BaseInjectedModule, DeepseekV2YarnRotaryEmbedding):
- def __init__(self,
- key: str,
- gguf_loader : GGUFLoader,
- config: PretrainedConfig,
- orig_module: nn.Module,
- # device: str = "cuda",
- generate_device: str = "cuda",
- prefill_device: str = "cuda",
- **kwargs):
- BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
- self.orig_module.__init__(orig_module.dim,
+ def __init__(
+ self,
+ key: str,
+ gguf_loader: GGUFLoader,
+ config: PretrainedConfig,
+ orig_module: nn.Module,
+ # device: str = "cuda",
+ generate_device: str = "cuda",
+ prefill_device: str = "cuda",
+ **kwargs,
+ ):
+ BaseInjectedModule.__init__(
+ self, key, gguf_loader, config, orig_module, generate_device, **kwargs
+ )
+ self.orig_module.__init__(
+ orig_module.dim,
orig_module.max_position_embeddings,
orig_module.base,
- None, #device
+ None, # device
orig_module.scaling_factor,
orig_module.original_max_position_embeddings,
orig_module.beta_fast,
orig_module.beta_slow,
orig_module.mscale,
- orig_module.mscale_all_dim)
+ orig_module.mscale_all_dim,
+ )
self.generate_device = generate_device
self.prefill_device = prefill_device
-
-
+
def load(self):
- self.orig_module.__init__(self.orig_module.dim,
+ self.orig_module.__init__(
+ self.orig_module.dim,
self.orig_module.max_position_embeddings,
self.orig_module.base,
self.generate_device,
@@ -70,5 +131,42 @@ def load(self):
self.orig_module.beta_fast,
self.orig_module.beta_slow,
self.orig_module.mscale,
- self.orig_module.mscale_all_dim)
-
+ self.orig_module.mscale_all_dim,
+ )
+
+
+class DynamicNTKScalingRotaryEmbedding(
+ BaseInjectedModule, LlamaDynamicNTKScalingRotaryEmbedding
+):
+ def __init__(
+ self,
+ key: str,
+ gguf_loader: GGUFLoader,
+ config: PretrainedConfig,
+ orig_module: nn.Module,
+ device: str = "cuda",
+ **kwargs,
+ ):
+ BaseInjectedModule.__init__(
+ self, key, gguf_loader, config, orig_module, device, **kwargs
+ )
+ self.orig_module.__init__(
+ orig_module.dim,
+ orig_module.max_position_embeddings,
+ orig_module.base,
+ None, # device
+ orig_module.scaling_factor,
+ orig_module.rope_type,
+ orig_module.config,
+ )
+
+ def load(self):
+ self.orig_module.__init__(
+ self.orig_module.dim,
+ self.orig_module.max_position_embeddings,
+ self.orig_module.base,
+ self.orig_module.device,
+ self.orig_module.scaling_factor,
+ self.orig_module.rope_type,
+ self.orig_module.config,
+ )
diff --git a/ktransformers/operators/attention.py b/ktransformers/operators/attention.py
index 3cfb9fd..33dd021 100644
--- a/ktransformers/operators/attention.py
+++ b/ktransformers/operators/attention.py
@@ -7,16 +7,22 @@
import torch
from torch import nn
import warnings
+import torch.nn.functional as F
+from ktransformers.operators.models import KLlamaModel
from ktransformers.models.configuration_deepseek import DeepseekV2Config
+from ktransformers.models.configuration_llama import LlamaConfig
+from ktransformers.models.modeling_llama import LlamaRotaryEmbedding
from ktransformers.models.modeling_deepseek import DeepseekV2Attention, apply_rotary_pos_emb
from typing import Optional, Tuple
from ktransformers.operators.base_operator import BaseInjectedModule
from ktransformers.util.custom_gguf import GGUFLoader
+import logging
from transformers.configuration_utils import PretrainedConfig
from transformers.cache_utils import Cache
-
+logger = logging.getLogger("attention")
class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
+ attn_mask: Optional[torch.Tensor] = None
def __init__(self,
key: str,
@@ -24,10 +30,12 @@ def __init__(self,
config: PretrainedConfig,
orig_module: nn.Module,
device: str = "cuda",
+ chunck_size: int = 1000,
**kwargs):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, device, **kwargs)
self.orig_module.__init__(orig_module.config,
orig_module.layer_idx)
+ self.chunck_size = chunck_size # TODO, generate chunck_size automatically.
def get_absorbed(self) -> Tuple[torch.Tensor, torch.Tensor]:
if not (hasattr(self, 'q_absorb') and hasattr(self, 'out_absorb')):
@@ -157,9 +165,8 @@ def forward(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)
bsz, q_len, _ = hidden_states.size()
- chunck_size = 256 # TODO, generate chunck_size automatically.
- if q_len <= chunck_size:
+ if q_len <= self.chunck_size:
return self.forward_chunck(
hidden_states,
attention_mask,
@@ -176,24 +183,170 @@ def forward(
cur_idx = 0
while cur_idx < q_len:
if attention_mask is not None:
- chunk_mask = attention_mask[:, :, cur_idx:min(cur_idx + chunck_size, q_len), ...]
+ chunk_mask = attention_mask[:, :, cur_idx:min(cur_idx + self.chunck_size, q_len), ...]
else:
- chunk_mask = None
+ # generate chunk_mask automatically.
+ self.attn_mask = \
+ torch.zeros(1, 1, self.chunck_size, past_key_value.max_cache_len, device=hidden_states.device) \
+ if self.attn_mask is None \
+ else self.attn_mask
+ self.attn_mask[:, :, :, cur_idx:min(cur_idx+self.chunck_size, past_key_value.max_cache_len)] = \
+ -1e+38 * torch.triu(torch.ones(self.chunck_size, self.chunck_size, device=hidden_states.device), diagonal=1)\
+ [:,:min(self.chunck_size, min(past_key_value.max_cache_len-cur_idx, self.chunck_size))]
+ self.attn_mask[:, :, :, cur_idx+self.chunck_size:] = -1e+38
+ self.attn_mask[:, :, :, :cur_idx] = 0
+ chunck_mask = torch.narrow(self.attn_mask, 2, 0, min(self.chunck_size, q_len-cur_idx))
cur_output, _, _ = self.forward_chunck(
- hidden_states[:, cur_idx:min(cur_idx + chunck_size, q_len), ...],
- chunk_mask,
- position_ids[:, cur_idx:min(cur_idx + chunck_size, q_len)],
+ hidden_states[:, cur_idx:min(cur_idx + self.chunck_size, q_len), ...],
+ chunck_mask,
+ position_ids[:, cur_idx:min(cur_idx + self.chunck_size, q_len)],
past_key_value,
output_attentions,
use_cache,
- cache_position[cur_idx:min(cur_idx + chunck_size, q_len)],
+ cache_position[cur_idx:min(cur_idx + self.chunck_size, q_len)],
**kwargs
)
- cur_idx += chunck_size
+ cur_idx += self.chunck_size
if attn_output is None:
attn_output = cur_output
else:
attn_output = torch.cat((attn_output, cur_output), dim=-2)
return attn_output, None, past_key_value
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+
+
+class KLlamaAttention(BaseInjectedModule):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self,
+ key: str,
+ gguf_loader : GGUFLoader,
+ config: PretrainedConfig,
+ orig_module: nn.Module,
+ device: str = "cuda",
+ **kwargs):
+ BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, device, **kwargs)
+ self.orig_module.__init__(orig_module.config,
+ orig_module.layer_idx)
+ def apply_rotary_pos_emb(self, q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding to the query and key tensors.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ position_ids (`torch.Tensor`, *optional*):
+ Deprecated and unused.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ bsz, q_len, _ = hidden_states.size()
+
+ if self.config.pretraining_tp > 1:
+ key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
+ query_slices = self.q_proj.weight.split(
+ (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
+ )
+ key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
+ value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
+
+ query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
+ query_states = torch.cat(query_states, dim=-1)
+
+ key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
+ key_states = torch.cat(key_states, dim=-1)
+
+ value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
+ value_states = torch.cat(value_states, dim=-1)
+
+ else:
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ if position_embeddings is None:
+
+ logger.warning(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ else:
+ cos, sin = position_embeddings
+ query_states, key_states = self.apply_rotary_pos_emb(query_states, key_states, cos, sin)
+ if q_len == 1:
+ position_ids = position_ids[0][-1].unsqueeze(0).unsqueeze(0)
+ query_states = query_states[:, :, -1:]
+ key_states = key_states[:, :, -1:]
+
+ attn_output = KLlamaModel.dynamic_sdpa.apply(
+ self.layer_idx,
+ bsz,
+ position_ids[0][0],
+ query_states.transpose(1, 2).to(torch.float16),
+ key_states.transpose(1, 2).to(torch.float16),
+ value_states.transpose(1, 2).to(torch.float16),
+ mode="prefill" if q_len > 1 else "generate",
+ )
+
+
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ attn_output = attn_output.reshape(bsz, q_len, -1)
+
+ if self.config.pretraining_tp > 1:
+ attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
+ o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
+ attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
+ else:
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
\ No newline at end of file
diff --git a/ktransformers/operators/cpuinfer.py b/ktransformers/operators/cpuinfer.py
index 027cc8b..74b8fa8 100644
--- a/ktransformers/operators/cpuinfer.py
+++ b/ktransformers/operators/cpuinfer.py
@@ -1,18 +1,746 @@
+#!/usr/bin/env python
+# coding=utf-8
+"""
+Description : This script defines the `CPUInferKVCache` and `CPUInfer` classes for performing inference
+ with a Key-Value Cache on the CPU. The `CPUInferKVCache` class is responsible for configuring
+ and managing key-value caches, updating and retrieving cache data, and handling attention
+ operations. It supports different cache types (e.g., Q4_0, FP16) and retrieval strategies
+ (e.g., shared, separate). The `CPUInfer` class handles task submission and synchronization
+ on the CPU, with optional CUDA stream integration for tasks involving GPU acceleration.
+ These classes facilitate efficient caching and memory management for deep learning models
+ that leverage key-value attention mechanisms, particularly on CPU-based systems.
+Author : djw
+Date : 2024-08-26 23:25:24
+Version : 1.0.0
+LastEditors : djw
+LastEditTime : 2024-08-26 23:25:24
+Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
+"""
import sys, os
from typing import Any
+import torch
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build"))
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build", "Release"))
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build", "Debug"))
import cpuinfer_ext
from ktransformers.server.config.config import Config
+
+
+class CPUInferKVCache:
+ def __init__(
+ self,
+ layer_num: int = 32,
+ kv_head_num: int = 8,
+ q_head_num: int = 32,
+ head_dim: int = 128,
+ block_len: int = 256,
+ anchor_num: int = 4,
+ anchor_type: str = "FIXED",
+ kv_type: str = "Q4_0",
+ retrieval_type: str = "SHARED",
+ layer_step: int = 1,
+ token_step: int = 1,
+ layer_offset: int = 0,
+ max_thread_num: int = 32,
+ max_batch_size: int = 4,
+ max_block_num: int = 512,
+ ):
+
+ if anchor_type == "FIXED":
+ anchor_type = cpuinfer_ext.kvcache.AnchorType.FIXED
+ elif anchor_type == "QUEST":
+ anchor_type = cpuinfer_ext.kvcache.AnchorType.QUEST
+ elif anchor_type == "DYNAMIC":
+ anchor_type = cpuinfer_ext.kvcache.AnchorType.DYNAMIC
+ elif anchor_type == "BLOCK_MEAN":
+ anchor_type = cpuinfer_ext.kvcache.AnchorType.BLOCK_MEAN
+ elif anchor_type == "BLOCK_MAX":
+ anchor_type = cpuinfer_ext.kvcache.AnchorType.BLOCK_MAX
+ else:
+ raise ValueError(f"Unknown anchor type: {anchor_type}")
+
+ if kv_type == "FP16":
+ kv_type = cpuinfer_ext.kvcache.ggml_type.FP16
+ elif kv_type == "FP32":
+ assert False, "FP32 is not supported yet."
+ kv_type = cpuinfer_ext.kvcache.ggml_type.FP32
+ elif kv_type == "Q4_0":
+ kv_type = cpuinfer_ext.kvcache.ggml_type.Q4_0
+ elif kv_type == "Q8_0":
+ kv_type = cpuinfer_ext.kvcache.ggml_type.Q8_0
+ else:
+ raise ValueError(f"Unknown kv type: {kv_type}")
+
+ if retrieval_type == "SHARED":
+ retrieval_type = cpuinfer_ext.kvcache.RetrievalType.LAYER
+ elif retrieval_type == "INDIVIDUAL":
+ retrieval_type = cpuinfer_ext.kvcache.RetrievalType.QHEAD
+ elif retrieval_type == "SEPARATE":
+ retrieval_type = cpuinfer_ext.kvcache.RetrievalType.KVHEAD
+
+ self.config = cpuinfer_ext.kvcache.KVCacheConfig(
+ layer_num,
+ kv_head_num,
+ q_head_num,
+ head_dim,
+ block_len,
+ anchor_num,
+ anchor_type,
+ kv_type,
+ retrieval_type,
+ layer_step,
+ token_step,
+ layer_offset,
+ max_block_num,
+ max_batch_size,
+ max_thread_num,
+ )
+ self.kvcache = cpuinfer_ext.kvcache.KVCache(self.config)
+
+ def load_kvcache(self, tensor_file_path: str):
+ if not os.path.exists(tensor_file_path):
+ raise FileNotFoundError(f"The file {tensor_file_path} does not exist.")
+ return self.kvcache.load_kvcache(tensor_file_path,)
+
+ def dump_kvcache(
+ self, block_table: torch.Tensor, cache_total_len: int, tensor_file_path: str
+ ):
+ assert (
+ block_table.dim() == 1
+ and block_table.dtype == torch.int
+ and block_table.is_contiguous()
+ and block_table.device == torch.device("cpu")
+ ), "block_table dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format(
+ block_table.dim(),
+ block_table.size(),
+ block_table.dtype,
+ block_table.is_contiguous(),
+ block_table.device,
+ )
+
+ assert (
+ cache_total_len > 0
+ and cache_total_len <= self.config.block_len * block_table.size(0)
+ ), "cache_total_len: {}".format(cache_total_len)
+
+ if not os.path.exists(os.path.dirname(tensor_file_path)):
+ os.makedirs(os.path.dirname(tensor_file_path))
+
+ return self.kvcache.dump_kvcache(
+ block_table.data_ptr(),
+ cache_total_len,
+ tensor_file_path,
+ )
+
+ def update_cache_total_len(self, cache_total_len: int):
+ assert cache_total_len > 0, "cache_total_len: {}".format(cache_total_len)
+ self.kvcache.update_cache_total_len(cache_total_len)
+
+ # q_in: (bsz, q_len, q_head_num, head_dim)
+ # output: (bsz, q_len, q_head_num, head_dim)
+ # attn_lse: (bsz, q_len, q_head_num)
+ # block_table: (bsz, max_block_num)
+ def attn(
+ self,
+ q_in: torch.Tensor,
+ output: torch.Tensor,
+ attn_lse: torch.Tensor,
+ layer_idx: int,
+ generate_token_idx: int,
+ block_table: torch.Tensor | None = None,
+ cache_seqlens: torch.Tensor | None = None,
+ pick_block_num: int | None = None,
+ init_block_num: int | None = None,
+ local_block_num: int | None = None,
+ ):
+
+ assert (
+ q_in.dim() == 4
+ and q_in.size(2) == self.config.q_head_num
+ and q_in.size(3) == self.config.head_dim
+ and q_in.dtype == torch.float16
+ and q_in.is_contiguous()
+ and q_in.device == torch.device("cpu")
+ ), "q_in dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format(
+ q_in.dim(), q_in.size(), q_in.dtype, q_in.is_contiguous(), q_in.device
+ )
+
+ batch_size = q_in.size(0)
+ q_len = q_in.size(1)
+
+ assert (block_table is None) or (
+ block_table.dim() == 2
+ and block_table.size(0) == batch_size
+ and block_table.dtype == torch.int
+ and block_table.is_contiguous()
+ and block_table.device == torch.device("cpu")
+ ), "block_table dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format(
+ block_table.dim(),
+ block_table.size(),
+ block_table.dtype,
+ block_table.is_contiguous(),
+ block_table.device,
+ )
+
+ max_block_num = block_table.size(1) if block_table is not None else 0
+
+ assert (
+ output.dim() == 4
+ and output.size(0) == batch_size
+ and output.size(2) == self.config.q_head_num
+ and output.size(1) == q_len
+ and output.size(3) == self.config.head_dim
+ and output.dtype == torch.float16
+ and output.is_contiguous()
+ and output.device == torch.device("cpu")
+ ), "output dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format(
+ output.dim(),
+ output.size(),
+ output.dtype,
+ output.is_contiguous(),
+ output.device,
+ )
+
+ assert (
+ attn_lse.dim() == 3
+ and attn_lse.size(0) == batch_size
+ and attn_lse.size(1) == q_len
+ and attn_lse.size(2) == self.config.q_head_num
+ and attn_lse.dtype == torch.float32
+ and attn_lse.is_contiguous()
+ and attn_lse.device == torch.device("cpu")
+ ), "attn_lse dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format(
+ attn_lse.dim(),
+ attn_lse.size(),
+ attn_lse.dtype,
+ attn_lse.is_contiguous(),
+ attn_lse.device,
+ )
+
+ assert (
+ layer_idx >= 0 and layer_idx < self.config.layer_num
+ ), "layer_idx: {}".format(layer_idx)
+
+ assert (cache_seqlens is None) or (
+ cache_seqlens.dim() == 1
+ and cache_seqlens.size(0) == batch_size
+ and cache_seqlens.dtype == torch.int
+ and cache_seqlens.is_contiguous()
+ and cache_seqlens.device == torch.device("cpu")
+ ), "cache_seqlens dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format(
+ cache_seqlens.dim(),
+ cache_seqlens.size(),
+ cache_seqlens.dtype,
+ cache_seqlens.is_contiguous(),
+ cache_seqlens.device,
+ )
+
+ return self.kvcache.attn(
+ q_in.data_ptr(),
+ output.data_ptr(),
+ attn_lse.data_ptr(),
+ layer_idx,
+ generate_token_idx,
+ q_len,
+ batch_size,
+ max_block_num,
+ block_table.data_ptr() if block_table is not None else 0,
+ cache_seqlens.data_ptr() if cache_seqlens is not None else 0,
+ pick_block_num,
+ init_block_num,
+ local_block_num,
+ )
+
+ # k_in: (block_len, kv_head_num, head_dim)
+ # v_in: (block_len, kv_head_num, head_dim)
+ def update_kvcache_one_block_fp16(
+ self, k_in: torch.Tensor, v_in: torch.Tensor, layer_id: int, block_idx: int
+ ):
+ assert (
+ k_in.dim() == 3
+ and k_in.size(1) == self.config.block_len
+ and k_in.size(0) == self.config.kv_head_num
+ and k_in.size(2) == self.config.head_dim
+ and k_in.dtype == torch.float16
+ and k_in.is_contiguous()
+ and k_in.device == torch.device("cpu")
+ ), "k_in dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format(
+ k_in.dim(), k_in.size(), k_in.dtype, k_in.is_contiguous(), k_in.device
+ )
+ assert (
+ v_in.dim() == 3
+ and v_in.size(1) == self.config.block_len
+ and v_in.size(0) == self.config.kv_head_num
+ and v_in.size(2) == self.config.head_dim
+ and v_in.dtype == torch.float16
+ and v_in.is_contiguous()
+ and v_in.device == torch.device("cpu")
+ ), "v_in dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format(
+ v_in.dim(), v_in.size(), v_in.dtype, v_in.is_contiguous(), v_in.device
+ )
+ assert (
+ layer_id >= 0 and layer_id < self.config.layer_num
+ ), "layer_id: {}".format(layer_id)
+ assert block_idx >= 0, "block_idx: {}".format(block_idx)
+ return self.kvcache.update_one_block_fp16(
+ k_in.data_ptr(),
+ v_in.data_ptr(),
+ layer_id,
+ block_idx,
+ )
+
+ def get_kvcache_one_block_fp16(
+ self, k_in: torch.Tensor, v_in: torch.Tensor, layer_id: int, block_idx: int
+ ):
+ assert (
+ k_in.dim() == 3
+ and k_in.size(1) == self.config.block_len
+ and k_in.size(0) == self.config.kv_head_num
+ and k_in.size(2) == self.config.head_dim
+ and k_in.dtype == torch.float16
+ and k_in.is_contiguous()
+ and k_in.device == torch.device("cpu")
+ ), "k_in dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format(
+ k_in.dim(), k_in.size(), k_in.dtype, k_in.is_contiguous(), k_in.device
+ )
+ assert (
+ v_in.dim() == 3
+ and v_in.size(1) == self.config.block_len
+ and v_in.size(0) == self.config.kv_head_num
+ and v_in.size(2) == self.config.head_dim
+ and v_in.dtype == torch.float16
+ and v_in.is_contiguous()
+ and v_in.device == torch.device("cpu")
+ ), "v_in dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format(
+ v_in.dim(), v_in.size(), v_in.dtype, v_in.is_contiguous(), v_in.device
+ )
+ assert (
+ layer_id >= 0 and layer_id < self.config.layer_num
+ ), "layer_id: {}".format(layer_id)
+ assert block_idx >= 0, "block_idx: {}".format(block_idx)
+ return self.kvcache.get_one_block_fp16(
+ k_in.data_ptr(),
+ v_in.data_ptr(),
+ layer_id,
+ block_idx,
+ )
+
+ def update_importance_one_block(
+ self, importance: torch.Tensor, layer_id: int, block_idx: int
+ ):
+ assert (
+ importance.dim() == 1
+ and importance.size(0) == self.config.block_len
+ and importance.dtype == torch.float16
+ and importance.is_contiguous()
+ and importance.device == torch.device("cpu")
+ ), "importance dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format(
+ importance.dim(),
+ importance.size(),
+ importance.dtype,
+ importance.is_contiguous(),
+ importance.device,
+ )
+ assert (
+ layer_id >= 0 and layer_id < self.config.layer_num
+ ), "layer_id: {}".format(layer_id)
+ assert block_idx >= 0, "block_idx: {}".format(block_idx)
+ return self.kvcache.update_importance_one_block(
+ importance.data_ptr(),
+ layer_id,
+ block_idx,
+ )
+
+ def get_importance_one_block(
+ self, importance: torch.Tensor, layer_id: int, block_idx: int
+ ):
+ assert (
+ importance.dim() == 1
+ and importance.size(0) == self.config.block_len
+ and importance.dtype == torch.float16
+ and importance.is_contiguous()
+ and importance.device == torch.device("cpu")
+ ), "importance dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format(
+ importance.dim(),
+ importance.size(),
+ importance.dtype,
+ importance.is_contiguous(),
+ importance.device,
+ )
+ assert (
+ layer_id >= 0 and layer_id < self.config.layer_num
+ ), "layer_id: {}".format(layer_id)
+ assert block_idx >= 0, "block_idx: {}".format(block_idx)
+ return self.kvcache.get_importance_one_block(
+ importance.data_ptr(),
+ layer_id,
+ block_idx,
+ )
+
+ def get_anchor_one_block(self, anchor: torch.Tensor, layer_id: int, block_idx: int):
+ assert (
+ anchor.dim() == 3
+ and anchor.size(0) == self.config.kv_head_num
+ and anchor.size(1) == self.config.anchor_num
+ and anchor.size(2) == self.config.head_dim
+ and anchor.dtype == torch.float16
+ and anchor.is_contiguous()
+ and anchor.device == torch.device("cpu")
+ ), "anchor dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format(
+ anchor.dim(),
+ anchor.size(),
+ anchor.dtype,
+ anchor.is_contiguous(),
+ anchor.device,
+ )
+ assert (
+ layer_id >= 0 and layer_id < self.config.layer_num
+ ), "layer_id: {}".format(layer_id)
+ assert block_idx >= 0, "block_idx: {}".format(block_idx)
+ return self.kvcache.get_anchor_one_block(
+ anchor.data_ptr(),
+ layer_id,
+ block_idx,
+ )
+
+ def update_anchor_one_block(
+ self, anchor: torch.Tensor, layer_id: int, block_idx: int
+ ):
+ assert (
+ anchor.dim() == 3
+ and anchor.size(0) == self.config.kv_head_num
+ and anchor.size(1) == self.config.anchor_num
+ and anchor.size(2) == self.config.head_dim
+ and anchor.dtype == torch.float16
+ and anchor.is_contiguous()
+ and anchor.device == torch.device("cpu")
+ ), "anchor dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format(
+ anchor.dim(),
+ anchor.size(),
+ anchor.dtype,
+ anchor.is_contiguous(),
+ anchor.device,
+ )
+ assert (
+ layer_id >= 0 and layer_id < self.config.layer_num
+ ), "layer_id: {}".format(layer_id)
+ assert block_idx >= 0, "block_idx: {}".format(block_idx)
+ return self.kvcache.update_anchor_one_block(
+ anchor.data_ptr(),
+ layer_id,
+ block_idx,
+ )
+
+ def calc_anchor_all_layers(
+ self,
+ block_table: torch.Tensor,
+ cache_seqlens: torch.Tensor,
+ ):
+ assert (
+ block_table.dim() == 2
+ and block_table.size(0) == cache_seqlens.size(0)
+ and block_table.dtype == torch.int
+ and block_table.is_contiguous()
+ and block_table.device == torch.device("cpu")
+ ), "block_table dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format(
+ block_table.dim(),
+ block_table.size(),
+ block_table.dtype,
+ block_table.is_contiguous(),
+ block_table.device,
+ )
+ assert (
+ cache_seqlens.dim() == 1
+ and cache_seqlens.dtype == torch.int
+ and cache_seqlens.is_contiguous()
+ and cache_seqlens.device == torch.device("cpu")
+ ), "cache_seqlens dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format(
+ cache_seqlens.dim(),
+ cache_seqlens.size(),
+ cache_seqlens.dtype,
+ cache_seqlens.is_contiguous(),
+ cache_seqlens.device,
+ )
+ batch_size = block_table.size(0)
+ max_block_num = block_table.size(1)
+ return self.kvcache.calc_anchor_all_layers(
+ block_table.data_ptr(),
+ cache_seqlens.data_ptr(),
+ batch_size,
+ max_block_num,
+ )
+
+ def clear_importance_all_layers(
+ self,
+ block_table: torch.Tensor,
+ cache_seqlens: torch.Tensor,
+ ):
+ assert (
+ block_table.dim() == 2
+ and block_table.size(0) == cache_seqlens.size(0)
+ and block_table.dtype == torch.int
+ and block_table.is_contiguous()
+ and block_table.device == torch.device("cpu")
+ ), "block_table dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format(
+ block_table.dim(),
+ block_table.size(),
+ block_table.dtype,
+ block_table.is_contiguous(),
+ block_table.device,
+ )
+ assert (
+ cache_seqlens.dim() == 1
+ and cache_seqlens.dtype == torch.int
+ and cache_seqlens.is_contiguous()
+ and cache_seqlens.device == torch.device("cpu")
+ ), "cache_seqlens dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format(
+ cache_seqlens.dim(),
+ cache_seqlens.size(),
+ cache_seqlens.dtype,
+ cache_seqlens.is_contiguous(),
+ cache_seqlens.device,
+ )
+ batch_size = block_table.size(0)
+ max_block_num = block_table.size(1)
+ return self.kvcache.clear_importance_all_layers(
+ block_table.data_ptr(),
+ cache_seqlens.data_ptr(),
+ batch_size,
+ max_block_num,
+ )
+
+ def get_cache_total_len(self):
+ return self.kvcache.get_cache_total_len()
+
+ def update_kvcache_q4(
+ self,
+ k_in: torch.Tensor,
+ k_scales: torch.Tensor,
+ v_in: torch.Tensor,
+ v_scales: torch.Tensor,
+ layer_id: int,
+ seq_offset: int | None = None,
+ seq_len: int | None = None,
+ block_table: torch.Tensor | None = None,
+ ):
+ raise NotImplementedError
+
+ def update_kvcache_fp16(
+ self,
+ k_in: torch.Tensor,
+ v_in: torch.Tensor,
+ layer_idx,
+ block_table: torch.Tensor,
+ max_block_num,
+ past_len: torch.Tensor,
+ q_len,
+ ):
+ batch_size = block_table.size(0)
+ return self.kvcache.get_kvcache_fp16(
+ k_in.data_ptr(),
+ v_in.data_ptr(),
+ layer_idx,
+ block_table.data_ptr(),
+ batch_size,
+ max_block_num,
+ past_len.data_ptr(),
+ q_len
+ )
+
+ def get_kvcache_q4(
+ self,
+ k_in: torch.Tensor,
+ k_scales: torch.Tensor,
+ v_in: torch.Tensor,
+ v_scales: torch.Tensor,
+ layer_id: int,
+ seq_offset: int | None = None,
+ seq_len: int | None = None,
+ block_table: torch.Tensor | None = None,
+ ):
+ raise NotImplementedError
+
+ def get_kvcache_fp16(
+ self,
+ k_in: torch.Tensor,
+ v_in: torch.Tensor,
+ layer_id: int,
+ layer_idx,
+ block_table: torch.Tensor,
+ max_block_num,
+ past_len: torch.Tensor,
+ ):
+ batch_size = block_table.size(0)
+ return self.kvcache.get_kvcache_fp16(
+ k_in.data_ptr(),
+ v_in.data_ptr(),
+ layer_idx,
+ block_table.data_ptr(),
+ batch_size,
+ max_block_num,
+ past_len.data_ptr(),
+ )
+
+ def get_and_update_kvcache_fp16(
+ self,
+ k_cache_cpu: torch.Tensor,
+ v_cache_cpu: torch.Tensor,
+ layer_idx,
+ block_table: torch.Tensor,
+ max_block_num,
+ past_len: torch.Tensor,
+ q_len,
+ ):
+ batch_size = block_table.size(0)
+ return self.kvcache.get_and_update_kvcache_fp16(
+ k_cache_cpu.data_ptr(),
+ v_cache_cpu.data_ptr(),
+ layer_idx,
+ block_table.data_ptr(),
+ batch_size,
+ max_block_num,
+ past_len.data_ptr(),
+ q_len,
+ )
+
+ def update_importance(
+ self,
+ importance_cache: torch.Tensor,
+ layer_idx,
+ block_table: torch.Tensor,
+ max_block_num,
+ offset: torch.Tensor,
+ width,
+ ):
+ batch_size = block_table.size(0)
+ return self.kvcache.update_importance(
+ importance_cache.data_ptr(),
+ layer_idx,
+ block_table.data_ptr(),
+ batch_size,
+ max_block_num,
+ offset.data_ptr(),
+ width,
+ )
+
+ # attn_sparsity: ((bsz, q_len, q_head_num), dtype = torch.float32)
+ def get_attn_sparsity(
+ self,
+ q_in: torch.Tensor,
+ attn_sparsity: torch.Tensor,
+ layer_idx: int,
+ block_table: torch.Tensor,
+ cache_seqlens: torch.Tensor,
+ block_table_origin: torch.Tensor,
+ cache_seqlens_origin: torch.Tensor,
+ generate_token_idx: int = 0,
+ topk: int | None = None,
+ local: int | None = None,
+ ):
+ batch_size = block_table.size(0)
+ max_block_num = block_table.size(1)
+ max_block_num_origin = block_table_origin.size(1)
+ q_len = q_in.size(1)
+
+ if topk is None or local is None or topk + local >= max_block_num:
+ topk = -1
+ local = -1
+ return self.kvcache.get_attn_sparsity(
+ q_in.data_ptr(),
+ attn_sparsity.data_ptr(),
+ layer_idx,
+ generate_token_idx,
+ q_len,
+ batch_size,
+ max_block_num,
+ block_table.data_ptr(),
+ cache_seqlens.data_ptr(),
+ block_table_origin.data_ptr(),
+ cache_seqlens_origin.data_ptr(),
+ max_block_num_origin,
+ topk,
+ local,
+ )
+
+ def attn_with_kvcache(
+ self,
+ q_in: torch.Tensor,
+ k_in: torch.Tensor,
+ v_in: torch.Tensor,
+ output: torch.Tensor,
+ attn_lse: torch.Tensor,
+ layer_idx: int,
+ block_table: torch.Tensor,
+ cache_seqlens: torch.Tensor,
+ generate_token_idx: int = 0,
+ topk: int | None = None,
+ local: int | None = None,
+ ):
+
+ batch_size = block_table.size(0)
+ max_block_num = block_table.size(1)
+ q_len = q_in.size(1)
+
+ if topk is None or local is None or topk + local >= max_block_num:
+ topk = -1
+ local = -1
+ return self.kvcache.attn_with_kvcache(
+ q_in.data_ptr(),
+ k_in.data_ptr(),
+ v_in.data_ptr(),
+ output.data_ptr(),
+ attn_lse.data_ptr(),
+ layer_idx,
+ generate_token_idx,
+ q_len,
+ batch_size,
+ max_block_num,
+ block_table.data_ptr(),
+ cache_seqlens.data_ptr(),
+ topk,
+ local,
+ )
+
+ def get_all_kvcache_one_layer(
+ self, k_in: torch.Tensor, v_in: torch.Tensor, layer_id: int
+ ):
+ return self.kvcache.get_all_kvcache_one_layer(
+ k_in.data_ptr(),
+ v_in.data_ptr(),
+ layer_id,
+ )
+
+ def get_importance(
+ self,
+ importance: torch.Tensor,
+ block_table: torch.Tensor,
+ ):
+ raise NotImplementedError
+
+ def get_anchor(
+ self,
+ anchor: torch.Tensor,
+ block_table: torch.Tensor,
+ ):
+ raise NotImplementedError
+
+
class CPUInfer:
- cpu_infer = None
- def __init__(self, cpu_infer:int = Config().cpu_infer):
- if CPUInfer.cpu_infer is None:
- CPUInfer.cpu_infer = cpuinfer_ext.CPUInfer(cpu_infer)
+ cpuinfer = None
+ def __init__(self, thread_num):
+ CPUInfer.cpuinfer = cpuinfer_ext.CPUInfer(thread_num)
+
+ def submit(self, task):
+ CPUInfer.cpuinfer.submit(task)
+
+ def submit_with_cuda_stream(self, current_cuda_stream, task):
+ CPUInfer.cpuinfer.submit_with_cuda_stream(current_cuda_stream, task)
+
+ def sync(self):
+ CPUInfer.cpuinfer.sync()
+
+ def sync_with_cuda_stream(self, current_cuda_stream):
+ CPUInfer.cpuinfer.sync_with_cuda_stream(current_cuda_stream)
+
+
- def __getattribute__(self, __name: str) -> Any:
- return CPUInfer.cpu_infer.__getattribute__(__name)
-
- def __setattr__(self, __name: str, __value: Any) -> None:
- return CPUInfer.cpu_infer.__setattr__(__name, __value)
\ No newline at end of file
diff --git a/ktransformers/operators/dynamic_attention.py b/ktransformers/operators/dynamic_attention.py
new file mode 100644
index 0000000..13a74b4
--- /dev/null
+++ b/ktransformers/operators/dynamic_attention.py
@@ -0,0 +1,775 @@
+#!/usr/bin/env python
+# coding=utf-8
+"""
+Description :
+Author : Jianwei Dong
+Date : 2024-08-26 23:25:24
+Version : 1.0.0
+LastEditors : Jianwei Dong
+LastEditTime : 2024-08-26 23:25:24
+Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
+"""
+
+import torch
+from transformers import AutoConfig
+import sys, os
+import logging
+logger = logging.getLogger("dynamic_attention")
+sys.path.append(os.path.dirname(__file__) + "/../ktransformers_ext/cpu_backend")
+from ktransformers.operators.cpuinfer import CPUInfer, CPUInferKVCache
+from flash_attn import flash_attn_func, flash_attn_with_kvcache
+
+
+import math
+import json
+
+
+class DynamicScaledDotProductAttention:
+ remaining_length: int
+
+ def __init__(
+ self,
+ max_seq_len: int,
+ block_size: int,
+ config: AutoConfig,
+ device: torch.device,
+ local_windows_len: int,
+ topk: int,
+ threads_num: int,
+ anchor_type: str = "DYNAMIC",
+ kv_type: str = "FP16",
+ dense_layer_num: int = 0,
+ anchor_num: int = 1,
+ block_selection_mode: str = "SHARED",
+ layer_step: int = 1,
+ token_step: int = 1,
+ preselect_block: bool = False,
+ preselect_block_count: int = 96,
+ prefill_chunk_size: int = 20480,
+ use_attn_sparsity: bool = False,
+ ):
+ # assert anchor_num == 1
+ # assert anchor_type == "DYNAMIC"
+ self.remaining_length = 0
+ valid_anchor_types = ["DYNAMIC", "FIXED", "BLOCK_MEAN", "BLOCK_MAX", "QUEST"]
+ assert anchor_type in valid_anchor_types
+ if anchor_type == "QUEST":
+ assert anchor_num == 2
+ elif anchor_type != "FIXED" and anchor_type != "DYNAMIC":
+ assert anchor_num == 1
+
+ valid_kv_types = ["FP16", "FP32", "Q4_0", "Q8_0"]
+ assert kv_type in valid_kv_types
+ if kv_type != "FP16" and kv_type != "FP32":
+ assert block_size % 32 == 0
+
+ valid_block_selection_modes = ["SHARED", "SEPARATE"] # individual
+ assert block_selection_mode in valid_block_selection_modes
+
+ self.max_seq_len = max_seq_len
+ self.block_num = max_seq_len // block_size
+ self.block_size = block_size
+ self.anchor_type = anchor_type
+ self.kv_type = kv_type
+ self.anchor_num = anchor_num
+ self.threads_num = threads_num
+ self.layer_step = layer_step
+ self.token_step = token_step
+ self.preselect_block = preselect_block
+ self.preselect_block_count = preselect_block_count
+ self.block_selection_mode = block_selection_mode
+ self.use_attn_sparsity = use_attn_sparsity
+
+ # model config
+ self.kv_head_num = config.num_key_value_heads
+ self.q_head_num = config.num_attention_heads
+ self.head_dim = config.hidden_size // config.num_attention_heads
+ self.layer_num = config.num_hidden_layers
+
+ self.device = device
+ self.local_windows_len = local_windows_len
+ self.local_block_num = self.local_windows_len // self.block_size + 1
+ self.prefill_chunk_size = prefill_chunk_size
+
+ self.topk = topk
+ self.dense_layer_num = dense_layer_num
+ # self.dense_layer_num = 32
+ self.cache_key_states = torch.zeros(
+ (self.block_num, block_size, self.kv_head_num, self.head_dim),
+ device=device,
+ dtype=torch.float16,
+ )
+ self.cache_value_states = torch.zeros(
+ (self.block_num, block_size, self.kv_head_num, self.head_dim),
+ device=device,
+ dtype=torch.float16,
+ )
+ # [max_num_block, block_size, head_num]
+ self.cache_importance = torch.zeros(
+ (self.block_num, block_size, self.q_head_num),
+ device=device,
+ dtype=torch.float16,
+ )
+
+ # key_states: [bsz, q_len, kv_head_num, head_dim]
+ # value_states: [bsz, q_len, kv_head_num, head_dim]
+ # query_states: [bsz, q_len, q_head_num, head_dim]
+ self.q_in_cpu = torch.zeros(
+ (1, 1, self.q_head_num, self.head_dim),
+ device="cpu",
+ dtype=torch.float16,
+ pin_memory=True,
+ )
+ self.k_in_cpu = torch.zeros(
+ (1, 1, self.kv_head_num, self.head_dim),
+ device="cpu",
+ dtype=torch.float16,
+ pin_memory=True,
+ )
+ self.v_in_cpu = torch.zeros(
+ (1, 1, self.kv_head_num, self.head_dim),
+ device="cpu",
+ dtype=torch.float16,
+ pin_memory=True,
+ )
+
+ self.cache_seqlens_cpu = torch.empty(
+ (1,), device="cpu", dtype=torch.int32, pin_memory=True
+ )
+
+ self.cache_seqlens_cuda = torch.empty((1,), device=device, dtype=torch.int32)
+
+ self.prefix_block_table = torch.arange(
+ self.block_num, device="cpu", dtype=torch.int32, pin_memory=True
+ ).view(1, -1)
+
+ self.block_table_cpu = torch.arange(
+ self.block_num, device="cpu", dtype=torch.int32, pin_memory=True
+ ).view(1, -1)
+
+ # assert (
+ # self.local_windows_len // self.block_size + 1 + self.preselect_block_count
+ # <= self.block_num
+ # )
+
+ self.output_cpu = torch.empty(
+ (1, 1, self.q_head_num, self.head_dim),
+ device="cpu",
+ dtype=torch.float16,
+ pin_memory=True,
+ )
+ self.lse_cpu = torch.empty(
+ (1, 1, self.q_head_num), device="cpu", dtype=torch.float32, pin_memory=True
+ )
+
+ self.output_cuda = torch.empty(
+ (1, 1, self.q_head_num, self.head_dim), device=device, dtype=torch.float16
+ )
+
+ self.attn_sparsity = torch.zeros(
+ (1, 1, self.q_head_num), device="cpu", dtype=torch.float32, pin_memory=True
+ )
+
+ if preselect_block == True:
+ self.preselect_block_table = torch.zeros(
+ self.layer_num,
+ self.preselect_block_count,
+ device=device,
+ dtype=torch.int32,
+ )
+ self.preselect_block_num = 0 # block_num before preselect
+ self.evict_tokens = 0
+
+ self.cpu_infer = CPUInfer(threads_num)
+ self.local_thread = CPUInferKVCache(
+ self.layer_num,
+ self.kv_head_num,
+ self.q_head_num,
+ self.head_dim,
+ self.block_size,
+ anchor_num=self.anchor_num,
+ anchor_type=anchor_type,
+ kv_type=self.kv_type,
+ retrieval_type=self.block_selection_mode,
+ layer_step=self.layer_step,
+ token_step=self.token_step,
+ layer_offset=self.dense_layer_num % self.layer_step,
+ max_batch_size=1,
+ max_block_num=self.block_num,
+ max_thread_num=self.threads_num,
+ )
+
+ print(
+ f"local_windows_len: {local_windows_len}, topk: {topk}, dense_layer_num: {dense_layer_num}, kv_type: {self.kv_type}, anchor_type: {self.anchor_type}, preselect_block: {self.preselect_block}, preselect_block_count: {self.preselect_block_count}, token_step: {self.token_step}, layer_step: {self.layer_step}"
+ )
+
+ self.shape_mask = (
+ self.q_head_num,
+ self.block_size,
+ self.block_size,
+ )
+
+ mask = torch.zeros(
+ self.shape_mask, dtype=torch.uint8, device=device
+ ).contiguous()
+ elm_idx = torch.arange(self.block_size, device=device)
+
+ for i in range(mask.size(-2)):
+ idx = i + mask.size(-1) - mask.size(-2) - elm_idx
+ idx = idx[idx >= 0]
+ mask[..., i, idx] = 1
+
+ self.tril_mask = mask
+ self.triu_mask = mask ^ 1
+
+ self.generate_token_idx = 0
+
+ def get_attn_score_one_block(
+ self,
+ batch_idx: int,
+ max_block_num: int,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ offset: int,
+ width: int,
+ mask_mode: str | None = None,
+ use_softmax: bool = True,
+ ):
+ n_rep = self.q_head_num // self.kv_head_num
+ importance = self.cache_importance.view(-1, self.q_head_num)
+ importance = importance.narrow(0, batch_idx * max_block_num + offset, width)
+ n_gqa_ = self.q_head_num // self.kv_head_num
+ for head_idx in range(self.q_head_num):
+ key_item = key[..., head_idx // n_gqa_, :].view(key.size(0), -1)
+ qk = torch.einsum(
+ "qd,kd->qk", query[:,head_idx,:], key_item
+ ) # (num_attention_heads, len_q, len_k)
+
+ if mask_mode == "tril":
+ mask = self.tril_mask
+ mask = mask[0, -qk.size(-2) :, -qk.size(-1) :]
+ qk = qk * mask
+ elif mask_mode == "triu":
+ mask = self.triu_mask
+ mask = mask[0, -qk.size(-2) :, -qk.size(-1) :]
+ qk = qk * mask
+
+ if use_softmax:
+ qk = torch.nn.functional.softmax(
+ qk / math.sqrt(self.head_dim), dim=-1, dtype=torch.float32
+ ).to(torch.float16)
+
+ qk = torch.sum(qk, dim=-2)
+ importance[...,head_idx] += qk
+
+ def get_preselect_block_table_and_attn_score(
+ self,
+ layer_idx: int,
+ batch_size: int,
+ offset: torch.Tensor,
+ width: int,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ union_with_last_layer: bool = True,
+ ):
+ max_seqs_len = offset.max().item() + width
+ max_block_num = (max_seqs_len + self.block_size - 1) // self.block_size
+
+ for batch_idx in range(batch_size):
+ query_cur = query[batch_idx][-128:]
+ self.get_attn_score_one_block(
+ batch_idx,
+ max_block_num,
+ query_cur,
+ key[batch_idx][: offset[batch_idx].item() + width],
+ 0,
+ offset[batch_idx].item() + width,
+ mask_mode=None,
+ )
+
+ if self.preselect_block:
+ self.prefill_block_num = max(
+ 0, max_block_num - self.local_windows_len // self.block_size
+ )
+ self.evict_tokens = (
+ max(self.prefill_block_num - self.preselect_block_count, 0)
+ * self.block_size
+ )
+
+ if self.prefill_block_num != 0:
+ importance_cache = self.cache_importance.narrow(
+ 0, 0, self.prefill_block_num * batch_size
+ ).view(
+ batch_size, self.prefill_block_num, self.block_size, self.q_head_num
+ )
+
+ importance_r = importance_cache[:, 1:, : self.block_size // 4]
+ pad_r = torch.zeros_like(importance_r[:, :1])
+ importance_r = torch.cat((importance_r, pad_r), dim=1)
+ importance_l = importance_cache[:, :-1, -self.block_size // 4 :]
+ pad_l = torch.zeros_like(importance_l[:, :1])
+ importance_l = torch.cat((pad_l, importance_l), dim=1)
+ importance = torch.cat(
+ (importance_l, importance_cache, importance_r), dim=2
+ )
+ importance = importance.mean(dim=-1)
+ importance = importance.mean(dim=-1)
+ # importance: (batch_size, max_block_num)
+ topk = min(self.preselect_block_count, self.prefill_block_num)
+ values, indices = torch.topk(
+ importance,
+ k=topk,
+ dim=1,
+ )
+
+ self.preselect_block_table[
+ layer_idx : layer_idx + 1,
+ :topk,
+ ].copy_(indices)
+
+ if union_with_last_layer and layer_idx == 31:
+ for tmp_layer_idx in range(self.layer_num - 1):
+ for i in range(1, min(topk, 6)):
+ x = self.preselect_block_table[-1, i]
+ if x not in self.preselect_block_table[tmp_layer_idx]:
+ self.preselect_block_table[tmp_layer_idx, topk - i] = x
+ if self.anchor_type == "DYNAMIC":
+ importance_cache = self.cache_importance.narrow(
+ 0, 0, max_block_num * batch_size
+ ).view(batch_size, max_block_num * self.block_size, self.q_head_num)
+ importance_cache_cpu = torch.empty_like(
+ importance_cache, device="cpu", pin_memory=True
+ )
+
+ importance_cache_cpu.copy_(importance_cache)
+
+ block_table_cpu = self.prefix_block_table[:, :max_block_num].to("cpu")
+ offset_cpu = offset.contiguous().to("cpu")
+
+ self.cpu_infer.submit(
+ self.local_thread.update_importance(
+ importance_cache_cpu,
+ layer_idx,
+ block_table_cpu,
+ max_block_num,
+ offset_cpu,
+ width,
+ )
+ )
+ self.cpu_infer.sync()
+
+ importance_cache = self.cache_importance.narrow(
+ 0, 0, max_block_num * batch_size
+ ).view(batch_size, max_block_num * self.block_size, self.q_head_num)
+ importance_cache.zero_()
+
+ # key: [bsz, past_len, head_num, head_dim] float16
+ # query: [bsz, q_len, q_head_num, head_dim] float16
+ def get_attn_score(
+ self,
+ layer_idx: int,
+ batch_size: int,
+ offset: torch.Tensor,
+ width: int,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ ):
+ max_seqs_len = offset.max().item() + width
+ max_block_num = (max_seqs_len + self.block_size - 1) // self.block_size
+
+ for batch_idx in range(batch_size):
+ for idx in range(width // self.block_size):
+ offset_cur = idx * self.block_size
+ query_cur = query[batch_idx, offset_cur : offset_cur + self.block_size]
+ self.get_attn_score_one_block(
+ batch_idx,
+ max_block_num,
+ query_cur,
+ key[
+ batch_idx,
+ offset[batch_idx]
+ + offset_cur : offset[batch_idx]
+ + offset_cur
+ + self.block_size,
+ ],
+ offset[batch_idx].item() + offset_cur,
+ self.block_size,
+ mask_mode="tril",
+ use_softmax=False,
+ )
+
+ offset_key = (
+ offset[batch_idx].item()
+ + idx * self.block_size
+ - self.local_windows_len
+ )
+ if offset_key >= 0:
+ self.get_attn_score_one_block(
+ batch_idx,
+ max_block_num,
+ query_cur,
+ key[batch_idx, offset_key : offset_key + self.block_size],
+ offset_key,
+ self.block_size,
+ mask_mode="triu",
+ use_softmax=False,
+ )
+
+ offset_key = max(0, offset_key + self.block_size)
+ width_key = (
+ offset[batch_idx].item() + idx * self.block_size - offset_key
+ )
+ if width_key > 0:
+ self.get_attn_score_one_block(
+ batch_idx,
+ max_block_num,
+ query_cur,
+ key[batch_idx, offset_key : offset_key + width_key],
+ offset_key,
+ width_key,
+ mask_mode=None,
+ use_softmax=False,
+ )
+
+ importance_cache = self.cache_importance.narrow(
+ 0, 0, max_block_num * batch_size
+ ).view(batch_size, max_block_num * self.block_size, self.q_head_num)
+ importance_cache_cpu = torch.empty_like(
+ importance_cache, device="cpu", pin_memory=True
+ )
+
+ importance_cache_cpu.copy_(importance_cache)
+
+ block_table_cpu = self.prefix_block_table[:, :max_block_num].to("cpu")
+ offset_cpu = offset.contiguous().to("cpu")
+
+ self.cpu_infer.submit(
+ self.local_thread.update_importance(
+ importance_cache_cpu,
+ layer_idx,
+ block_table_cpu,
+ max_block_num,
+ offset_cpu,
+ width,
+ )
+ )
+ self.cpu_infer.sync()
+ importance_cache.zero_()
+
+ # key: [bsz, q_len, head_num, head_dim] float16
+ # value: [bsz, q_len, head_num, head_dim] float16
+ def swap_in_and_swap_out(self, layer_idx, past_len, q_len, key, value):
+ batch_size = 1
+ max_seqs_len = past_len.max().item() + q_len
+ max_block_num = (max_seqs_len + self.block_size - 1) // self.block_size
+ k_cache = self.cache_key_states.narrow(0, 0, max_block_num * batch_size).view(
+ batch_size, max_block_num * self.block_size, self.kv_head_num, self.head_dim
+ )
+ v_cache = self.cache_value_states.narrow(0, 0, max_block_num * batch_size).view(
+ batch_size, max_block_num * self.block_size, self.kv_head_num, self.head_dim
+ )
+
+ for batch_idx in range(batch_size):
+ offset = past_len[batch_idx]
+ width = q_len
+ k_cache[batch_idx][offset : offset + width].copy_(
+ key[batch_idx].view(-1, self.kv_head_num, self.head_dim)
+ )
+ v_cache[batch_idx][offset : offset + width].copy_(
+ value[batch_idx].view(-1, self.kv_head_num, self.head_dim)
+ )
+
+ k_cache_cpu = torch.empty_like(k_cache, device="cpu", pin_memory=True)
+ v_cache_cpu = torch.empty_like(v_cache, device="cpu", pin_memory=True)
+
+ k_cache_cpu.copy_(k_cache)
+ v_cache_cpu.copy_(v_cache)
+
+ cur_block_num = (
+ q_len + past_len[0].item() + self.block_size - 1
+ ) // self.block_size
+ block_table_cpu = self.prefix_block_table[:, :cur_block_num].to("cpu")
+ past_len_cpu = past_len.contiguous().to("cpu")
+
+ self.cpu_infer.submit(
+ self.local_thread.get_and_update_kvcache_fp16(
+ k_cache_cpu,
+ v_cache_cpu,
+ layer_idx,
+ block_table_cpu,
+ max_block_num,
+ past_len_cpu,
+ q_len,
+ )
+ )
+
+ self.cpu_infer.sync()
+ k_cache.copy_(k_cache_cpu)
+ v_cache.copy_(v_cache_cpu)
+
+ return k_cache, v_cache
+
+ def calc_anchor(self, cache_seqlens: int):
+ cur_block_num = (cache_seqlens + self.block_size - 1) // self.block_size
+ block_table_cpu = self.prefix_block_table[:, :cur_block_num].to("cpu")
+ cache_seqlens_cpu = torch.tensor(
+ [cache_seqlens], device="cpu", dtype=torch.int32
+ )
+
+ self.cpu_infer.submit(
+ self.local_thread.calc_anchor_all_layers(
+ block_table_cpu,
+ cache_seqlens_cpu,
+ )
+ )
+ self.cpu_infer.sync()
+
+ def clear_importance(self, cache_seqlens: int):
+ print(f"clear importance: {cache_seqlens}")
+ cur_block_num = (cache_seqlens + self.block_size - 1) // self.block_size
+ block_table_cpu = self.prefix_block_table[:, :cur_block_num].to("cpu")
+ cache_seqlens_cpu = torch.tensor(
+ [cache_seqlens], device="cpu", dtype=torch.int32
+ )
+
+ self.cpu_infer.submit(
+ self.local_thread.clear_importance_all_layers(
+ block_table_cpu,
+ cache_seqlens_cpu,
+ )
+ )
+ self.cpu_infer.sync()
+
+ def clear_kvcache(self, cache_seqlens: int):
+ cur_block_num = (cache_seqlens + self.block_size - 1) // self.block_size
+ block_table_cpu = self.prefix_block_table[:, :cur_block_num].to("cpu")
+ cache_seqlens_cpu = torch.tensor(
+ [cache_seqlens], device="cpu", dtype=torch.int32
+ )
+
+ self.cpu_infer.submit(
+ self.local_thread.clear_kvcache_all_layers(
+ block_table_cpu,
+ cache_seqlens_cpu,
+ )
+ )
+ self.cpu_infer.sync()
+
+ def get_attn_sparsity(
+ self,
+ q_in: torch.Tensor,
+ layer_idx: int,
+ block_table: torch.Tensor,
+ cache_seqlens: torch.Tensor,
+ block_table_origin: torch.Tensor,
+ cache_seqlens_origin: torch.Tensor,
+ generate_token_idx: int = 0,
+ topk: int | None = None,
+ local: int | None = None,
+ output_path: str = "./attn_sparsity.json",
+ ):
+ self.attn_sparsity.zero_()
+ self.pcinfer.submit(
+ self.local_thread.get_attn_sparsity(
+ q_in,
+ self.attn_sparsity,
+ layer_idx,
+ block_table,
+ cache_seqlens,
+ block_table_origin,
+ cache_seqlens_origin,
+ generate_token_idx,
+ topk,
+ local,
+ )
+ )
+ self.cpu_infer.sync()
+ with open(output_path, "a") as file:
+ for head_idx in range(self.q_head_num):
+ sparsity = self.attn_sparsity[0][0][head_idx].item()
+ json_obj = {
+ "token_idx": generate_token_idx,
+ "layer_idx": layer_idx,
+ "head_idx": head_idx,
+ "sparsity": sparsity,
+ }
+ json.dump(json_obj, file)
+ file.write("\n")
+
+ def apply(
+ self,
+ layer_idx: int,
+ bsz: int,
+ past_len: int,
+ query_states: torch.Tensor,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ mode: str = "prefill",
+ generate_token_idx: int = -1,
+ ):
+
+ # key_states: [bsz, q_len, kv_head_num, head_dim]
+ # value_states: [bsz, q_len, kv_head_num, head_dim]
+ # query_states: [bsz, q_len, q_head_num, head_dim]
+ assert query_states.dtype == torch.float16
+ assert key_states.dtype == torch.float16
+ assert value_states.dtype == torch.float16
+
+ assert key_states.size(2) == self.kv_head_num
+ assert value_states.size(2) == self.kv_head_num
+ assert query_states.size(2) == self.q_head_num
+
+ q_len = query_states.size(1)
+ batch_size = query_states.size(0)
+ self.cache_seqlens_cuda.fill_(past_len)
+ last_chunk = False
+ if self.remaining_length <= self.prefill_chunk_size and q_len != 1:
+ last_chunk = True
+ device = query_states.device
+ if layer_idx == 0:
+ if q_len == 1:
+ self.generate_token_idx += 1
+ elif last_chunk:
+ self.generate_token_idx = -1
+
+ if mode == "prefill":
+ key, value = self.swap_in_and_swap_out(
+ layer_idx,
+ self.cache_seqlens_cuda,
+ q_len,
+ key_states,
+ value_states,
+ )
+
+ if last_chunk and (self.anchor_type == "DYNAMIC" or self.preselect_block):
+ self.get_preselect_block_table_and_attn_score(
+ layer_idx,
+ bsz,
+ self.cache_seqlens_cuda,
+ q_len,
+ query_states,
+ key,
+ )
+ output = flash_attn_with_kvcache(
+ q=query_states,
+ k_cache=key,
+ v_cache=value,
+ cache_seqlens=self.cache_seqlens_cuda + q_len,
+ causal=True,
+ )
+ return output.transpose(1, 2)
+
+ elif mode == "generate":
+ assert self.generate_token_idx >= 0
+ self.q_in_cpu.copy_(query_states, non_blocking=True)
+ self.k_in_cpu.copy_(key_states, non_blocking=True)
+ self.v_in_cpu.copy_(value_states, non_blocking=True)
+ self.cache_seqlens_cpu.copy_(self.cache_seqlens_cuda, non_blocking=True)
+ # print(layer_idx)
+ if layer_idx < self.dense_layer_num:
+ self.block_table_cpu.copy_(self.prefix_block_table, non_blocking=True)
+ self.cpu_infer.submit_with_cuda_stream(
+ torch.cuda.current_stream("cuda").cuda_stream,
+ self.local_thread.attn_with_kvcache(
+ q_in=self.q_in_cpu,
+ k_in=self.k_in_cpu,
+ v_in=self.v_in_cpu,
+ output=self.output_cpu,
+ attn_lse=self.lse_cpu,
+ layer_idx=layer_idx,
+ block_table=self.block_table_cpu,
+ cache_seqlens=self.cache_seqlens_cpu,
+ ),
+ )
+ else:
+ if self.preselect_block:
+ self.cache_seqlens_cpu.copy_(
+ self.cache_seqlens_cuda - self.evict_tokens, non_blocking=True
+ )
+ if self.preselect_block_count < self.prefill_block_num:
+ self.block_table_cpu[:, : self.preselect_block_count].copy_(
+ self.preselect_block_table[layer_idx : layer_idx + 1],
+ non_blocking=True,
+ )
+
+ self.block_table_cpu[
+ :,
+ self.preselect_block_count : self.preselect_block_count
+ + self.local_block_num,
+ ].copy_(
+ self.prefix_block_table[
+ :,
+ self.prefill_block_num : self.prefill_block_num
+ + self.local_block_num,
+ ],
+ non_blocking=True,
+ )
+ # print("submit_with_cuda_stream")
+ self.cpu_infer.submit_with_cuda_stream(
+ torch.cuda.current_stream("cuda").cuda_stream,
+ self.local_thread.attn_with_kvcache(
+ q_in=self.q_in_cpu,
+ k_in=self.k_in_cpu,
+ v_in=self.v_in_cpu,
+ output=self.output_cpu,
+ attn_lse=self.lse_cpu,
+ layer_idx=layer_idx,
+ generate_token_idx=self.generate_token_idx,
+ block_table=self.block_table_cpu,
+ cache_seqlens=self.cache_seqlens_cpu,
+ topk=(
+ self.topk
+ if self.topk <= self.preselect_block_count
+ else None
+ ),
+ local=self.local_windows_len // self.block_size,
+ ),
+ )
+ # print("submit_with_cuda_stream enqueue\n")
+ else:
+ self.block_table_cpu.copy_(
+ self.prefix_block_table, non_blocking=True
+ )
+ self.cpu_infer.submit_with_cuda_stream(
+ torch.cuda.current_stream("cuda").cuda_stream,
+ self.local_thread.attn_with_kvcache(
+ q_in=self.q_in_cpu,
+ k_in=self.k_in_cpu,
+ v_in=self.v_in_cpu,
+ output=self.output_cpu,
+ attn_lse=self.lse_cpu,
+ layer_idx=layer_idx,
+ generate_token_idx=self.generate_token_idx,
+ block_table=self.block_table_cpu,
+ cache_seqlens=self.cache_seqlens_cpu,
+ topk=self.topk,
+ local=self.local_windows_len // self.block_size,
+ ),
+ )
+ self.cpu_infer.sync_with_cuda_stream(
+ torch.cuda.current_stream("cuda").cuda_stream
+ )
+ # print("submit_with_cuda_stream finished\n")
+ self.output_cuda.copy_(self.output_cpu, non_blocking=True)
+ return self.output_cuda.transpose(1, 2)
+
+ def save(self, path: str, length: int):
+ cur_block_num = (length + self.block_size - 1) // self.block_size
+ block_table_cpu = self.prefix_block_table[0, :cur_block_num].to("cpu")
+ cache_seqlens_cpu = torch.tensor([length], device="cpu", dtype=torch.int32)
+ self.cpu_infer.submit(
+ self.local_thread.dump_kvcache(
+ block_table_cpu,
+ cache_seqlens_cpu,
+ path,
+ )
+ )
+ self.cpu_infer.sync()
+
+ def load(self, path: str, length: int):
+ self.cpu_infer.submit(
+ self.local_thread.load_kvcache(
+ path,
+ )
+ )
+ self.cpu_infer.sync()
diff --git a/ktransformers/operators/experts.py b/ktransformers/operators/experts.py
index 864c4b7..d84b063 100644
--- a/ktransformers/operators/experts.py
+++ b/ktransformers/operators/experts.py
@@ -6,7 +6,7 @@
Date : 2024-07-25 11:25:24
Version : 0.1.0
LastEditors : Azure
-LastEditTime : 2024-08-15 02:36:29
+LastEditTime : 2024-08-27 03:50:23
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
'''
@@ -436,7 +436,7 @@ def forward(self, hidden_states_cpu: torch.Tensor, selected_experts_cpu: torch.T
final_hidden_states.index_add_(0, top_x, current_hidden_states)
- return final_hidden_states.to(org_dtype, device=org_device)
+ return final_hidden_states.to(dtype=org_dtype, device=org_device)
EXPERTS_MAP = {
"KExpertsCPU": KExpertsCPU,
diff --git a/ktransformers/operators/models.py b/ktransformers/operators/models.py
index c95e1ee..d6cdc47 100644
--- a/ktransformers/operators/models.py
+++ b/ktransformers/operators/models.py
@@ -1,14 +1,14 @@
#!/usr/bin/env python
# coding=utf-8
-'''
+"""
Description :
Author : Azure-Tang
Date : 2024-07-25 11:25:24
Version : 1.0.0
LastEditors : Azure
-LastEditTime : 2024-08-14 14:53:05
+LastEditTime : 2024-08-27 07:29:04
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
-'''
+"""
import inspect
import math
@@ -19,7 +19,10 @@
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
-
+from ktransformers.operators.dynamic_attention import DynamicScaledDotProductAttention
+from ktransformers.server.config.config import Config
+import os
+import yaml
from transformers.activations import ACT2FN
from transformers.cache_utils import Cache, DynamicCache, StaticCache
from transformers.modeling_attn_mask_utils import (
@@ -40,19 +43,35 @@
logging,
replace_return_docstrings,
)
-from ktransformers.models.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock, Qwen2MoeMLP, Qwen2MoeDecoderLayer
-from ktransformers.models.modeling_deepseek import BaseModelOutputWithPast, DeepseekV2DecoderLayer, DeepseekV2MoE
+from ktransformers.models.modeling_qwen2_moe import (
+ Qwen2MoeSparseMoeBlock,
+ Qwen2MoeMLP,
+ Qwen2MoeDecoderLayer,
+)
+from ktransformers.models.modeling_deepseek import (
+ BaseModelOutputWithPast,
+ DeepseekV2DecoderLayer,
+ DeepseekV2MoE,
+)
from transformers.models.qwen2_moe.configuration_qwen2_moe import Qwen2MoeConfig
+from ktransformers.models.configuration_llama import LlamaConfig
from ktransformers.operators.base_operator import BaseInjectedModule
from ktransformers.util.utils import InferenceState
from ktransformers.util.custom_gguf import GGUFLoader
from transformers.configuration_utils import PretrainedConfig
+from ktransformers.models.modeling_llama import (
+ LlamaDecoderLayer,
+ LlamaRMSNorm,
+ LlamaRotaryEmbedding,
+)
if is_flash_attn_2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
- _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
+ _flash_supports_window_size = "window_size" in list(
+ inspect.signature(flash_attn_func).parameters
+ )
logger = logging.get_logger(__name__)
@@ -151,6 +170,7 @@
the complete sequence length.
"""
+
@add_start_docstrings(
"The bare Qwen2MoE Model outputting raw hidden-states without any specific head on top.",
QWEN2MOE_START_DOCSTRING,
@@ -162,18 +182,21 @@ class KQwen2MoeModel(BaseInjectedModule):
Args:
config: Qwen2MoeConfig
"""
+
def __init__(
self,
key: str,
- gguf_loader : GGUFLoader,
+ gguf_loader: GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module,
device: str = "cuda",
- per_layer_prefill_intput_threshold: int = 30000, # if None, no per-layer prefill
+ per_layer_prefill_intput_threshold: int = 30000, # if None, no per-layer prefill
transfer_map: dict = None,
**kwargs,
):
- BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, device, **kwargs)
+ BaseInjectedModule.__init__(
+ self, key, gguf_loader, config, orig_module, device, **kwargs
+ )
self.per_layer_prefill_intput_threshold = per_layer_prefill_intput_threshold
self.transfer_map = transfer_map
self.stream_device_map = dict()
@@ -192,29 +215,47 @@ def forward(
output_router_logits: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
- per_layer_prefill_intput_threshold: int | None = None, # if None or 0, close per-layer prefill
+ per_layer_prefill_intput_threshold: (
+ int | None
+ ) = None, # if None or 0, close per-layer prefill
) -> Union[Tuple, MoeModelOutputWithPast]:
# print(f'Total length of input_ids: {input_ids.size(1)}, {input_ids.size()}')
- if per_layer_prefill_intput_threshold is None: per_layer_prefill_intput_threshold = self.per_layer_prefill_intput_threshold
+ if per_layer_prefill_intput_threshold is None:
+ per_layer_prefill_intput_threshold = self.per_layer_prefill_intput_threshold
per_layer_prefill_flag = False
- seq_lenth = inputs_embeds.size(1) if inputs_embeds is not None else input_ids.size(1)
- if per_layer_prefill_intput_threshold and per_layer_prefill_intput_threshold < seq_lenth:
+ seq_lenth = (
+ inputs_embeds.size(1) if inputs_embeds is not None else input_ids.size(1)
+ )
+ if (
+ per_layer_prefill_intput_threshold
+ and per_layer_prefill_intput_threshold < seq_lenth
+ ):
per_layer_prefill_flag = True
for layer in self.layers:
self.load_layer_to(layer, InferenceState.UNLOAD)
else:
pass
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_attentions = (
+ output_attentions
+ if output_attentions is not None
+ else self.config.output_attentions
+ )
output_router_logits = (
- output_router_logits if output_router_logits is not None else self.config.output_router_logits
+ output_router_logits
+ if output_router_logits is not None
+ else self.config.output_router_logits
)
output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ output_hidden_states
+ if output_hidden_states is not None
+ else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ return_dict = (
+ return_dict if return_dict is not None else self.config.use_return_dict
+ )
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
@@ -243,15 +284,23 @@ def forward(
inputs_embeds = inputs_embeds.to("cuda")
if cache_position is None:
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ past_seen_tokens = (
+ past_key_values.get_seq_length() if past_key_values is not None else 0
+ )
cache_position = torch.arange(
- past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ past_seen_tokens,
+ past_seen_tokens + inputs_embeds.shape[1],
+ device=inputs_embeds.device,
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(
- attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
+ attention_mask,
+ inputs_embeds,
+ cache_position,
+ past_key_values,
+ output_attentions,
)
hidden_states = inputs_embeds
@@ -263,7 +312,7 @@ def forward(
next_decoder_cache = None
for i, decoder_layer in enumerate(self.layers):
- if self.transfer_map is not None and i in self.transfer_map:
+ if self.transfer_map is not None and i in self.transfer_map:
prev_stream = torch.cuda.current_stream()
cur_device = self.transfer_map[i]
if cur_device not in self.stream_device_map:
@@ -271,11 +320,25 @@ def forward(
torch.cuda.set_device(cur_device)
self.stream_device_map[cur_device].wait_stream(prev_stream)
torch.cuda.set_stream(self.stream_device_map[cur_device])
- hidden_states = hidden_states.to(self.transfer_map[i], non_blocking = True)
- causal_mask = causal_mask.to(self.transfer_map[i], non_blocking = True) if causal_mask is not None else None
- position_ids = position_ids.to(self.transfer_map[i], non_blocking = True) if position_ids is not None else None
- cache_position = cache_position.to(self.transfer_map[i], non_blocking = True) if cache_position is not None else None
-
+ hidden_states = hidden_states.to(
+ self.transfer_map[i], non_blocking=True
+ )
+ causal_mask = (
+ causal_mask.to(self.transfer_map[i], non_blocking=True)
+ if causal_mask is not None
+ else None
+ )
+ position_ids = (
+ position_ids.to(self.transfer_map[i], non_blocking=True)
+ if position_ids is not None
+ else None
+ )
+ cache_position = (
+ cache_position.to(self.transfer_map[i], non_blocking=True)
+ if cache_position is not None
+ else None
+ )
+
if output_hidden_states:
all_hidden_states += (hidden_states,)
@@ -323,7 +386,6 @@ def forward(
hidden_states = self.norm(hidden_states)
-
if per_layer_prefill_flag:
per_layer_prefill_flag = False
for layer in self.layers:
@@ -333,12 +395,22 @@ def forward(
next_cache = None
if use_cache:
- next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
+ next_cache = (
+ next_decoder_cache.to_legacy_cache()
+ if use_legacy_cache
+ else next_decoder_cache
+ )
if not return_dict:
return tuple(
v
- for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]
+ for v in [
+ hidden_states,
+ next_cache,
+ all_hidden_states,
+ all_self_attns,
+ all_router_logits,
+ ]
if v is not None
)
return MoeModelOutputWithPast(
@@ -349,11 +421,13 @@ def forward(
router_logits=all_router_logits,
)
- def load_layer_to(self, layer:Qwen2MoeDecoderLayer, target: InferenceState):
- assert isinstance(layer, Qwen2MoeDecoderLayer), "module should be nn.ModuleList of decoder layers"
+ def load_layer_to(self, layer: Qwen2MoeDecoderLayer, target: InferenceState):
+ assert isinstance(
+ layer, Qwen2MoeDecoderLayer
+ ), "module should be nn.ModuleList of decoder layers"
# TODO Support restore to original device, not only cuda
- device = "cpu" if target == InferenceState.UNLOAD else "cuda"
+ device = "cpu" if target == InferenceState.UNLOAD else "cuda"
# attn
layer.self_attn.q_proj.set_inference_mode(target)
@@ -458,18 +532,21 @@ class KDeepseekV2Model(BaseInjectedModule):
Args:
config: DeepseekV2Config
"""
+
def __init__(
self,
key: str,
- gguf_loader : GGUFLoader,
+ gguf_loader: GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module,
device: str = "cuda",
- per_layer_prefill_intput_threshold: int = 30000, # if None, no per-layer prefill
+ per_layer_prefill_intput_threshold: int = 30000, # if None, no per-layer prefill
transfer_map: dict = None,
**kwargs,
):
- BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, device, **kwargs)
+ BaseInjectedModule.__init__(
+ self, key, gguf_loader, config, orig_module, device, **kwargs
+ )
self.per_layer_prefill_intput_threshold = per_layer_prefill_intput_threshold
self.transfer_map = transfer_map
self.stream_device_map = dict()
@@ -487,15 +564,23 @@ def forward(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
- per_layer_prefill_intput_threshold: int | None = None, # if None, no per-layer prefill
+ per_layer_prefill_intput_threshold: (
+ int | None
+ ) = None, # if None, no per-layer prefill
) -> Union[Tuple, BaseModelOutputWithPast]:
- if per_layer_prefill_intput_threshold is None: per_layer_prefill_intput_threshold = self.per_layer_prefill_intput_threshold
+ if per_layer_prefill_intput_threshold is None:
+ per_layer_prefill_intput_threshold = self.per_layer_prefill_intput_threshold
per_layer_prefill_flag = False
- seq_lenth = inputs_embeds.size(1) if inputs_embeds is not None else input_ids.size(1)
- if per_layer_prefill_intput_threshold and per_layer_prefill_intput_threshold < seq_lenth:
+ seq_lenth = (
+ inputs_embeds.size(1) if inputs_embeds is not None else input_ids.size(1)
+ )
+ if (
+ per_layer_prefill_intput_threshold
+ and per_layer_prefill_intput_threshold < seq_lenth
+ ):
per_layer_prefill_flag = True
for layer in self.layers:
- self.load_layer_to(layer, InferenceState.UNLOAD)
+ self.load_layer_to(layer, InferenceState.UNLOAD)
torch.cuda.empty_cache()
else:
pass
@@ -542,9 +627,13 @@ def forward(
past_key_values_length = past_key_values.get_usable_length(seq_length)
if cache_position is None:
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ past_seen_tokens = (
+ past_key_values.get_seq_length() if past_key_values is not None else 0
+ )
cache_position = torch.arange(
- past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ past_seen_tokens,
+ past_seen_tokens + inputs_embeds.shape[1],
+ device=inputs_embeds.device,
)
if position_ids is None:
@@ -556,15 +645,17 @@ def forward(
inputs_embeds = self.embed_tokens(input_ids)
input_ids = input_ids.to(org_device)
-
- causal_mask = self._update_causal_mask(
- attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
- )
+ if per_layer_prefill_flag:
+ causal_mask = None
+ else:
+ causal_mask = self._update_causal_mask(
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
+ )
# embed positions
hidden_states = inputs_embeds
if per_layer_prefill_flag:
- print(f'Total length of input_ids: {hidden_states.size(1)}')
+ print(f"Total length of input_ids: {hidden_states.size(1)}")
# decoder layers
all_hidden_states = () if output_hidden_states else None
@@ -576,7 +667,7 @@ def forward(
t_f = 0
for i, decoder_layer in enumerate(self.layers):
- if self.transfer_map is not None and i in self.transfer_map:
+ if self.transfer_map is not None and i in self.transfer_map:
prev_stream = torch.cuda.current_stream()
cur_device = self.transfer_map[i]
if cur_device not in self.stream_device_map:
@@ -584,10 +675,24 @@ def forward(
torch.cuda.set_device(cur_device)
self.stream_device_map[cur_device].wait_stream(prev_stream)
torch.cuda.set_stream(self.stream_device_map[cur_device])
- hidden_states = hidden_states.to(self.transfer_map[i], non_blocking = True)
- causal_mask = causal_mask.to(self.transfer_map[i], non_blocking = True) if causal_mask is not None else None
- position_ids = position_ids.to(self.transfer_map[i], non_blocking = True) if position_ids is not None else None
- cache_position = cache_position.to(self.transfer_map[i], non_blocking = True) if cache_position is not None else None
+ hidden_states = hidden_states.to(
+ self.transfer_map[i], non_blocking=True
+ )
+ causal_mask = (
+ causal_mask.to(self.transfer_map[i], non_blocking=True)
+ if causal_mask is not None
+ else None
+ )
+ position_ids = (
+ position_ids.to(self.transfer_map[i], non_blocking=True)
+ if position_ids is not None
+ else None
+ )
+ cache_position = (
+ cache_position.to(self.transfer_map[i], non_blocking=True)
+ if cache_position is not None
+ else None
+ )
if output_hidden_states:
all_hidden_states += (hidden_states,)
@@ -622,12 +727,12 @@ def forward(
t5 = time.time()
if per_layer_prefill_flag:
# print(f"to cpu")
- self.load_layer_to(decoder_layer, InferenceState.UNLOAD)
+ self.load_layer_to(decoder_layer, InferenceState.UNLOAD)
torch.cuda.empty_cache()
t6 = time.time()
- t_gpu += t4-t3
- t_cpu += t6-t5
- t_f += t5-t4
+ t_gpu += t4 - t3
+ t_cpu += t6 - t5
+ t_f += t5 - t4
hidden_states = layer_outputs[0]
@@ -648,7 +753,9 @@ def forward(
torch.cuda.empty_cache()
t7 = time.time()
- print(f"total time: {t7-t3}, \n layer num{len(self.layers)}, gpu time: {t_gpu}, cpu time: {t_cpu}, forward time: {t_f}, restore time: {t7-t6}")
+ print(
+ f"total time: {t7-t3}, \n layer num{len(self.layers)}, gpu time: {t_gpu}, cpu time: {t_cpu}, forward time: {t_f}, restore time: {t7-t6}"
+ )
# add hidden states from the last decoder layer
if output_hidden_states:
@@ -674,16 +781,18 @@ def forward(
attentions=all_self_attns,
)
- def load_layer_to(self, layer: DeepseekV2DecoderLayer, target: InferenceState):
- assert isinstance(layer, DeepseekV2DecoderLayer), "module should be nn.ModuleList of decoder layers"
+ def load_layer_to(self, layer: DeepseekV2DecoderLayer, target: InferenceState):
+ assert isinstance(
+ layer, DeepseekV2DecoderLayer
+ ), "module should be nn.ModuleList of decoder layers"
# TODO Support restore to original device, not only cuda
- device = "cpu" if target == InferenceState.UNLOAD else "cuda"
+ device = "cpu" if target == InferenceState.UNLOAD else "cuda"
# TODO Support DFS to auto use {to, set_inference_mode} according to the module type
# attn
- layer.self_attn.to(device) #
+ layer.self_attn.to(device) #
# mlp
if isinstance(layer.mlp, DeepseekV2MoE):
@@ -702,3 +811,526 @@ def load_layer_to(self, layer: DeepseekV2DecoderLayer, target: InferenceState):
# layer norm
layer.input_layernorm.to(device)
layer.post_attention_layernorm.to(device)
+
+
+LLAMA_START_DOCSTRING = r"""
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+ and behavior.
+
+ Parameters:
+ config ([`LlamaConfig`]):
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
+ load the weights associated with the model, only the configuration. Check out the
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+LLAMA_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+ it.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
+ `past_key_values`).
+
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
+ information on the default strategy.
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.n_positions - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
+
+ Two formats are allowed:
+ - a [`~cache_utils.Cache`] instance;
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
+ cache format.
+
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
+ legacy cache format will be returned.
+
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
+ of shape `(batch_size, sequence_length)`.
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
+ the complete sequence length.
+"""
+
+
+@add_start_docstrings(
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
+ LLAMA_START_DOCSTRING,
+)
+class LlamaPreTrainedModel(PreTrainedModel):
+ config_class = LlamaConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["LlamaDecoderLayer"]
+ _skip_keys_device_placement = ["past_key_values"]
+ _supports_flash_attn_2 = True
+ _supports_sdpa = True
+ _supports_cache_class = True
+ _supports_quantized_cache = True
+ _supports_static_cache = True
+
+ def _init_weights(self, module):
+ std = self.config.initializer_range
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+
+class KLlamaModel(BaseInjectedModule):
+ """
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
+
+ Args:
+ config: LlamaConfig
+ """
+
+ dynamic_sdpa = None
+
+ def __init__(
+ self,
+ key: str,
+ gguf_loader: GGUFLoader,
+ config: PretrainedConfig,
+ orig_module: nn.Module,
+ device: str = "cuda",
+ per_layer_prefill_intput_threshold: int = 30000, # if None, no per-layer prefill
+ transfer_map: dict = None,
+ **kwargs,
+ ):
+
+ BaseInjectedModule.__init__(
+ self, key, gguf_loader, config, orig_module, device, **kwargs
+ )
+ self.per_layer_prefill_intput_threshold = per_layer_prefill_intput_threshold
+ self.transfer_map = transfer_map
+ self.stream_device_map = dict()
+ user_path: str = os.path.expanduser('~')
+ localstore_path: str = os.path.join(user_path,'.ktransformers')
+ config_path: str = os.path.join(localstore_path,Config.CONFIG_FILE_NAME)
+ with open(config_path,"r") as file:
+ config_yaml = yaml.safe_load(file.read())
+ self.long_context_config = config_yaml.get("long_context")
+ self.ext_config = config_yaml.get("ext")
+
+ KLlamaModel.dynamic_sdpa = DynamicScaledDotProductAttention(
+ max_seq_len=self.long_context_config["max_seq_len"],
+ block_size=self.long_context_config["block_size"],
+ config=config,
+ device=torch.device("cuda"),
+ local_windows_len=self.long_context_config["local_windows_len"],
+ topk=self.long_context_config["second_select_num"],
+ threads_num=self.ext_config["cpu_infer"],
+ anchor_type=self.long_context_config["anchor_type"],
+ kv_type=self.long_context_config["kv_type"],
+ dense_layer_num=self.long_context_config["dense_layer_num"],
+ anchor_num=self.long_context_config["anchor_num"],
+ preselect_block=self.long_context_config["preselect_block"],
+ block_selection_mode=self.long_context_config["head_select_mode"],
+ preselect_block_count=self.long_context_config["preselect_block_count"],
+ layer_step=self.long_context_config["layer_step"],
+ token_step=self.long_context_config["token_step"],
+ prefill_chunk_size=self.long_context_config["chunk_size"],
+ use_attn_sparsity=False,
+ )
+
+ def get_input_embeddings(self):
+ return self.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.embed_tokens = value
+
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
+ output_attentions = (
+ output_attentions
+ if output_attentions is not None
+ else self.config.output_attentions
+ )
+ output_hidden_states = (
+ output_hidden_states
+ if output_hidden_states is not None
+ else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = (
+ return_dict if return_dict is not None else self.config.use_return_dict
+ )
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError(
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
+ )
+
+ if self.gradient_checkpointing and self.training and use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
+ )
+ use_cache = False
+
+ return_legacy_cache = False
+ if (
+ use_cache and not isinstance(past_key_values, Cache) and not self.training
+ ): # kept for BC (non `Cache` `past_key_values` inputs)
+ return_legacy_cache = True
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
+ logger.warning_once(
+ "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. "
+ "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)"
+ )
+
+ if cache_position is None:
+ past_seen_tokens = (
+ past_key_values.get_seq_length() if past_key_values is not None else 0
+ )
+ cache_position = torch.arange(
+ past_seen_tokens,
+ past_seen_tokens + inputs_embeds.shape[1],
+ device="cuda",
+ )
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ causal_mask = None
+ chunck_size = self.long_context_config["chunk_size"]
+ cur_idx = 0
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids.to("cpu"))
+ q_len = cache_position.size(0)
+
+ # generate
+ if q_len == 1:
+ x = inputs_embeds[:, -1:, :]
+ position_ids = position_ids[:, -1:]
+ return self.forward_chunk(
+ x,
+ causal_mask,
+ position_ids,
+ past_key_values,
+ output_attentions,
+ use_cache,
+ cache_position,
+ output_hidden_states,
+ return_dict,
+ )
+ elif q_len <= chunck_size:
+ inputs_embeds = inputs_embeds.to('cuda')
+ output = self.forward_chunk(
+ inputs_embeds,
+ causal_mask,
+ position_ids,
+ past_key_values,
+ output_attentions,
+ use_cache,
+ cache_position,
+ output_hidden_states,
+ return_dict,
+ )
+ KLlamaModel.dynamic_sdpa.calc_anchor(cache_position[-1] + 1)
+ KLlamaModel.dynamic_sdpa.clear_importance(cache_position[-1] + 1)
+ return output
+ cur_idx = 0
+ assert (
+ output_attentions == False
+ ), "output_attentions is not supported when using chunked attention"
+ attn_output = None
+ # prefill
+ KLlamaModel.dynamic_sdpa.remaining_length = q_len
+ while cur_idx < q_len:
+ print(f'current prefill length: {cur_idx}')
+ chunk_mask = None
+ if inputs_embeds.device.type == 'cpu':
+ tmp_inputs_embeds = inputs_embeds[:, cur_idx : min(cur_idx + chunck_size, q_len)].to("cuda")
+ else:
+ tmp_inputs_embeds = inputs_embeds[:, cur_idx : min(cur_idx + chunck_size, q_len)]
+ output_with_past = self.forward_chunk(
+ tmp_inputs_embeds,
+ chunk_mask,
+ position_ids[:, cur_idx : min(cur_idx + chunck_size, q_len)],
+ past_key_values,
+ output_attentions,
+ use_cache,
+ cache_position[cur_idx : min(cur_idx + chunck_size, q_len)],
+ )
+ cur_output = output_with_past.last_hidden_state
+ KLlamaModel.dynamic_sdpa.remaining_length -= (
+ min(cur_idx + chunck_size, q_len) - cur_idx
+ )
+ cur_idx += chunck_size
+ # if attn_output is None:
+ attn_output = cur_output
+ # else:
+ # attn_output = torch.cat((attn_output, cur_output), dim=-2)
+
+ KLlamaModel.dynamic_sdpa.calc_anchor(cache_position[-1] + 1)
+ KLlamaModel.dynamic_sdpa.clear_importance(cache_position[-1] + 1)
+ return BaseModelOutputWithPast(last_hidden_state=attn_output)
+
+ def forward_chunk(
+ self,
+ inputs_embeds,
+ causal_mask,
+ position_ids,
+ past_key_values,
+ output_attentions,
+ use_cache,
+ cache_position,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ):
+
+ output_hidden_states = (
+ output_hidden_states
+ if output_hidden_states is not None
+ else self.config.output_hidden_states
+ )
+ return_legacy_cache = False
+ if use_cache and not isinstance(
+ past_key_values, Cache
+ ): # kept for BC (non `Cache` `past_key_values` inputs)
+ return_legacy_cache = True
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
+ return_dict = (
+ return_dict if return_dict is not None else self.config.use_return_dict
+ )
+
+ hidden_states = inputs_embeds
+
+ # create position embeddings to be shared across the decoder layers
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ next_decoder_cache = None
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ next_decoder_cache = None
+
+ for decoder_layer in self.layers:
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ decoder_layer.__call__,
+ hidden_states,
+ causal_mask,
+ position_ids,
+ past_key_values,
+ output_attentions,
+ use_cache,
+ cache_position,
+ position_embeddings,
+ )
+ else:
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ hidden_states = self.norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ next_cache = next_decoder_cache if use_cache else None
+ if return_legacy_cache:
+ next_cache = next_cache.to_legacy_cache()
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
+ if v is not None
+ )
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+ def _update_causal_mask(
+ self,
+ attention_mask: torch.Tensor,
+ input_tensor: torch.Tensor,
+ cache_position: torch.Tensor,
+ past_key_values: Cache,
+ output_attentions: bool,
+ ):
+ # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
+ # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
+ # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
+ # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
+
+ if self.config._attn_implementation == "flash_attention_2":
+ if attention_mask is not None and 0.0 in attention_mask:
+ return attention_mask
+ return None
+
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
+ # to infer the attention mask.
+ past_seen_tokens = (
+ past_key_values.get_seq_length() if past_key_values is not None else 0
+ )
+ using_static_cache = isinstance(past_key_values, StaticCache)
+
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
+ if (
+ self.config._attn_implementation == "sdpa"
+ and not using_static_cache
+ and not output_attentions
+ ):
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
+ attention_mask,
+ inputs_embeds=input_tensor,
+ past_key_values_length=past_seen_tokens,
+ is_training=self.training,
+ ):
+ return None
+
+ dtype, device = input_tensor.dtype, input_tensor.device
+ min_dtype = torch.finfo(dtype).min
+ sequence_length = input_tensor.shape[1]
+ if using_static_cache:
+ target_length = past_key_values.get_max_length()
+ else:
+ target_length = (
+ attention_mask.shape[-1]
+ if isinstance(attention_mask, torch.Tensor)
+ else past_seen_tokens + sequence_length + 1
+ )
+
+ if attention_mask is not None and attention_mask.dim() == 4:
+ # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
+ if attention_mask.max() != 0:
+ raise ValueError(
+ "Custom 4D attention mask should be passed in inverted form with max==0`"
+ )
+ causal_mask = attention_mask
+ else:
+ causal_mask = torch.full(
+ (sequence_length, target_length),
+ fill_value=min_dtype,
+ dtype=dtype,
+ device=device,
+ )
+ if sequence_length != 1:
+ causal_mask = torch.triu(causal_mask, diagonal=1)
+ causal_mask *= torch.arange(
+ target_length, device=device
+ ) > cache_position.reshape(-1, 1)
+ causal_mask = causal_mask[None, None, :, :].expand(
+ input_tensor.shape[0], 1, -1, -1
+ )
+ if attention_mask is not None:
+ causal_mask = (
+ causal_mask.clone()
+ ) # copy to contiguous memory for in-place edit
+ mask_length = attention_mask.shape[-1]
+ padding_mask = (
+ causal_mask[:, :, :, :mask_length]
+ + attention_mask[:, None, None, :]
+ )
+ padding_mask = padding_mask == 0
+ causal_mask[:, :, :, :mask_length] = causal_mask[
+ :, :, :, :mask_length
+ ].masked_fill(padding_mask, min_dtype)
+ if (
+ self.config._attn_implementation == "sdpa"
+ and attention_mask is not None
+ and attention_mask.device.type == "cuda"
+ and not output_attentions
+ ):
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
+ # Details: https://github.com/pytorch/pytorch/issues/110213
+ causal_mask = AttentionMaskConverter._unmask_unattended(
+ causal_mask, min_dtype
+ )
+
+ return causal_mask
diff --git a/ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat-multi-gpu-4.yaml b/ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat-multi-gpu-4.yaml
index d7adfa2..07f173f 100644
--- a/ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat-multi-gpu-4.yaml
+++ b/ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat-multi-gpu-4.yaml
@@ -225,4 +225,4 @@
class: "default"
kwargs:
generate_device: "cuda:3"
- prefill_device: "cuda:3"
\ No newline at end of file
+ prefill_device: "cuda:3"
diff --git a/ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat-multi-gpu.yaml b/ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat-multi-gpu.yaml
index a21b22d..3884077 100644
--- a/ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat-multi-gpu.yaml
+++ b/ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat-multi-gpu.yaml
@@ -123,4 +123,4 @@
class: "default"
kwargs:
generate_device: "cuda:1"
- prefill_device: "cuda:1"
\ No newline at end of file
+ prefill_device: "cuda:1"
diff --git a/ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat.yaml b/ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat.yaml
index a2701e1..52db7dd 100644
--- a/ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat.yaml
+++ b/ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat.yaml
@@ -6,7 +6,7 @@
generate_device: "cuda"
prefill_device: "cuda"
- match:
- name: "^model\\.layers\\.(?!.*self_attn).*$" # regular expression
+ name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
@@ -41,6 +41,12 @@
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
+- match:
+ name: "^model$"
+ replace:
+ class: "ktransformers.operators.models.KDeepseekV2Model"
+ kwargs:
+ per_layer_prefill_intput_threshold: 2000 # 0 is close layer wise prefill
- match:
name: "^model.embed_tokens"
replace:
diff --git a/ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat-multi-gpu.yaml b/ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat-multi-gpu.yaml
index cfd77dc..99d01c0 100644
--- a/ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat-multi-gpu.yaml
+++ b/ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat-multi-gpu.yaml
@@ -123,4 +123,4 @@
class: "default"
kwargs:
generate_device: "cuda:1"
- prefill_device: "cuda:1"
\ No newline at end of file
+ prefill_device: "cuda:1"
diff --git a/ktransformers/optimize/optimize_rules/Internlm2_5-7b-Chat-1m.yaml b/ktransformers/optimize/optimize_rules/Internlm2_5-7b-Chat-1m.yaml
new file mode 100644
index 0000000..51a8142
--- /dev/null
+++ b/ktransformers/optimize/optimize_rules/Internlm2_5-7b-Chat-1m.yaml
@@ -0,0 +1,28 @@
+- match:
+ class: ktransformers.models.modeling_llama.LlamaRotaryEmbedding
+ replace:
+ class: ktransformers.operators.RoPE.RotaryEmbeddingV2
+- match:
+ name: "^model.embed_tokens"
+ replace:
+ class: "default"
+ kwargs:
+ generate_device: "cpu"
+ prefill_device: "cpu"
+- match:
+ class: ktransformers.models.modeling_llama.LlamaModel
+ replace:
+ class: ktransformers.operators.models.KLlamaModel
+ kwargs:
+ generate_device: "cuda"
+ prefill_device: "cuda"
+ per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
+
+- match:
+ name: "^model\\.layers\\..*\\.self_attn$"
+ replace:
+ class: ktransformers.operators.attention.KLlamaAttention
+ kwargs:
+ generate_device: "cuda"
+ prefill_device: "cuda"
+
diff --git a/ktransformers/optimize/optimize_rules/Qwen2-57B-A14B-Instruct-multi-gpu.yaml b/ktransformers/optimize/optimize_rules/Qwen2-57B-A14B-Instruct-multi-gpu.yaml
index bfa60b7..da4fb4a 100644
--- a/ktransformers/optimize/optimize_rules/Qwen2-57B-A14B-Instruct-multi-gpu.yaml
+++ b/ktransformers/optimize/optimize_rules/Qwen2-57B-A14B-Instruct-multi-gpu.yaml
@@ -109,4 +109,4 @@
class: "default"
kwargs:
generate_device: "cuda:1"
- prefill_device: "cuda:1"
\ No newline at end of file
+ prefill_device: "cuda:1"
diff --git a/ktransformers/optimize/optimize_rules/Qwen2-57B-A14B-Instruct.yaml b/ktransformers/optimize/optimize_rules/Qwen2-57B-A14B-Instruct.yaml
index 073332c..989e4b8 100644
--- a/ktransformers/optimize/optimize_rules/Qwen2-57B-A14B-Instruct.yaml
+++ b/ktransformers/optimize/optimize_rules/Qwen2-57B-A14B-Instruct.yaml
@@ -1,3 +1,10 @@
+- match:
+ name: "^model\\.layers\\..*\\."
+ replace:
+ class: "default"
+ kwargs:
+ generate_device: "cuda"
+ prefill_device: "cuda"
- match:
class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeRotaryEmbedding
replace:
@@ -54,4 +61,4 @@
class: "default"
kwargs:
generate_device: "cuda"
- prefill_device: "cuda"
\ No newline at end of file
+ prefill_device: "cuda"
diff --git a/ktransformers/server/config/config.py b/ktransformers/server/config/config.py
index e17d215..d391d66 100644
--- a/ktransformers/server/config/config.py
+++ b/ktransformers/server/config/config.py
@@ -5,10 +5,11 @@
Author : unicornchan
Date : 2024-06-11 16:35:42
Version : 1.0.0
-LastEditors : chenxl
-LastEditTime : 2024-07-27 01:55:42
+LastEditors : WuHao
+LastEditTime : 2024-08-12 06:31:14
'''
import os
+import shutil
import yaml
from ktransformers.server.config.singleton import Singleton
@@ -30,10 +31,18 @@ def load() -> dict:
os.path.dirname(os.path.dirname(__file__)))
config_yaml: str = os.path.join(
base_path, "configs", Config.CONFIG_FILE_NAME)
+
+ user_path: str = os.path.expanduser('~')
+ localstore_path: str = os.path.join(user_path,'.ktransformers')
+ config_path: str = os.path.join(localstore_path,Config.CONFIG_FILE_NAME)
if not os.path.exists(config_yaml):
print(f"Can't find config file, {config_yaml}")
exit(-1)
- with open(config_yaml, 'r', encoding="utf-8") as fp:
+ if not os.path.exists(localstore_path):
+ os.mkdir(localstore_path)
+ if not os.path.exists(config_path):
+ shutil.copyfile(config_yaml,config_path)
+ with open(config_path, 'r', encoding="utf-8") as fp:
config = yaml.safe_load(fp)
return config
@@ -51,6 +60,8 @@ def __init__(self):
cfg = Config.load()
self.base_path = os.path.dirname(
os.path.dirname(os.path.dirname(__file__)))
+ self.user_path: str = os.path.expanduser('~')
+ self.localstore_path: str = os.path.join(self.user_path,'.ktransformers')
# log configs
self.log_dir = os.path.join(self.base_path, Config.to_path(cfg["log"]["dir"]))
self.log_file = cfg["log"]["file"]
@@ -83,11 +94,20 @@ def __init__(self):
self.model_name: str = self.model.get("name", "")
self.model_device: str = self.model.get("device", "cuda:0")
self.gguf_path: str = self.model.get("gguf_path", "")
+ self.model_cache_lens = self.model.get("cache_lens")
# web config
self.web: dict = cfg.get("web", {})
self.web_cross_domain: bool = self.web.get("open_cross_domain", True)
self.mount_web: bool = self.web.get("mount", False)
-
+
self.ext: dict = cfg.get("ext", {})
self.cpu_infer = self.ext.get("cpu_infer", 10)
+
+ #file config
+ self.local_store_configs: dict = cfg.get("local_store",{})
+ self.file_upload_dir: str = os.path.join(self.localstore_path,self.local_store_configs.get("file_upload_dir",""))
+ self.assistant_store_dir: str = os.path.join(self.localstore_path,self.local_store_configs.get("assistant_store_dir",""))
+
+ #long context config
+ self.long_context_config: dict = cfg.get("long_context",{})
\ No newline at end of file
diff --git a/ktransformers/util/cuda_graph_runner.py b/ktransformers/util/cuda_graph_runner.py
index c7a9c87..b4b0adc 100644
--- a/ktransformers/util/cuda_graph_runner.py
+++ b/ktransformers/util/cuda_graph_runner.py
@@ -46,7 +46,8 @@ def capture(
capture_stream.wait_stream(torch.cuda.current_stream())
torch.cuda.set_device(main_device)
torch.cuda.set_stream(capture_stream)
- past_key_values.change_seq_length(-1)
+ if past_key_values != None:
+ past_key_values.change_seq_length(-1)
torch.cuda.synchronize(self.main_device)
#self.graph.debug_dump("cuda_graph_hooked.dot")
diff --git a/ktransformers/util/custom_gguf.py b/ktransformers/util/custom_gguf.py
index b3929be..04ce0ae 100644
--- a/ktransformers/util/custom_gguf.py
+++ b/ktransformers/util/custom_gguf.py
@@ -6,7 +6,7 @@
Date : 2024-07-26 08:48:54
Version : 1.0.0
LastEditors : kkk1nak0
-LastEditTime : 2024-08-12 07:21:55
+LastEditTime : 2024-08-14 08:20:45
Adapted from https://github.com/99991/pygguf/blob/main/gguf.py
Copyright (c) 2023-2024 The ggml authors
Copyright (c) 2024 Thomas Germer
@@ -294,7 +294,6 @@ def load_gguf_tensor(self, name: str, device:str = "cpu")->torch.Tensor:
else:
values = GGML_DEQUANTIZE[ggml_name](data)
values = torch.from_numpy(values)
-
values = values.view(shape[::-1])
if "attn_q" in name and self.gguf_file_meta['general.architecture'] in ["llama"]:
n_head = self.gguf_file_meta['llama.attention.head_count']
diff --git a/ktransformers/util/utils.py b/ktransformers/util/utils.py
index 8c91d47..f85b66e 100644
--- a/ktransformers/util/utils.py
+++ b/ktransformers/util/utils.py
@@ -84,7 +84,8 @@ def load_weights(module:nn.Module, gguf_loader:GGUFLoader, prefix=''):
else:
module.load()
-def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cuda_graph: bool = True):
+def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cuda_graph: bool = True,
+ mode = 'normal'):
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
torch._dynamo.config.suppress_errors = True
@@ -110,7 +111,8 @@ def decode_one_tokens(cuda_graph_runner, cur_token, position_ids, cache_position
cache_position=cache_position,
past_key_values=past_key_values,
return_dict=False, use_cache=True)[0]
- past_key_values.change_seq_length(1)
+ if past_key_values != None:
+ past_key_values.change_seq_length(1)
for device in all_cuda_device:
torch.cuda.synchronize(device)
#print(logits)
@@ -125,18 +127,26 @@ def decode_one_tokens(cuda_graph_runner, cur_token, position_ids, cache_position
torch.cuda.set_device(torch_device)
with torch.no_grad():
stream = TextStreamer(tokenizer)
- past_key_values = StaticCache(
- config = model.config, max_batch_size = 1, max_cache_len = seq_length + max_new_tokens, device = device_map, dtype = model.dtype
- )
+ if mode != 'long_context':
+ past_key_values = StaticCache(
+ config = model.config, max_batch_size = 1, max_cache_len = seq_length + max_new_tokens, device = device_map, dtype = model.dtype
+ )
+ else:
+ past_key_values = None
cache_position = torch.arange(seq_length, device=torch_device)
generated_ids = torch.zeros(
batch_size, seq_length + max_new_tokens + 1, dtype=torch.int, device=torch_device
)
generated_ids[:, cache_position] = inputs.to(torch_device).to(torch.int)
- past_key_values.cur_idx=cache_position
+ if past_key_values != None:
+ past_key_values.cur_idx=cache_position
start_time = time.time()
inputs_embeds = model.model.embed_tokens(inputs.to("cpu")).to(torch_device)
+ if mode == "long_context":
+ inputs_embeds = model.model.embed_tokens(inputs.to("cpu"))
+ else:
+ inputs_embeds = model.model.embed_tokens(inputs.to("cpu")).to(torch_device)
logits = model(
inputs_embeds = inputs_embeds, cache_position=cache_position, past_key_values=past_key_values, return_dict=False, use_cache=True
)[0][:,-1,:].unsqueeze(0).clone().to(torch_device)
@@ -184,7 +194,7 @@ def decode_one_tokens(cuda_graph_runner, cur_token, position_ids, cache_position
tokens.append(next_token.int())
seq_length += 1
- if next_token[0].item() == tokenizer.eos_token_id:
+ if next_token[0].item() == tokenizer.eos_token_id or tokenizer.decode(next_token) == '<|im_end|>':
print(stream.end(), end="", flush=True)
break
else:
diff --git a/pyproject.toml b/pyproject.toml
index 863fcb4..adeb8a9 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -27,7 +27,8 @@ dependencies = [
"wheel",
"colorlog",
"build",
- "fire"
+ "fire",
+ "protobuf"
]
requires-python = ">=3.10"
diff --git a/requirements-local_chat.txt b/requirements-local_chat.txt
index 17cb0f1..50b1f65 100644
--- a/requirements-local_chat.txt
+++ b/requirements-local_chat.txt
@@ -3,4 +3,5 @@ transformers
numpy
torch>=2.3.0
packaging
-cpufeature
\ No newline at end of file
+cpufeature
+protobuf
\ No newline at end of file