Skip to content

Commit 1b11e3e

Browse files
authored
Qualcomm AI Engine Direct - documentation for KV cache update (#8134)
summary - visualize KV cache update mechanism for better understanding - asset folder for storing diagrams
1 parent 9832db9 commit 1b11e3e

File tree

3 files changed

+57
-3
lines changed

3 files changed

+57
-3
lines changed

examples/qualcomm/oss_scripts/llama/README.md

Lines changed: 57 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@ This file provides you the instructions to run LLAMA model with different parame
55
1. LLAMA2 Stories 110M
66
2. LLAMA3.2 1B
77
3. LLAMA3.2 3B (WIP)
8+
89
We offer the following modes to execute the model:
910

10-
Prefill Mode: This is also known as batch prefill mode, where the model takes in a list of tokens as input and generates the next token along with the key-value (KV) cache for all tokens. This mode is efficient for generating the initial sequence of tokens (usually the user's prompt).
11+
Prefill Mode: This is also known as batch prefill mode, where the model takes in a list of tokens as input and generates the next token along with the key-value (KV) cache for all tokens. This mode is efficient for encoding the user's prompt.
1112

1213
KV Cache Mode: In KV Cache mode, the model takes in a single previous token and generates the next predicted token along with its KV cache. It is efficient for generating subsequent tokens after the initial prompt.
1314

@@ -41,7 +42,7 @@ python -m extension.llm.tokenizer.tokenizer -t tokenizer.model -o tokenizer.bin
4142
echo '{"dim": 768, "multiple_of": 32, "n_heads": 12, "n_layers": 12, "norm_eps": 1e-05, "vocab_size": 32000}' > params.json
4243
```
4344

44-
#### LLAMA3.2
45+
#### LLAMA3.2
4546
Follow the [instructions](https://www.llama.com/) to download models.
4647
At the end of this step, users should have the following files ready: `consolidated.00.pth`, `params.json`, and `tokenizer.model`.
4748

@@ -58,6 +59,53 @@ Default example using hybrid mode.
5859
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --llama_model llama3_2 --model_mode hybrid --prefill_seq_len 32 --kv_seq_len 128 --prompt "what is 1+1"
5960
```
6061

62+
### KV Cache update mechanism
63+
We have two distinct mechanisms for updating the key-value (KV) cache, which can be selected at runtime. Shift Pointer and Smart Mask.
64+
65+
#### Shift Pointer mechanism
66+
67+
<figure>
68+
<img src="./assets/ShiftPointer.png" alt="Shift Pointer mechanism"> <figcaption>
69+
The figure illustrates the process of updating the key and value caches during each inference step. In key cache update process, we initially allocate memory for each layer with <code>num_head</code> size of <code>(head_dim + 1) * (seq_len - 1)</code>. After a single inference, the new key cache is copied from the key output pointer <code>k_out</code> and appended to the key cache. Subsequently, the buffer start pointer of the key cache <code>k_in</code> moves to the next token, making the previous position of the buffer start pointer unused. This process is repeated for each subsequent inference step.
70+
For the value cache update process, we first allocate a contiguous memory of size <code>(num_head + 1) * head_dim * (seq_len - 1)</code> for each layer, with the last head reserved for I/O shifting, After the first inference, the cache is updated by simply shifting the pointers of all heads to the next token position, making only the previous <code>head_dim * 1</code> section of the buffer start pointer <code>v_in</code> of the first head unused. This process is repeated for each subsequent inference step.</figcaption>
71+
</figure>
72+
73+
#### Smart Mask mechanism:
74+
<figure>
75+
<img src="./assets/SmartMask.png" alt="Smart Mask mechanism">
76+
<figcaption>The Smart Mask mechanism streamlines the process of updating tokens in the cache. Unlike the Shift Pointer mechanism, which requires moving the buffer start pointer <code>k_in</code>/<code>v_in</code> of the cache, the Smart Mask mechanism updates only the new token at the specified position. This approach eliminates the need to adjust the buffer start pointer. This mechanism is beneficial for shared buffers but requires CPU memory copying. </figcaption>
77+
</figure>
78+
79+
#### Analysis KV Cache Update Mechanism for each Layer each inference
80+
<table>
81+
<tr>
82+
<th>Mechanism</th>
83+
<th colspan="2" style="text-align:center;">Time Complexity</th>
84+
<th colspan="2" style="text-align:center;">Space Complexity</th>
85+
</tr>
86+
<tr>
87+
<th></th>
88+
<th style="text-align:center;">K</th>
89+
<th style="text-align:center;">V</th>
90+
<th style="text-align:center;">K</th>
91+
<th style="text-align:center;">V</th>
92+
</tr>
93+
<tr>
94+
<td style="text-align:center;">Shift Pointer</td>
95+
<td style="text-align:center;">num_head * head_dim</td>
96+
<td style="text-align:center;">1</td>
97+
<td style="text-align:center;">num_head * (head_dim + 1) * seq_len</td>
98+
<td style="text-align:center;">(num_head + 1) * head_dim * (seq_len - 1)</td>
99+
</tr>
100+
<tr>
101+
<td style="text-align:center;">Smart Mask</td>
102+
<td style="text-align:center;">num_head * head_dim</td>
103+
<td style="text-align:center;">num_head * head_dim</td>
104+
<td style="text-align:center;">num_head * seq_len * head_dim</td>
105+
<td style="text-align:center;">num_head * seq_len * head_dim</td>
106+
</tr>
107+
</table>
108+
61109
### Additional Configs when running the script
62110
If you would like to compile the model only, we have provided the flag `--compile_only`. Taking LLAMA3.2 as an example:
63111
```bash
@@ -67,4 +115,10 @@ python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -m ${SOC_MO
67115
On the other hand, if you already have a pre-compiled .pte model, you can perform inference by providing the flag `--pre_gen_pte` and specifying the folder that contains the .pte model. Taking LLAMA3.2 as an example:
68116
```bash
69117
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --llama_model llama3_2 --model_mode hybrid --prefill_seq_len 32 --kv_seq_len 128 --prompt "what is 1+1" --pre_gen_pte ${FOLDER_TO_PRE_GEN_PTE}
70-
```
118+
```
119+
120+
You can select the KV Cache update mechanism at runtime by setting the `KV_UPDATER` variable to either "shift_pointer" or "smart_mask". By default, it is set to "smart_mask".
121+
`KV_UPDATER` = "shift_pointer"
122+
```bash
123+
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --llama_model llama3_2 --model_mode hybrid --prefill_seq_len 32 --kv_seq_len 128 --prompt "what is 1+1" --kv_updator ${KV_UPDATER}
124+
```
Loading
Loading

0 commit comments

Comments
 (0)