Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC]: Implement disaggregated prefilling via KV cache transfer #5557

Open
KuntaiDu opened this issue Jun 14, 2024 · 16 comments
Open

[RFC]: Implement disaggregated prefilling via KV cache transfer #5557

KuntaiDu opened this issue Jun 14, 2024 · 16 comments
Labels

Comments

@KuntaiDu
Copy link
Collaborator

Motivation.

There are more and more use cases, where we need to transfer KV caches between vLLM instances, or store KV caches for future use. Some concrete use cases:

  • Disaggregated prefilling. In this case, the KV cache needs to be transferred from the prefilling instances to the decoding instances
  • The user want to query a fixed set of long documents (examples: software manual, internal documents, etc). In this case, the GPU memory + CPU memory may not be enough to store the KV cache of all documents, and we may want to storage the KV cache of these documents and move them to GPU on-demand.

Proposed Change.

My current thought is to introduce two new abstractions: communicator and KV database. The workflow will be

vllm <--> communicator <--> KV database

where

  • The communicator transfer the data from src to dst, where both src and dst can be a KV block in vllm, or an entry in database
  • The KV database is a database using the hash (generated in automatic prefix caching) as the key, the corresponding KV cache tensor as the value.

This will be a huge framework, with a wide range of challenging (but fun!) questions inside, including but not limited to:

  • How to leverage infrastructures like NVLink to transfer KV cache faster?
  • How to properly pipeline the KV cache transfer?
  • How to make sure the blocks are not swapped out when the communicator is working?
  • Compress KV cache during transfer or not? If so, which compression algorithm? Who compresses the cache?

Feel free to post any thoughts on the design! Is it good? Is this abstraction able to achieve the optimal performance in your use cases?

Feedback Period.

Several weeks

CC List.

@simon-mo @youkaichao @zhuohan123 @cadedaniel @ywang96 @WoosukKwon @LiuXiaoxuanPKU

Any Other Things.

No response

@KuntaiDu KuntaiDu added the RFC label Jun 14, 2024
@KuntaiDu
Copy link
Collaborator Author

KuntaiDu commented Jun 15, 2024

After discussing, maybe it is better for us to focus on disaggregated prefilling first, and then it is much easier to tell how should we make the high-level architecture change.

For disaggregated prefilling, does the following workflow sound good or not?

