Skip to content

[P/D][NixlConnector] Enable FlashInfer backend #19090

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 50 additions & 15 deletions vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@
import zmq

from vllm import envs
from vllm.attention.selector import backend_name_to_enum, get_attn_backend
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size,
get_tp_group)
from vllm.logger import init_logger
from vllm.platforms import _Backend
from vllm.utils import make_zmq_path, make_zmq_socket, round_down
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.request import RequestStatus
Expand Down Expand Up @@ -57,6 +59,7 @@ class NixlAgentMetadata(
num_blocks: int
tp_size: int
block_len: int
attn_backend_name: str


@dataclass
Expand Down Expand Up @@ -384,11 +387,25 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):

self.vllm_config = vllm_config
self.block_size = vllm_config.cache_config.block_size
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config

# TODO(mgoin): remove this once we have hybrid memory allocator
# Optimization for models with local attention (Llama 4)
# List of block window sizes for each layer for local attention
self.block_window_per_layer: list[Optional[int]] = []
self.use_mla = self.model_config.use_mla

backend = get_attn_backend(self.model_config.get_head_size(),
self.model_config.dtype,
self.cache_config.cache_dtype,
self.block_size,
self.model_config.is_attention_free,
use_mla=self.use_mla)
self.backend_name = backend.get_name()
attn_backend = backend_name_to_enum(self.backend_name)
self._use_flashinfer = attn_backend == _Backend.FLASHINFER_VLLM_V1
logger.debug("Detected attention backend %s", self.backend_name)

self._tp_size: dict[str, int] = {self.engine_id: self.world_size}
# With heterogeneous TP, P must wait for all assigned D TP workers to
Expand Down Expand Up @@ -472,37 +489,44 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
kv_elem_size = first_kv_cache.element_size()

# TODO(tms): Find a more robust way to detect and handle MLA
self.use_mla = len(first_kv_cache.shape) == 3
# NOTE (NickLucche) To move blocks efficiently with NIXL, the expected
# KV memory layout is HND, as opposed to the default NHD. Note that it
# will only affects the strides. For MLA instead, we make require no
# such thing and resort to the standard layout.
if self.use_mla:
use_mla = len(first_kv_cache.shape) == 3
assert use_mla == self.use_mla
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could just replace self.use_mla at this point, just wanted to make sure the two match for review


# TODO (NickLucche) not compatible with hybrid allocator. Enforce check
# once it goes live, as a single kv layout is expected for xfers.
if use_mla:
# MLA case.
self.num_blocks = first_kv_cache.shape[0]
block_rank = 2 # [block_size, latent_dim]
block_shape = first_kv_cache.shape[-block_rank:]
block_size, kv_latent_dim = block_shape
self.slot_size_bytes = kv_elem_size * kv_latent_dim
else:
# [2 (k and v), num_blocks, block_size, kv_heads, head_dim]
self.num_blocks = first_kv_cache.shape[1]
block_rank = 3 # [block_size, kv_heads, head_dim]
# [2 (k and v), num_blocks, ...]
if self._use_flashinfer:
# FlashInfer swaps 2<->num_blocks dimensions.
self.num_blocks = first_kv_cache.shape[0]
block_rank = 4 # [2, block_size, kv_heads, head_dim]
else:
self.num_blocks = first_kv_cache.shape[1]
block_rank = 3 # [block_size, kv_heads, head_dim]
block_shape = first_kv_cache.shape[-block_rank:]
block_size, n_kv_heads, head_dim = block_shape
block_size, n_kv_heads, head_dim = block_shape[-3:]
# head size in bytes.
self.slot_size_bytes = kv_elem_size * n_kv_heads * head_dim
assert block_size == self.block_size
# TODO(tms): self.block_len needs to be per-layer for sliding window,
# hybrid attn, etc
# block size in bytes
self.block_len = kv_elem_size * math.prod(block_shape)

logger.debug("Registering KV_Caches. use_mla: %s, shape %s",
self.use_mla, first_kv_cache.shape)
logger.debug("num_blocks: %s, block_shape: %s", self.num_blocks,
block_shape)
logger.debug("Per layer kv cache size: %s", first_kv_cache.shape)
logger.info(
"Registering KV_Caches: use_mla: %s, num_blocks: %s, "
"block_shape: %s, per_layer_kv_cache_shape: %s", use_mla,
self.num_blocks, block_shape, first_kv_cache.shape)
self.dst_num_blocks[self.engine_id] = self.num_blocks
self.kv_caches = kv_caches
kv_caches_base_addr = []
Expand All @@ -514,9 +538,12 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
# are non-contiguous (it's not locally guaranteed that they will be)
# Disadvantage is that the encoded NixlAgentMetadata is now larger
# (roughly 8KB vs 5KB).
# Conversely for FlashInfer, K and V are transferred in the same tensor
# to better exploit the memory layout (ie num_blocks is the first dim).
for cache_or_caches in kv_caches.values():
# Normalize to always be a list of caches
cache_list = [cache_or_caches] if self.use_mla else cache_or_caches
cache_list = [cache_or_caches] if use_mla or self._use_flashinfer \
else cache_or_caches
for cache in cache_list:
base_addr = cache.data_ptr()
region_len = self.num_blocks * self.block_len
Expand Down Expand Up @@ -581,7 +608,8 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id],
num_blocks=self.num_blocks,
tp_size=self.world_size,
block_len=self.block_len)
block_len=self.block_len,
attn_backend_name=self.backend_name)
ready_event = threading.Event()
self._nixl_handshake_listener_t = threading.Thread(
target=self._nixl_handshake_listener,
Expand Down Expand Up @@ -641,6 +669,10 @@ def add_remote_agent(self,
assert self._tp_size[engine_id] == nixl_agent_meta.tp_size
else:
self._tp_size[engine_id] = nixl_agent_meta.tp_size
# We may eventually enable this after asserting equality in cache
# layout and close outputs.
assert nixl_agent_meta.attn_backend_name == self.backend_name

self._remote_agents[engine_id][
remote_tp_rank] = self.nixl_wrapper.add_remote_agent(
nixl_agent_meta.agent_metadata)
Expand All @@ -659,13 +691,16 @@ def add_remote_agent(self,
else:
remote_block_size = nixl_agent_meta.block_len // (
self.slot_size_bytes * tp_ratio)
if self._use_flashinfer:
# Account for joint KV in FlashInfer.
remote_block_size //= 2

assert nixl_agent_meta.block_len == self.block_len * tp_ratio, (
"Remote P worker KV layer cache must be of shape [2, N, "
"local_kv_heads*tp_ratio, block_size, head_dim] and same dtype."
)

assert self.block_size == remote_block_size, "Remote P worker with "
assert self.block_size == remote_block_size, "Remote P worker with " \
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

o/w assert message prints a tab in console

"different block size is not supported"

assert self.num_blocks >= nixl_agent_meta.num_blocks
Expand Down
1 change: 1 addition & 0 deletions vllm/platforms/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class _Backend(enum.Enum):
ROCM_AITER_MLA_VLLM_V1 = enum.auto()
TORCH_SDPA = enum.auto()
FLASHINFER = enum.auto()
FLASHINFER_VLLM_V1 = enum.auto()
TRITON_MLA = enum.auto() # Supported by V1
TRITON_MLA_VLLM_V1 = enum.auto()
FLASHMLA_VLLM_V1 = enum.auto()
Expand Down