Skip to content

Conversation

@orozery
Copy link
Contributor

@orozery orozery commented Oct 29, 2025

Core of RFC #27742.
Following this PR, connectors can turn-on and adapt to the new layout.

This PR enables the GPU model runner to allocate the KV cache tensors, so that the KV data for all layers will be contiguous per block. This can yield a significant speed up the transfer time of KV transfers (e.g. X4), such in the case of using NixlConnector or OffloadingConnector. Currently, this new layout is disabled by default, and will only be enabled when using a connector which explicitly prefers this new layout. Also, this new layout is currently only supported for uniform (non HMA) models.

@mergify
Copy link

mergify bot commented Oct 29, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @orozery.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces an optimization for KV cache transfers by allowing contiguous allocation of KV data across layers. The changes are well-structured, introducing a new allocation path in GPUModelRunner that is conditionally enabled for uniform models and specific connectors. The interface changes in attention backends are appropriate. The implementation includes a graceful fallback to the existing allocation mechanism, ensuring backward compatibility. I've identified a minor issue with incorrect comments in vllm/v1/attention/backends/flashinfer.py which could be misleading for future maintenance.

@orozery orozery force-pushed the contiguous-kv-layers branch from 4fa8cec to b466bf3 Compare October 29, 2025 12:41
@mergify mergify bot removed the needs-rebase label Oct 29, 2025
@orozery orozery force-pushed the contiguous-kv-layers branch from b466bf3 to 5392229 Compare October 29, 2025 12:46
This commit enables the GPU model runner to allocate the KV cache tensors,
so that the KV data for all layers will be contiguous per block.
This can yield a significant speed up the transfer time of KV transfers (e.g. X4),
such in the case of using NixlConnector or OffloadingConnector.
Currently, this new layout is disabled by default, and will only be enabled when using
a connector which explicitly prefers this new layout.
Also, this new layout is currently only supported for uniform (non HMA) models.

Signed-off-by: Or Ozeri <oro@il.ibm.com>
@orozery orozery force-pushed the contiguous-kv-layers branch from 5392229 to 6e68176 Compare October 29, 2025 14:46
@orozery orozery changed the title GPUModelRunner: Support contiguous KV data across layers GPUModelRunner: Support cross-layer KV blocks Oct 29, 2025
@ApostaC
Copy link
Collaborator

ApostaC commented Oct 29, 2025

Quick question:

  • How should the connector know whether the layout is the new one or the old one?
  • Should the connector have 2 different code paths to support both layouts? This will potentially increase the connector code by a lot.

@orozery
Copy link
Contributor Author

orozery commented Oct 29, 2025

  • How should the connector know whether the layout is the new one or the old one?

My initial thought is to check on register_kv_caches whether two layer tensors have the same .storage().data_ptr().
Of course we can also expose a new API.

  • Should the connector have 2 different code paths to support both layouts? This will potentially increase the connector code by a lot.

Connectors don't need to support the new layout if they don't want to.
For the offloading connector and nixl it does not seem like a big change.
For the offloading connector for example this is just writing a new cuda transfer function.

Copy link
Collaborator

@ApostaC ApostaC left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Like the high-level idea, but I think we need to further discuss the design.

Also, the correctness of this PR is not tested. Will the new code have any correctness & performance impact on the underlying attention modules?

Comment on lines +4352 to +4356
buffer = (
torch.zeros(total_size, dtype=torch.int8, device=self.device)
.view(kv_cache_spec.dtype)
.view(kv_cache_shape)
).permute(*inv_order)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By calling view + permute, we basically create a non-contiguous view of the raw KV cache buffer. This may have an unexpected impact (both performance and correctness) on other modules.

For example, calling reshape on a non-contiguous tensor may introduce an extra memory copy. Therefore, this code adds implicit pitfalls for all the other logics (either first-party or third-party) that may need to do reshape.

Quick code snippet:

import torch

x = torch.randn((2,3,5))
y = x.permute(2,0,1)
z = y.reshape(30)

print("X -- is contiguous:", x.is_contiguous(), "\tdata_ptr:", x.data_ptr())
print("Y -- is contiguous:", y.is_contiguous(), "\tdata_ptr:", y.data_ptr())
print("Z -- is contiguous:", z.is_contiguous(), "\tdata_ptr:", z.data_ptr())

And the output:

X -- is contiguous: True        data_ptr: 1076219648
Y -- is contiguous: False       data_ptr: 1076219648
Z -- is contiguous: True        data_ptr: 986565184   ############ Y is being implicitly copied to Z.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct.
This is already done (making non-contiguous view) when using HND layout (note that a permute flow already exists today in _reshape_kv_cache_tensors.

Note that this layout is currently applied when getting the agreement of 3 parties:

  1. The KV cache spec (no HMA)
  2. The KV connector (prefer_cross_layer_blocks)
  3. The attention backend (get_kv_cache_stride_order)

Comment on lines +4358 to +4365
kv_caches = {}
for i, kv_cache_tensor in enumerate(kv_cache_config.kv_cache_tensors):
tensor = buffer[i]
for layer_name in kv_cache_tensor.shared_by:
kv_caches[layer_name] = tensor

return kv_caches

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although I like the high-level idea, I strongly feel like we should directly expose the raw KV cache tensor to the connector rather than the permuted KV caches dictionary.

Otherwise, it's very hard to let the connector know what the layout is before the permutation, and this may introduce a huge debug pain, especially when we need to write some C/CUDA code that directly uses the raw pointers.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current API exposes dict[str, pytorch.Tensor] to the connector (via register_kv_caches).
The tensors themselves could be permuted.
The physical layout of the tensors can be self-determined using pytorch APIs.
But I agree this is somehow "hidden".
If we want to be more explicit we can add an auxiliary variable to register_kv_caches that will hold the physical layout spec for the tensors. What do you think?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the idea of adding the auxiliary variable to register kv caches function!

@orozery
Copy link
Contributor Author

orozery commented Oct 30, 2025

Like the high-level idea, but I think we need to further discuss the design.

Also, the correctness of this PR is not tested. Will the new code have any correctness & performance impact on the underlying attention modules?

Good point.
I was not sure myself this will work, so I just ran some "how much is two plus two" prompt to verify this works (for flash attention and flash infer).
Is there a way to make a CI test out of it? Or some other way?

@ApostaC
Copy link
Collaborator

ApostaC commented Oct 30, 2025

Is there a way to make a CI test out of it? Or some other way?

I think lm_eval can be used for correctness benchmarks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants