Closed
Description
This RFC discusses our plan for implementing automatic prefix caching in vLLM.
High-level idea
We observe that every block in the KV cache can be uniquely identified by
hash(prefix tokens, tokens in this block)
With this, we can add another indirection in vLLM's KV cache management:
Logical block table --> hash table --> physical block table.
Then, all the sharing in vLLM, including sharing prefixes, can be achieved by logical blocks pointing to the block with the same hash value. Automatic prefix caching can be achieved by not freeing blocks with reference one in the KV cache. Specifically, this design enables us to manage the KV blocks as ordinary caches in operating systems.
We can maintain the following information in every block:
- Block's hash
- Reference count
- Last accessed time
- Total access count
- The prefix length of this block
Then, for example, the following cache eviction policy will give the same policy as in RadixAttetion:
- Check the reference count first. Only evict the blocks with
ref count == 0
. - Then check the last accessed time. Prefer to free older blocks following LRU.
- If the last accessed time is the same, check the prefix length. Free the one with longer prefix lengths first.
Major benefits of this design over a KV block Trie
- Sometimes, caching is not limited to prefix caching:
- With Mistral's sliding window attention, we only need to cache the last tokens in the sliding window.
- With attention sinks, we need to cache the first few tokens and the latest tokens.
- Maintaining hash table is simpler than maintaining a tree.
- Extensible to more advanced caching policy (the one above is just an example).
Notes
- An arbitrary caching policy may randomly free a block in the middle of a prefix. Then we need an attention kernel that can compute attention on sequences like the following: “ooooxxxxxooooxxxoooxxxooooxxx”, where we need to compute attention on all “x” tokens. This kernel can be implemented and is not required for the first version.
- We would only cache the complete blocks, and we will keep partial blocks out of the hash table.
Deliverables
P0
- Make every complete KV block cacheable. Do not immediately free KV blocks with ref count 0.
- Implement the cache eviction policy above, with a good abstracted class on eviction policy.
- The cache eviction policy class should take a sequence of token IDs (or block hases), and return a list of blocks. Some of blocks can be already in the cache, and some of the blocks can be a new block that is just evicted by the policy.
- Refactor the current block table to use hash:
- Add the attributes above to every block object in vLLM.
- For every request, keep the list of block objects as block table.
- [new] Implement global tables of blocks
- Two tables: complete block table and partial block table.
- When an incomplete block becomes a complete block, we need to merge it with an existing complete block or promote it to a new complete block.
- The preemption policy is kept the same as before.
P1
- Make sure the policy works for sliding windows.
- Faster hash function
P2
- Kernel optimization for "ooxxxoooxxxoooxxx" case.
- Better preemption strategy for OOM cases.
- Support block swapping.