Commit 67645d0
authored
Implemented sliding window attention to maintain KV cache only for the window size to enable infinite decoding. (#995)
* Revert "Transpose kv cache for better decode performance (#979)"
This reverts commit b130416.
* Update golden configs
* Implemented sliding window attention to maintain KV cache only for the window size to enable infinite decoding.
Currently, when using `MultiheadAttention` or `GroupedQueryAttention` for
sliding window attention, the KV cache is kept for the full sequence length
(`seq_len`) instead of the window length (`window_len`).
For example, a model with `window_len=1k` and `seq_len=2M` keeps a KV cache
for the full 2M tokens. It then biases 1999k invalid KV tokens before
calculating attention, resulting in a computational complexity of **O(2M²)**
instead of the desired **O(1k²)**.
This issue persists even when using flash attention. Flash attention uses the
KV cache allocated in HBM as its input. While unnecessary blocks are discarded
during computation, the KV cache still occupies HBM inefficiently for the full
2M tokens.
To address this, when `MultiheadAttention` detects a sliding window mask, it
stores the key-value (KV) cache in a ring buffer inside the input linear layer.
As a result, downstream projects using `MultiheadAttention` automatically
benefit from efficient KV cache handling in `init_states` and `extend_step`.
Additionally, for use cases like local-global attention in LLMs, it is
recommended to use sliding window masks for even the global attention as well.
For example, if you want to train an LLM with a context length of 8k, you can
set the sliding window size to 8k during training. This enables functionally
infinite decoding during inference. Accuracy wouldn't be good tho.
Note:
* query_positions in QKVLinear.forward() was introduced by
#914. Now it returns to the caller.
This PR actually moves from downstream speech/streaming/sliding_window_attention.py
* transpose1 parent 272a4d2 commit 67645d0
File tree
106 files changed
+1468
-580
lines changed- axlearn
- audio
- common
- flash_attention
- experiments
- testdata
- axlearn.experiments.text.gpt.c4_trainer
- axlearn.experiments.text.gpt.deterministic_trainer
- axlearn.experiments.text.gpt.pajama_sigmoid_trainer
- axlearn.experiments.text.gpt.pajama_trainer
- text/gpt
- vision
Some content is hidden
Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.
106 files changed
+1468
-580
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
10 | 10 | | |
11 | 11 | | |
12 | 12 | | |
13 | | - | |
| 13 | + | |
14 | 14 | | |
15 | 15 | | |
16 | 16 | | |
| |||
1619 | 1619 | | |
1620 | 1620 | | |
1621 | 1621 | | |
1622 | | - | |
| 1622 | + | |
1623 | 1623 | | |
1624 | 1624 | | |
1625 | 1625 | | |
| |||
1698 | 1698 | | |
1699 | 1699 | | |
1700 | 1700 | | |
| 1701 | + | |
| 1702 | + | |
| 1703 | + | |
| 1704 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
2 | 2 | | |
3 | 3 | | |
4 | 4 | | |
5 | | - | |
6 | | - | |
7 | 5 | | |
8 | 6 | | |
9 | 7 | | |
| |||
130 | 128 | | |
131 | 129 | | |
132 | 130 | | |
133 | | - | |
134 | | - | |
135 | | - | |
136 | | - | |
137 | | - | |
138 | | - | |
139 | | - | |
140 | | - | |
141 | | - | |
142 | | - | |
| 131 | + | |
| 132 | + | |
| 133 | + | |
| 134 | + | |
| 135 | + | |
| 136 | + | |
| 137 | + | |
| 138 | + | |
| 139 | + | |
| 140 | + | |
143 | 141 | | |
144 | | - | |
145 | | - | |
146 | | - | |
| 142 | + | |
147 | 143 | | |
148 | 144 | | |
149 | 145 | | |
| |||
171 | 167 | | |
172 | 168 | | |
173 | 169 | | |
174 | | - | |
| 170 | + | |
175 | 171 | | |
176 | 172 | | |
177 | 173 | | |
| |||
0 commit comments