For an upcoming request:

  • We send the request to a vllm instance to do prefilling (by setting the # of output tokens to 1)
  • Then, we find a vllm instance for decoding and reserve KV cache (by sending the request to that decoding instance, but preempt this request).
  • When the prefilling vllm instance finishes computing the KV block for one LLM layer, we transfer the KV block of this layer to the decoding vllm instance
  • After the prefilling instance finishes (we know it finishes when we receive the HTTP response from the prefilling instance), we can stop preempting the request in decoding instance.
  • The decoding vllm instance then uses the automatic prefix caching to retrieve the prefilled KV blocks, and then performs inference.

@leiwen83
Copy link
Contributor

leiwen83 commented Jun 17, 2024

Sounds very interesting!

For the second usage, I have a question

The user want to query a fixed set of long documents (examples: software manual, internal documents, etc). In this case, the GPU memory + CPU memory may not be enough to store the KV cache of all documents, and we may want to storage the KV cache of these documents and move them to GPU on-demand.

It seems to leverage the prefill caching mechanism, which require the doc is in the top of the prompt, and only the query part is different in the bottom, right? So that it could handle the case that long documents pieces along with many different query, and those top same part's kvcache would be stored inside CPU's memory?

And it's better also take consideration those GPU without nvlink like 4090...

For KV compression, I think maybe KV cache quanatization to 4/2bits would make this whole subsystem more valuable

@richardliaw
Copy link
Collaborator

Would it make sense to first get some simple design on abstractions for handling the KV cache, before designing the transport?

For example, having something like:

input_state = engine.prefill(input)
save(input_state, file)
----
input_state = read(file)
engine = engine.insert_state(input_state)
engine.generate(...)

Would be a nice starting point.

Then later maybe it can be async/lazy so that we would pipeline the state automatically

@cadedaniel
Copy link
Collaborator

I gave a comment offline, pasting it here:

The concept makes sense in vLLM but I am concerned we are starting with the infra first instead of the impactful feature or performance optimization. What usually happens is because the infra is built without a narrow use-case in mind, it is very difficult to prioritize design choices and infra features. Can we flip this on its head and instead build one of the user-impacting features/performance improvements, and work backwards from that to the infra features necessary?
My thoughts are that prefill disagg has really tight performance constraints for KV transfer. it would be a big waste if the eventual implementation couldn’t use this work because the performance requirements weren’t known ahead-of-time.

@AnikinNN
Copy link

I have found one more usage for storing KV cache somewhere. I suppose it would be nice to have this feature when working with agents such as chain of thoughts. It has a repetable phases of generation and appending tool's outputs. As for now, every time generation stops due to tool invocation and appending tool's outputs to the prompt, LLM then calls again. We have growing leading part of prompt which is the same inside one call of chain-of-thoughts.

@KuntaiDu
Copy link
Collaborator Author

Sounds very interesting!

For the second usage, I have a question

The user want to query a fixed set of long documents (examples: software manual, internal documents, etc). In this case, the GPU memory + CPU memory may not be enough to store the KV cache of all documents, and we may want to storage the KV cache of these documents and move them to GPU on-demand.

It seems to leverage the prefill caching mechanism, which require the doc is in the top of the prompt, and only the query part is different in the bottom, right? So that it could handle the case that long documents pieces along with many different query, and those top same part's kvcache would be stored inside CPU's memory?

And it's better also take consideration those GPU without nvlink like 4090...

For KV compression, I think maybe KV cache quanatization to 4/2bits would make this whole subsystem more valuable

In the long document reusing case, sure CPU can be used as a layer of cache. But there are two scenarios, where using CPU as a KV cache is NOT efficient:

  • When the CPU memory is not enough to store the KV caches of all documents, and in that case we may need to load KV cache from an external device. For example, the KV cache of 200-page document is roughly 30 GB for 7B model. With 1000 documents, the KV cache can be 30 TB and definitely will not fit into CPU memory.
  • When the requests with the SAME prefix are forwarded to DIFFERENT vllm instances, which can be a common case when doing load balancing between multiple vllm instances.

For those devices without NVLink, I agree with you, it would be nice if we can support it. But let's focus on make the KV transfer REALLY fast using NVLink first (which is a cool feature that trt/tgi/lmdeploy does not have), so that we can gauge more interest from other developers.

For KV compression, there is a series of research that explores alternative opportunities besides simple quantization. Some pointers:
https://arxiv.org/abs/2306.14048 (token filtering)
https://arxiv.org/pdf/2310.07240 (leveraging similarity between consecutive tokens for compression)
So a lot of exciting opportunities besides simple quantization.

@KuntaiDu
Copy link
Collaborator Author

KuntaiDu commented Jun 19, 2024

Would it make sense to first get some simple design on abstractions for handling the KV cache, before designing the transport?

For example, having something like:

input_state = engine.prefill(input)
save(input_state, file)
----
input_state = read(file)
engine = engine.insert_state(input_state)
engine.generate(...)

Would be a nice starting point.

Then later maybe it can be async/lazy so that we would pipeline the state automatically

Agree!!! A nuance here is what should be the granularity of KV cache read/write. Per vllm block or per query. My current preference is per vllm block, as the time when we need to read/save KV cache is typically tied to the decisions of block manager (e.g. we may need to read KV cache, when block manager allocates new block; or we may need to write KV cache to disk, when a KV cache is swapped out from CPU by block manager), so it is better to align the granularity with the block manager.

@Jeffwan
Copy link
Contributor

Jeffwan commented Jun 25, 2024

Great to see the proposal! We are doing experiments to offload reusable KV contents to external cache store. Happy to discuss more details.

@KuntaiDu KuntaiDu changed the title [RFC]: Implement KV cache transferring mechanism in vLLM [RFC]: Implement disaggregated prefilling via KV cache transfer Jun 30, 2024
@KuntaiDu
Copy link
Collaborator Author

My current plan is to focus on implementing disaggregated prefilling using cross-vllm-instance kv cache transfer. Two reasons:

  • Disaggregated prefilling is already adopted widely by industry (so there will be more people willing to contribute)
  • And it helps us build a good abstraction for KV cache transfer (including read & write & stream KV cache), useful for future research

@KuntaiDu
Copy link
Collaborator Author

KuntaiDu commented Jun 30, 2024

Base implementation:
4 processes: prefilling instance, decoding instance
For a new incoming request:

  • Pad this request so that the number of tokens are divisible by block_size
  • Send prefilling instance the incoming request, and set max_tokens=1
  • Wait for prefilling instance to finish
  • Stream the KV cache blocks from prefilling instance to decoding instance
  • Send decoding instance the incoming request

Foreseeable overheads (compared to an implementation):

  • Padding may take some time
    • Can be reduced by adding extra parameter to vllm engine padding=True and let vllm pad the input token by itself)
  • A slightly larger prefilling time due to the added padding tokens.
    • This overhead should be marginal in the usecase of disaggregated prefilling (long input length), so likely no need to reduce it
  • KV cache streaming time
    • Can be reduced by pipeline the KV cache transfer (layer-by-layer transfer or token-by-token transfer)
  • Decoding instance needs to call the prefilling function again (though it will be much faster by using the transferred KV cache).
    • Can be reduced but not easy. Need to measure it to show that it is much smaller than TTFT.

My very first step: measure the overhead of call the prefilling function again with the KV cache.

@TopIdiot
Copy link

TopIdiot commented Jul 9, 2024

Sounds great!
And I think a scheduler is needed, to decide which two instances the request should be scheduled to.

@leo6022
Copy link

leo6022 commented Jul 22, 2024

How to implement kv-cache transfer, nccl or rdma?

@Playerrrrr
Copy link

Is this still going?

1 similar comment
@dhandhalyabhavik
Copy link

Is this still going?

@Wh1isper
Copy link

I noticed one paper that seems to implement kv cache migration: https://arxiv.org/abs/2406.03243

Their project: https://github.com/AlibabaPAI/llumnix

Sorry I'm just getting into vllm and seeing this issue. I'm curious how they did it if vllm doesn't have the relevant interface support? Or, we already have a way to not implement this feature within vllm.

@5symx
Copy link

5symx commented Nov 26, 2024

I noticed one paper that seems to implement kv cache migration: https://arxiv.org/abs/2406.03243

Their project: https://github.com/AlibabaPAI/llumnix

Sorry I'm just getting into vllm and seeing this issue. I'm curious how they did it if vllm doesn't have the relevant interface support? Or, we already have a way to not implement this feature within vllm.

Yes, I think their project implement kv cache migration. But they are doing this across the continuous decoding step not between the prefill and decode or for the future reuse. It means that the overlap between kv migration and decoding computation doesn't exist anymore because the src will not generate new token when the kv transfer happens for disaggregated prefilling.

Current interface of kv cache move is copy memory between GPU and CPU using cudaMemcpyDeviceToHost, also copy memory within the same device by using cudaMemcpyDeviceToDevice.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests